diff --git a/ast.go b/ast.go index 304eb80..f3c4fad 100644 --- a/ast.go +++ b/ast.go @@ -54,9 +54,11 @@ type SelectItem struct { } type SelectStatement struct { - Item *[]*SelectItem - From *Token - Where *Expression + Item *[]*SelectItem + From *Token + Where *Expression + Limit *Expression + Offset *Expression } func (ss SelectStatement) GenerateCode() string { @@ -73,17 +75,24 @@ func (ss SelectStatement) GenerateCode() string { item = append(item, s) } - from := "" + code := "SELECT\n" + strings.Join(item, ",\n") if ss.From != nil { - from = fmt.Sprintf("\nFROM\n\t\"%s\"", ss.From.Value) + code += fmt.Sprintf("\nFROM\n\t\"%s\"", ss.From.Value) } - where := "" if ss.Where != nil { - where = fmt.Sprintf("\nWHERE\n\t%s", ss.Where.GenerateCode()) + code += "\nWHERE\n\t" + ss.Where.GenerateCode() } - return fmt.Sprintf("SELECT\n%s%s%s;", strings.Join(item, ",\n"), from, where) + if ss.Limit != nil { + code += "\nLIMIT\n\t" + ss.Limit.GenerateCode() + } + + if ss.Offset != nil { + code += "\nOFFSET\n\t" + ss.Limit.GenerateCode() + } + + return code + ";" } type ColumnDefinition struct { diff --git a/lexer.go b/lexer.go index 4aa2fa3..cfe8bd2 100644 --- a/lexer.go +++ b/lexer.go @@ -37,6 +37,8 @@ const ( OnKeyword Keyword = "on" PrimarykeyKeyword Keyword = "primary key" NullKeyword Keyword = "null" + LimitKeyword Keyword = "limit" + OffsetKeyword Keyword = "offset" ) // for storing SQL syntax @@ -261,6 +263,8 @@ func lexKeyword(source string, ic cursor) (*Token, cursor, bool) { OnKeyword, PrimarykeyKeyword, NullKeyword, + LimitKeyword, + OffsetKeyword, } var options []string diff --git a/memory.go b/memory.go index 8f215bc..70cc5d4 100644 --- a/memory.go +++ b/memory.go @@ -587,6 +587,33 @@ func (mb *MemoryBackend) Select(slct *SelectStatement) (*Results, error) { } } + limit := len(t.rows) + if slct.Limit != nil { + v, _, _, err := t.evaluateCell(0, *slct.Limit) + if err != nil { + return nil, err + } + + limit = int(*v.AsInt()) + } + if limit < 0 { + return nil, fmt.Errorf("Invalid, negative limit") + } + + offset := 0 + if slct.Offset != nil { + v, _, _, err := t.evaluateCell(0, *slct.Offset) + if err != nil { + return nil, err + } + + offset = int(*v.AsInt()) + } + if offset < 0 { + return nil, fmt.Errorf("Invalid, negative limit") + } + + rowIndex := -1 for i := range t.rows { result := []Cell{} isFirstRow := len(results) == 0 @@ -602,6 +629,13 @@ func (mb *MemoryBackend) Select(slct *SelectStatement) (*Results, error) { } } + rowIndex++ + if rowIndex < offset { + continue + } else if rowIndex > offset+limit-1 { + break + } + for _, col := range finalItems { value, columnName, columnType, err := t.evaluateCell(uint(i), *col.Exp) if err != nil { diff --git a/parser.go b/parser.go index 207788a..948865a 100644 --- a/parser.go +++ b/parser.go @@ -273,9 +273,12 @@ func (p Parser) parseSelectStatement(tokens []*Token, initialCursor uint, delimi cursor = newCursor } + limitToken := tokenFromKeyword(LimitKeyword) + offsetToken := tokenFromKeyword(OffsetKeyword) + _, cursor, ok = p.parseToken(tokens, cursor, whereToken) if ok { - where, newCursor, ok := p.parseExpression(tokens, cursor, []Token{delimiter}, 0) + where, newCursor, ok := p.parseExpression(tokens, cursor, []Token{limitToken, offsetToken, delimiter}, 0) if !ok { p.helpMessage(tokens, cursor, "Expected WHERE conditionals") return nil, initialCursor, false @@ -285,6 +288,30 @@ func (p Parser) parseSelectStatement(tokens []*Token, initialCursor uint, delimi cursor = newCursor } + _, cursor, ok = p.parseToken(tokens, cursor, limitToken) + if ok { + limit, newCursor, ok := p.parseExpression(tokens, cursor, []Token{offsetToken, delimiter}, 0) + if !ok { + p.helpMessage(tokens, cursor, "Expected LIMIT value") + return nil, initialCursor, false + } + + slct.Limit = limit + cursor = newCursor + } + + _, cursor, ok = p.parseToken(tokens, cursor, offsetToken) + if ok { + offset, newCursor, ok := p.parseExpression(tokens, cursor, []Token{delimiter}, 0) + if !ok { + p.helpMessage(tokens, cursor, "Expected OFFSET value") + return nil, initialCursor, false + } + + slct.Offset = offset + cursor = newCursor + } + return &slct, cursor, true } diff --git a/parser_test.go b/parser_test.go index d9d44cf..e218c2c 100644 --- a/parser_test.go +++ b/parser_test.go @@ -193,7 +193,7 @@ func TestParse(t *testing.T) { }, }, { - source: `SELECT id, name AS fullname FROM "sketchy name"`, + source: `SELECT id, name AS fullname FROM "sketchy name" LIMIT 10 OFFSET 12`, ast: &Ast{ Statements: []*Statement{ { @@ -231,6 +231,22 @@ func TestParse(t *testing.T) { Kind: IdentifierKind, Value: "sketchy name", }, + Limit: &Expression{ + Kind: LiteralKind, + Literal: &Token{ + Loc: Location{Col: 54, Line: 0}, + Kind: NumericKind, + Value: "10", + }, + }, + Offset: &Expression{ + Kind: LiteralKind, + Literal: &Token{ + Loc: Location{Col: 65, Line: 0}, + Kind: NumericKind, + Value: "12", + }, + }, }, }, }, diff --git a/repl.go b/repl.go index 1667029..59dec9c 100644 --- a/repl.go +++ b/repl.go @@ -10,7 +10,6 @@ import ( "github.com/olekukonko/tablewriter" ) - func doSelect(mb Backend, slct *SelectStatement) error { results, err := mb.Select(slct) if err != nil { @@ -236,31 +235,31 @@ repl: case CreateIndexKind: err = b.CreateIndex(ast.Statements[0].CreateIndexStatement) if err != nil { - fmt.Println("Error adding index on table", err) + fmt.Println("Error adding index on table:", err) continue repl } case CreateTableKind: err = b.CreateTable(ast.Statements[0].CreateTableStatement) if err != nil { - fmt.Println("Error creating table", err) + fmt.Println("Error creating table:", err) continue repl } case DropTableKind: err = b.DropTable(ast.Statements[0].DropTableStatement) if err != nil { - fmt.Println("Error dropping table", err) + fmt.Println("Error dropping table:", err) continue repl } case InsertKind: err = b.Insert(stmt.InsertStatement) if err != nil { - fmt.Println("Error inserting values", err) + fmt.Println("Error inserting values:", err) continue repl } case SelectKind: err := doSelect(b, stmt.SelectStatement) if err != nil { - fmt.Println("Error selecting values", err) + fmt.Println("Error selecting values:", err) continue repl } }