From d0aec474d95806e3d9c2281671fbbc5d04adc546 Mon Sep 17 00:00:00 2001 From: Phil Eaton Date: Sat, 14 Mar 2020 19:49:08 -0400 Subject: [PATCH] Upgrade lexer (#2) * Use cursor-style lexer with full control * Some progress * Fixes for tests * Add tests for case preservation --- cmd/main.go | 5 +- lexer.go | 359 ++++++++++++++++++++++++++++++------------------- lexer_test.go | 296 ++++++++++++++++++++++++++++++++++++++-- parser.go | 8 +- parser_test.go | 29 ++-- 5 files changed, 522 insertions(+), 175 deletions(-) diff --git a/cmd/main.go b/cmd/main.go index d59f028..554cc08 100644 --- a/cmd/main.go +++ b/cmd/main.go @@ -2,7 +2,6 @@ package main import ( "bufio" - "bytes" "fmt" "os" "strings" @@ -20,9 +19,7 @@ func main() { text, err := reader.ReadString('\n') text = strings.Replace(text, "\n", "", -1) - source := bytes.NewBufferString(text) - - ast, err := gosql.Parse(source) + ast, err := gosql.Parse(text) if err != nil { panic(err) } diff --git a/lexer.go b/lexer.go index 14ff353..74bceb1 100644 --- a/lexer.go +++ b/lexer.go @@ -2,7 +2,6 @@ package gosql import ( "fmt" - "io" "strings" ) @@ -52,99 +51,163 @@ type token struct { loc location } +type cursor struct { + pointer uint + loc location +} + func (t *token) equals(other *token) bool { return t.value == other.value && t.kind == other.kind } -func (t *token) finalizeSymbol() bool { - switch t.value { - case "*": - break - case ";": - break - case "(": - break - case ")": +type lexer func(string, cursor) (*token, cursor, bool) + +func lexSymbol(source string, ic cursor) (*token, cursor, bool) { + c := source[ic.pointer] + cur := ic + cur.loc.col++ + cur.pointer++ + + switch c { + // Syntax that should be thrown away + case '\n': + cur.loc.line++ + cur.loc.col = 0 + fallthrough + case '\t': + fallthrough + case ' ': + return nil, cur, true + + // Syntax that should be kept + case ',': + fallthrough + case '(': + fallthrough + case ')': + fallthrough + case ';': + fallthrough + case '*': break + + // Unknown character default: - return false + return nil, ic, false } - t.kind = symbolKind - return true + return &token{ + value: string(c), + loc: ic.loc, + kind: symbolKind, + }, cur, true } -func (t *token) finalizeKeyword() bool { - switch strings.ToLower(t.value) { - case "select": - break - case "from": - break - case "as": - break - case "table": - break - case "create": - break - case "insert": - break - case "into": - break - case "values": - break - case "int": - break - case "text": - break - default: - return false +func lexKeyword(source string, ic cursor) (*token, cursor, bool) { + cur := ic + keywords := []keyword{ + selectKeyword, + insertKeyword, + valuesKeyword, + tableKeyword, + createKeyword, + fromKeyword, + intoKeyword, + textKeyword, + intKeyword, + asKeyword, } - t.value = strings.ToLower(t.value) - t.kind = keywordKind - return true -} + var value []byte + var skipList []int + var match string + + for { + value = append(value, source[cur.pointer]) + cur.pointer++ + + keyword: + for i, keyword := range keywords { + for _, skip := range skipList { + if i == skip { + continue keyword + } + } -func (t *token) finalizeNumeric() bool { - if len(t.value) == 0 { - return false + // Deal with cases like INT vs INTO + if string(keyword) == strings.ToLower(string(value)) { + skipList = append(skipList, i) + if len(keyword) > len(match) { + match = string(keyword) + } + + continue + } + + sharesPrefix := strings.ToLower(string(value)) == string(keyword)[:cur.pointer-ic.pointer] + tooLong := len(value) > len(keyword) + if tooLong || !sharesPrefix { + skipList = append(skipList, i) + } + } + + if len(skipList) == len(keywords) { + break + } } + if match == "" { + return nil, ic, false + } + + // Set pointer and col exactly because of partial matches + // while iterating over keywords + cur.pointer = ic.pointer + uint(len(match)) + cur.loc.col = cur.loc.col + uint(len(match)) + + return &token{ + value: match, + kind: keywordKind, + loc: ic.loc, + }, cur, true +} + +func lexNumeric(source string, ic cursor) (*token, cursor, bool) { + cur := ic + periodFound := false expMarkerFound := false - i := 0 - for i < len(t.value) { - c := t.value[i] + for ; cur.pointer < uint(len(source)); cur.pointer++ { + c := source[cur.pointer] + cur.loc.col++ isDigit := c >= '0' && c <= '9' isPeriod := c == '.' isExpMarker := c == 'e' // Must start with a digit or period - if i == 0 { + if cur.pointer == ic.pointer { if !isDigit && !isPeriod { - return false + return nil, ic, false } periodFound = isPeriod - i++ continue } if isPeriod { if periodFound { - return false + return nil, ic, false } periodFound = true - i++ continue } if isExpMarker { if expMarkerFound { - return false + return nil, ic, false } // No periods allowed after expMarker @@ -152,134 +215,150 @@ func (t *token) finalizeNumeric() bool { expMarkerFound = true // expMarker must be followed by digits - if i == len(t.value)-1 { - return false + if cur.pointer == uint(len(source)-1) { + return nil, ic, false } - cNext := t.value[i+1] + cNext := source[cur.pointer+1] if cNext == '-' || cNext == '+' { - i++ + cur.pointer++ + cur.loc.col++ } - i++ continue } if !isDigit { - return false + break } + } - i++ + // No characters accumulated + if cur.pointer == ic.pointer { + return nil, ic, false } - t.kind = numericKind - return true + return &token{ + value: source[ic.pointer:cur.pointer], + loc: ic.loc, + kind: numericKind, + }, cur, true } -func (t *token) finalizeIdentifier() bool { - t.kind = identifierKind - return true -} +func lexCharacterDelimited(source string, ic cursor, delimiter byte) (*token, cursor, bool) { + cur := ic -func (t *token) finalizeString() bool { - if len(t.value) == 0 { - return false + if len(source[cur.pointer:]) == 0 { + return nil, ic, false } - if t.value[0] == '\'' && t.value[len(t.value)-1] == '\'' { - t.kind = stringKind - t.value = t.value[1 : len(t.value)-1] - return true + if source[cur.pointer] != delimiter { + return nil, ic, false } - return false -} + cur.loc.col++ + cur.pointer++ + + var value []byte + for ; cur.pointer < uint(len(source)); cur.pointer++ { + c := source[cur.pointer] + + if c == delimiter { + // SQL escapes are via double characters, not backslash. + if cur.pointer+1 >= uint(len(source)) || source[cur.pointer+1] != delimiter { + return &token{ + value: string(value), + loc: ic.loc, + kind: stringKind, + }, cur, true + } else { + value = append(value, delimiter) + cur.pointer++ + cur.loc.col++ + } + } -func (t *token) finalize() bool { - if t.finalizeSymbol() { - return true + value = append(value, c) + cur.loc.col++ } - if t.finalizeKeyword() { - return true + return nil, ic, false +} + +func lexIdentifier(source string, ic cursor) (*token, cursor, bool) { + // Handle separately if is a double-quoted identifier + if token, newCursor, ok := lexCharacterDelimited(source, ic, '"'); ok { + return token, newCursor, true } - if t.finalizeNumeric() { - return true + cur := ic + + c := source[cur.pointer] + // Other characters count too, big ignoring non-ascii for now + isAlphabetical := (c >= 'A' && c <= 'Z') || (c >= 'a' && c <= 'z') + if !isAlphabetical { + return nil, ic, false } + cur.pointer++ + cur.loc.col++ + + value := []byte{c} + for ; cur.pointer < uint(len(source)); cur.pointer++ { + c = source[cur.pointer] + + // Other characters count too, big ignoring non-ascii for now + isAlphabetical := (c >= 'A' && c <= 'Z') || (c >= 'a' && c <= 'z') + isNumeric := c >= '0' && c <= '9' + if isAlphabetical || isNumeric || c == '$' || c == '_' { + value = append(value, c) + cur.loc.col++ + continue + } - if t.finalizeString() { - return true + break } - if t.finalizeIdentifier() { - return true + if len(value) == 0 { + return nil, ic, false } - return false + return &token{ + // Unquoted dentifiers are case-insensitive + value: strings.ToLower(string(value)), + loc: ic.loc, + kind: identifierKind, + }, cur, true } -func lex(source io.Reader) ([]*token, error) { - buf := make([]byte, 1) - tokens := []*token{} - current := token{} - var line uint = 0 - var col uint = 0 - - for { - _, err := source.Read(buf) - if err != nil && err != io.EOF { - return nil, err - } +func lexString(source string, ic cursor) (*token, cursor, bool) { + return lexCharacterDelimited(source, ic, '\'') +} - // Add semi-colon for EOF - var c byte = ';' - if err == nil { - c = buf[0] - } +func lex(source string) ([]*token, error) { + tokens := []*token{} + cur := cursor{} - switch c { - case '\n': - line++ - col = 0 - continue - case ' ': - fallthrough - case ',': - fallthrough - case '(': - fallthrough - case ')': - fallthrough - case ';': - if !current.finalize() { - return nil, fmt.Errorf("Unexpected token '%s' at %d:%d", current.value, current.loc.line, current.loc.col) - } +lex: + for cur.pointer < uint(len(source)) { + lexers := []lexer{lexKeyword, lexSymbol, lexString, lexNumeric, lexIdentifier} + for _, l := range lexers { + if token, newCursor, ok := l(source, cur); ok { + cur = newCursor - if current.value != "" { - copy := current - tokens = append(tokens, ©) - } + // Omit nil tokens for valid, but empty syntax like newlines + if token != nil { + tokens = append(tokens, token) + } - if c == ';' || c == ',' || c == '(' || c == ')' { - tokens = append(tokens, &token{ - loc: location{col: col, line: line}, - value: string(c), - kind: symbolKind, - }) + continue lex } - - current = token{} - current.loc.col = col - current.loc.line = line - default: - current.value += string(c) } - if err == io.EOF { - break + hint := "" + if len(tokens) > 0 { + hint = " after " + tokens[len(tokens)-1].value } - col++ + return nil, fmt.Errorf("Unable to lex token%s, at %d:%d", hint, cur.loc.line, cur.loc.col) } return tokens, nil diff --git a/lexer_test.go b/lexer_test.go index 51a59e2..85aa0e8 100644 --- a/lexer_test.go +++ b/lexer_test.go @@ -1,20 +1,24 @@ package gosql import ( - "bytes" + "strings" "testing" "github.com/stretchr/testify/assert" ) -func TestToken_finalizeNumeric(t *testing.T) { +func TestToken_lexNumeric(t *testing.T) { tests := []struct { number bool value string }{ { number: true, - value: "123", + value: "105", + }, + { + number: true, + value: "105 ", }, { number: true, @@ -69,12 +73,174 @@ func TestToken_finalizeNumeric(t *testing.T) { number: false, value: "1ee4", }, + { + number: false, + value: " 1", + }, + } + + for _, test := range tests { + tok, _, ok := lexNumeric(test.value, cursor{}) + assert.Equal(t, test.number, ok, test.value) + if ok { + assert.Equal(t, strings.TrimSpace(test.value), tok.value, test.value) + } + } +} + +func TestToken_lexString(t *testing.T) { + tests := []struct { + string bool + value string + }{ + { + string: true, + value: "'abc'", + }, + { + string: true, + value: "'a b'", + }, + { + string: true, + value: "'a' ", + }, + { + string: true, + value: "'a '' b'", + }, + // false tests + { + string: false, + value: "'", + }, + { + string: false, + value: "", + }, + { + string: false, + value: " 'foo'", + }, + } + + for _, test := range tests { + tok, _, ok := lexString(test.value, cursor{}) + assert.Equal(t, test.string, ok, test.value) + if ok { + test.value = strings.TrimSpace(test.value) + assert.Equal(t, test.value[1:len(test.value)-1], tok.value, test.value) + } + } +} + +func TestToken_lexIdentifier(t *testing.T) { + tests := []struct { + identifier bool + input string + value string + }{ + { + identifier: true, + input: "abc", + value: "abc", + }, + { + identifier: true, + input: "abc ", + value: "abc", + }, + { + identifier: true, + input: `" abc "`, + value: ` abc `, + }, + { + identifier: true, + input: "a9$", + value: "a9$", + }, + { + identifier: true, + input: "userName", + value: "username", + }, + { + identifier: true, + input: `"userName"`, + value: "userName", + }, + // false tests + { + identifier: false, + input: `"`, + }, + { + identifier: false, + input: "_sadsfa", + }, + { + identifier: false, + input: "9sadsfa", + }, + { + identifier: false, + input: " abc", + }, + } + + for _, test := range tests { + tok, _, ok := lexIdentifier(test.input, cursor{}) + assert.Equal(t, test.identifier, ok, test.input) + if ok { + assert.Equal(t, test.value, tok.value, test.input) + } + } +} + +func TestToken_lexKeyword(t *testing.T) { + tests := []struct { + keyword bool + value string + }{ + { + keyword: true, + value: "select ", + }, + { + keyword: true, + value: "from", + }, + { + keyword: true, + value: "as", + }, + { + keyword: true, + value: "SELECT", + }, + { + keyword: true, + value: "into", + }, + // false tests + { + keyword: false, + value: " into", + }, + { + keyword: false, + value: "flubbrety", + }, } for _, test := range tests { - tok := token{} - tok.value = test.value - assert.Equal(t, test.value, tok.value, test.number) + tok, _, ok := lexKeyword(test.value, cursor{}) + assert.Equal(t, test.keyword, ok, test.value) + if ok { + test.value = strings.TrimSpace(test.value) + assert.Equal(t, strings.ToLower(test.value), tok.value, test.value) + } } } @@ -93,20 +259,121 @@ func TestLex(t *testing.T) { kind: keywordKind, }, { - loc: location{col: 6, line: 0}, + loc: location{col: 7, line: 0}, value: "1", kind: numericKind, }, + }, + err: nil, + }, + { + input: "CREATE TABLE u (id INT, name TEXT)", + tokens: []token{ + { + loc: location{col: 0, line: 0}, + value: string(createKeyword), + kind: keywordKind, + }, { - loc: location{col: 8, line: 0}, - value: ";", + loc: location{col: 7, line: 0}, + value: string(tableKeyword), + kind: keywordKind, + }, + { + loc: location{col: 13, line: 0}, + value: "u", + kind: identifierKind, + }, + { + loc: location{col: 15, line: 0}, + value: "(", + kind: symbolKind, + }, + { + loc: location{col: 16, line: 0}, + value: "id", + kind: identifierKind, + }, + { + loc: location{col: 19, line: 0}, + value: "int", + kind: keywordKind, + }, + { + loc: location{col: 22, line: 0}, + value: ",", + kind: symbolKind, + }, + { + loc: location{col: 24, line: 0}, + value: "name", + kind: identifierKind, + }, + { + loc: location{col: 29, line: 0}, + value: "text", + kind: keywordKind, + }, + { + loc: location{col: 33, line: 0}, + value: ")", + kind: symbolKind, + }, + }, + }, + { + input: "insert into users values (105, 233)", + tokens: []token{ + { + loc: location{col: 0, line: 0}, + value: string(insertKeyword), + kind: keywordKind, + }, + { + loc: location{col: 7, line: 0}, + value: string(intoKeyword), + kind: keywordKind, + }, + { + loc: location{col: 12, line: 0}, + value: "users", + kind: identifierKind, + }, + { + loc: location{col: 18, line: 0}, + value: string(valuesKeyword), + kind: keywordKind, + }, + { + loc: location{col: 25, line: 0}, + value: "(", + kind: symbolKind, + }, + { + loc: location{col: 26, line: 0}, + value: "105", + kind: numericKind, + }, + { + loc: location{col: 30, line: 0}, + value: ",", + kind: symbolKind, + }, + { + loc: location{col: 32, line: 0}, + value: "233", + kind: numericKind, + }, + { + loc: location{col: 36, line: 0}, + value: ")", kind: symbolKind, }, }, err: nil, }, { - input: "SELECT id FROM users", + input: "SELECT id FROM users;", tokens: []token{ { loc: location{col: 0, line: 0}, @@ -114,17 +381,17 @@ func TestLex(t *testing.T) { kind: keywordKind, }, { - loc: location{col: 6, line: 0}, + loc: location{col: 7, line: 0}, value: "id", kind: identifierKind, }, { - loc: location{col: 9, line: 0}, + loc: location{col: 10, line: 0}, value: string(fromKeyword), kind: keywordKind, }, { - loc: location{col: 14, line: 0}, + loc: location{col: 15, line: 0}, value: "users", kind: identifierKind, }, @@ -139,8 +406,9 @@ func TestLex(t *testing.T) { } for _, test := range tests { - tokens, err := lex(bytes.NewBufferString(test.input)) + tokens, err := lex(test.input) assert.Equal(t, test.err, err, test.input) + assert.Equal(t, len(test.tokens), len(tokens), test.input) for i, tok := range tokens { assert.Equal(t, &test.tokens[i], tok, test.input) diff --git a/parser.go b/parser.go index d8a6ea1..0f5ed5e 100644 --- a/parser.go +++ b/parser.go @@ -3,7 +3,6 @@ package gosql import ( "errors" "fmt" - "io" ) func tokenFromKeyword(k keyword) token { @@ -383,12 +382,17 @@ func parseStatement(tokens []*token, initialCursor uint, delimiter token) (*Stat return nil, initialCursor, false } -func Parse(source io.Reader) (*Ast, error) { +func Parse(source string) (*Ast, error) { tokens, err := lex(source) if err != nil { return nil, err } + semicolonToken := tokenFromSymbol(semicolonSymbol) + if len(tokens) > 0 && !tokens[len(tokens)-1].equals(&semicolonToken) { + tokens = append(tokens, &semicolonToken) + } + a := Ast{} cursor := uint(0) for cursor < uint(len(tokens)) { diff --git a/parser_test.go b/parser_test.go index 3c7a735..b464f54 100644 --- a/parser_test.go +++ b/parser_test.go @@ -1,7 +1,6 @@ package gosql import ( - "bytes" "testing" "github.com/stretchr/testify/assert" @@ -20,14 +19,14 @@ func TestParse(t *testing.T) { Kind: InsertKind, InsertStatement: &InsertStatement{ table: token{ - loc: location{col: 11, line: 0}, + loc: location{col: 12, line: 0}, kind: identifierKind, value: "users", }, values: &[]*expression{ { literal: &token{ - loc: location{col: 25, line: 0}, + loc: location{col: 26, line: 0}, kind: numericKind, value: "105", }, @@ -35,7 +34,7 @@ func TestParse(t *testing.T) { }, { literal: &token{ - loc: location{col: 30, line: 0}, + loc: location{col: 32, line: 0}, kind: numericKind, value: "233", }, @@ -55,31 +54,31 @@ func TestParse(t *testing.T) { Kind: CreateTableKind, CreateTableStatement: &CreateTableStatement{ name: token{ - loc: location{col: 12, line: 0}, + loc: location{col: 13, line: 0}, kind: identifierKind, value: "users", }, cols: &[]*columnDefinition{ { name: token{ - loc: location{col: 19, line: 0}, + loc: location{col: 20, line: 0}, kind: identifierKind, value: "id", }, datatype: token{ - loc: location{col: 22, line: 0}, + loc: location{col: 23, line: 0}, kind: keywordKind, value: "int", }, }, { name: token{ - loc: location{col: 27, line: 0}, + loc: location{col: 28, line: 0}, kind: identifierKind, value: "name", }, datatype: token{ - loc: location{col: 32, line: 0}, + loc: location{col: 33, line: 0}, kind: keywordKind, value: "text", }, @@ -105,7 +104,7 @@ func TestParse(t *testing.T) { exp: &expression{ kind: literalKind, literal: &token{ - loc: location{col: 9, line: 0}, + loc: location{col: 10, line: 0}, kind: identifierKind, value: "exclusive", }, @@ -129,7 +128,7 @@ func TestParse(t *testing.T) { exp: &expression{ kind: literalKind, literal: &token{ - loc: location{col: 6, line: 0}, + loc: location{col: 7, line: 0}, kind: identifierKind, value: "id", }, @@ -139,13 +138,13 @@ func TestParse(t *testing.T) { exp: &expression{ kind: literalKind, literal: &token{ - loc: location{col: 10, line: 0}, + loc: location{col: 11, line: 0}, kind: identifierKind, value: "name", }, }, as: &token{ - loc: location{col: 18, line: 0}, + loc: location{col: 19, line: 0}, kind: identifierKind, value: "fullname", }, @@ -153,7 +152,7 @@ func TestParse(t *testing.T) { }, from: &fromItem{ table: &token{ - loc: location{col: 32, line: 0}, + loc: location{col: 33, line: 0}, kind: identifierKind, value: "users", }, @@ -166,7 +165,7 @@ func TestParse(t *testing.T) { } for _, test := range tests { - ast, err := Parse(bytes.NewBufferString(test.source)) + ast, err := Parse(test.source) assert.Nil(t, err, test.source) assert.Equal(t, test.ast, ast, test.source) }