diff --git a/element/column.go b/element/column.go index e22db34..74cf56d 100644 --- a/element/column.go +++ b/element/column.go @@ -20,6 +20,7 @@ const ( type Column struct { Node Typ *types.FieldType + StrTyp string Options []*ast.ColumnOption } diff --git a/postgres-parser/ast.go b/postgres-parser/ast.go new file mode 100644 index 0000000..29be168 --- /dev/null +++ b/postgres-parser/ast.go @@ -0,0 +1 @@ +package postgres_parser diff --git a/postgres-parser/parser.go b/postgres-parser/parser.go new file mode 100644 index 0000000..fda2aef --- /dev/null +++ b/postgres-parser/parser.go @@ -0,0 +1,351 @@ +package postgres_parser + +import ( + "fmt" + "strings" + + "github.com/pingcap/parser/ast" + "github.com/sunary/sqlize/element" +) + +// Parser declaration +type Parser struct { + Migration element.Migration + s *Scanner + + // current token & literal + token Token + lit string +} + +// NewParser ... +func NewParser() *Parser { + return &Parser{ + token: ILLEGAL, + lit: "", + } +} + +// Parse ... +func (p *Parser) Parse(sql string) error { + p.s = NewScanner(strings.NewReader(sql)) + + for { + p.next() + switch p.token { + case CREATE: + err := p.parseCreateTable() + if err != nil { + return err + } + + case ALTER: + + case EOF: + return nil + + default: + return p.expectErr() + } + } +} + +var _ = ` +CREATE TABLE IF NOT EXISTS inventory +( + id SERIAL, + sku VARCHAR NOT NULL, + warehouse VARCHAR NOT NULL, + available_quantity FLOAT8 NOT NULL DEFAULT 0, + group_id INTEGER NOT NULL DEFAULT 0, -- for partition purpose + created_at timestamptz(6) NOT NULL DEFAULT CURRENT_TIMESTAMP, + updated_at timestamptz(6) NOT NULL DEFAULT CURRENT_TIMESTAMP, + deleted_at timestamptz(6), + PRIMARY KEY (id, group_id) +) PARTITION BY LIST (group_id); +` + +func (p *Parser) parseCreateTable() error { + p.next() + if p.token != TABLE { + return p.expectErr(TABLE) + } + + p.next() + if p.token == IF { + err := p.expectedNext(NOT, EXISTS) + if err != nil { + return err + } + + p.next() + } + + p.Migration.AddTable(*element.NewTable(p.lit)) + p.Migration.Using(p.lit) + + err := p.expectedNext(LBRACE) + if err != nil { + return err + } + + for { + p.next() + switch p.token { + case INDEX: + err := p.parseIndex() + if err != nil { + return err + } + + case PRIMARY: + err := p.expectedNext(KEY) + if err != nil { + return err + } + err = p.parseIndex() + if err != nil { + return err + } + + case RBRACE: + return nil + + default: + p.next() + if p.token == COMMENT { + _, err = p.parseString() + if err != nil { + return err + } + + p.next() + } else { + err := p.parseColumn(p.lit) + if err != nil { + return err + } + } + } + } +} + +func (p *Parser) parseIndex() error { + idx := element.Index{ + Node: element.Node{ + Name: "", + }, + } + + if p.token == LPAREN { + p.next() + for p.token == IDENT { + idx.Columns = append(idx.Columns, p.lit) + p.next() + if p.token == COMMA { + p.next() + } + } + + if p.token != RPAREN { + return p.expectErr(RPAREN) + } + } else if p.token == IDENT { + idx.Columns = append(idx.Columns, p.lit) + } else { + return p.expectErr(COLUMN_NAME) + } + + p.next() + + if p.token == LBRACK { + commaAllowed := false + + for { + p.next() + switch { + case p.token == UNIQUE: + idx.Typ = ast.IndexKeyTypeUnique + + case p.token == COMMA: + if !commaAllowed { + return p.expectErr(INDEX_NAME) + } + + case p.token == RBRACK: + p.next() + return nil + + default: + return p.expectErr(PK, UNIQUE) + } + commaAllowed = !commaAllowed + } + } + + p.Migration.AddIndex("", idx) + return nil +} + +func (p *Parser) parseColumn(name string) error { + col := element.Column{ + Node: element.Node{ + Name: name, + }, + } + if p.token != IDENT { + return p.expectErr(INTEGER, VARCHAR) + } + + col.StrTyp = p.lit + p.next() + + // parse for type + switch p.token { + case LPAREN: + p.next() + if p.token != INTEGER { + return p.expectErr(INTEGER) + } + + col.StrTyp = fmt.Sprintf("%s(%s)", col.StrTyp, p.lit) + p.next() + if p.token != RPAREN { + return p.expectErr(RPAREN) + } + + p.next() + if p.token != LBRACK { + break + } + fallthrough + + case LBRACK: + //handle parseColumn + err := p.parseColumnSettings(&col) + if err != nil { + return err + } + + p.next() // remove ']' + } + + p.Migration.AddColumn("", col) + return nil +} + +func (p *Parser) parseColumnSettings(col *element.Column) error { + commaAllowed := false + + for { + p.next() + switch p.token { + case PK: + col.Options = append(col.Options, &ast.ColumnOption{ + Tp: ast.ColumnOptionPrimaryKey, + }) + + case PRIMARY: + p.next() + if p.token != KEY { + return p.expectErr(KEY) + } + col.Options = append(col.Options, &ast.ColumnOption{ + Tp: ast.ColumnOptionPrimaryKey, + }) + + case NOT: + p.next() + if p.token != NULL { + return p.expectErr(NULL) + } + col.Options = append(col.Options, &ast.ColumnOption{ + Tp: ast.ColumnOptionNotNull, + }) + + case UNIQUE: + col.Options = append(col.Options, &ast.ColumnOption{ + Tp: ast.ColumnOptionUniqKey, + }) + + case INCR: + col.Options = append(col.Options, &ast.ColumnOption{ + Tp: ast.ColumnOptionAutoIncrement, + }) + + case DEFAULT: + p.next() + if p.token != COLON { + return p.expectErr(COLON) + } + p.next() + switch p.token { + case STRING, DSTRING, TSTRING, INTEGER, NUMBER, EXPR: + col.Options = append(col.Options, &ast.ColumnOption{ + Tp: ast.ColumnOptionDefaultValue, + StrValue: p.lit, + }) + + default: + return p.expectErr(DEFAULT) + } + case COMMA: + if !commaAllowed { + return p.expectErr() + } + + case RBRACK: + return nil + + default: + return p.expectErr(PRIMARY, PK, UNIQUE) + } + + commaAllowed = !commaAllowed + } +} + +func (p *Parser) parseString() (string, error) { + p.next() + switch p.token { + case STRING, DSTRING, TSTRING: + return p.lit, nil + + default: + return "", p.expectErr(STRING, DSTRING, TSTRING) + } +} + +func (p *Parser) next() { + for { + p.token, p.lit = p.s.Read() + if p.token != COMMENT { + break + } + } +} + +func (p *Parser) expectedNext(keywords ...Token) error { + for _, k := range keywords { + p.next() + if p.token != k { + return p.expectErr(k) + } + } + + return nil +} + +func (p *Parser) expectErr(toks ...Token) error { + l, c := p.s.LineInfo() + expected := make([]string, len(toks)) + for i := range toks { + quote := "'" + if tokens[toks[i]] == "'" { + quote = "\"" + } + expected[i] = quote + tokens[toks[i]] + quote + } + + return fmt.Errorf("[%d:%d] invalid token '%s', expected: %s", l, c, p.lit, strings.Join(expected, ",")) +} diff --git a/postgres-parser/scanner.go b/postgres-parser/scanner.go new file mode 100644 index 0000000..7d0b8b9 --- /dev/null +++ b/postgres-parser/scanner.go @@ -0,0 +1,265 @@ +package postgres_parser + +import ( + "bufio" + "bytes" + "io" +) + +const eof = rune(0) + +// Scanner represents a lexical scanner. +type Scanner struct { + r *bufio.Reader + ch rune // for peek + l uint + c uint +} + +// NewScanner returns a new instance of Scanner. +func NewScanner(r io.Reader) *Scanner { + s := &Scanner{r: bufio.NewReader(r), l: 1, c: 0} + s.next() + return s +} + +// Next return next token and literal value +func (s *Scanner) Read() (tok Token, lit string) { + for isWhitespace(s.ch) { + s.next() + } + + // Otherwise read the individual character. + switch { + case isLetter(s.ch): + return s.scanIdent() + + case isDigit(s.ch): + return s.scanNumber() + + default: + ch := s.ch + lit := string(ch) + s.next() + switch ch { + case eof: + return EOF, "" + + case '-': + if s.ch == '-' { + return COMMENT, s.scanComment(s.ch) + } + + return SUB, lit + + case '<': + return LT, lit + + case '>': + return GT, lit + + case '(': + return LPAREN, lit + + case '[': + return LBRACK, lit + + case '{': + return LBRACE, lit + + case ')': + return RPAREN, lit + + case ']': + return RBRACK, lit + + case '}': + return RBRACE, lit + + case ';': + return SEMICOLON, lit + + case ':': + return COLON, lit + + case ',': + return COMMA, lit + + case '.': + return PERIOD, lit + + case '`': + return s.scanExpression() + + case '\'', '"': + return s.scanString(ch) + + case '/': + if s.ch == '/' { + return COMMENT, s.scanComment(s.ch) + } + + return ILLEGAL, string(ch) + + case '#': + return COMMENT, s.scanComment(0) + } + + return ILLEGAL, string(ch) + } +} + +func (s *Scanner) scanComment(nextChar rune) string { + var buf bytes.Buffer + if nextChar != 0 { + buf.WriteRune(nextChar) + } + + for s.ch != '\n' && s.ch != eof { + buf.WriteRune(s.ch) + s.next() + } + + return buf.String() +} + +func (s *Scanner) scanNumber() (Token, string) { + var buf bytes.Buffer + countDot := 0 + for isDigit(s.ch) || (s.ch == '.' && countDot < 2) { + if s.ch == '.' { + countDot++ + } + + buf.WriteRune(s.ch) + s.next() + } + + if countDot < 1 { + return INTEGER, buf.String() + } else if countDot > 1 { + return ILLEGAL, buf.String() + } + + return NUMBER, buf.String() +} + +func (s *Scanner) scanString(quo rune) (Token, string) { + switch quo { + case '"': + lit, ok := s.scanTo(quo) + if ok { + return DSTRING, lit + } + return ILLEGAL, lit + + case '\'': + if s.ch != '\'' { + lit, ok := s.scanTo(quo) + if ok { + return STRING, lit + } + return ILLEGAL, lit + } + + // Handle Triple quote string + var buf bytes.Buffer + s.next() + if s.ch == '\'' { // triple quote string + s.next() + count := 0 + for count < 3 { + switch s.ch { + case '\'': + count++ + case eof: + return ILLEGAL, buf.String() + } + buf.WriteRune(s.ch) + s.next() + } + return TSTRING, buf.String()[:buf.Len()-count] + } + return ILLEGAL, buf.String() + + default: + return ILLEGAL, string(eof) + } +} + +func (s *Scanner) scanExpression() (Token, string) { + lit, ok := s.scanTo('`') + if ok { + return EXPR, lit + } + + return ILLEGAL, lit +} + +func (s *Scanner) scanTo(stop rune) (string, bool) { + var buf bytes.Buffer + for { + switch s.ch { + case stop: + s.next() + return buf.String(), true + + case '\n', eof: + return buf.String(), false + + default: + buf.WriteRune(s.ch) + s.next() + } + } +} + +func (s *Scanner) scanIdent() (tok Token, lit string) { + var buf bytes.Buffer + for { + buf.WriteRune(s.ch) + s.next() + + if !isLetter(s.ch) && !isDigit(s.ch) && s.ch != '_' && s.ch != '.' { + break + } + } + + return Lookup(buf.String()), buf.String() +} + +func (s *Scanner) next() { + ch, _, err := s.r.ReadRune() + if err != nil { + s.ch = eof + return + } + + if ch == '\n' { + s.l++ + s.c = 0 + } + + s.c++ + s.ch = ch +} + +// LineInfo return line info +func (s *Scanner) LineInfo() (uint, uint) { + return s.l, s.c +} + +// isWhitespace returns true if the rune is a space, tab, or newline. +func isWhitespace(ch rune) bool { return ch == ' ' || ch == '\t' || ch == '\n' } + +// isLetter returns true if the rune is a letter. +func isLetter(ch rune) bool { return (ch >= 'a' && ch <= 'z') || (ch >= 'A' && ch <= 'Z') } + +// isDigit returns true if the rune is a digit. +func isDigit(ch rune) bool { return ch >= '0' && ch <= '9' } + +// isIdentChar returns true if the rune can be used in an unquoted identifier. +func isIdentChar(ch rune) bool { return isLetter(ch) || isDigit(ch) || ch == '_' } + +// isIdentFirstChar returns true if the rune can be used as the first char in an unquoted identifer. +func isIdentFirstChar(ch rune) bool { return isLetter(ch) || ch == '_' } diff --git a/postgres-parser/token.go b/postgres-parser/token.go new file mode 100644 index 0000000..e11d393 --- /dev/null +++ b/postgres-parser/token.go @@ -0,0 +1,362 @@ +package postgres_parser + +import ( + "strings" +) + +// Token is a lexical token of the InfluxQL language. +type Token int + +// These are a comprehensive list of InfluxQL language tokens. +const ( + // ILLEGAL Token, EOF, WS are Special InfluxQL tokens. + ILLEGAL Token = iota + EOF + WS + COMMENT + + literalBeg + // IDENT and the following are InfluxQL literal tokens. + IDENT // main + BOUNDPARAM // $param + NUMBER // 12345.67 + INTEGER // 12345 + VARCHAR // varchar + DURATIONVAL // 13h + STRING // 'abc' + DSTRING // "abc" + TSTRING // '''abc''' + BADSTRING // 'abc + BADESCAPE // \q + TRUE // true + FALSE // false + REGEX // Regular expressions + BADREGEX // `.* + EXPR // `now()` + literalEnd + + operatorBeg + // ADD and the following are InfluxQL Operators + ADD // + + SUB // - + MUL // * + DIV // / + + AND // AND + OR // OR + + EQ // = + NEQ // != + EQREGEX // =~ + NEQREGEX // !~ + LT // < + LTE // <= + GT // > + GTE // >= + operatorEnd + + LPAREN // ( + LBRACK // [ + LBRACE // { + COMMA // , + PERIOD // . + + RPAREN // ) + RBRACK // ] + RBRACE // } + SEMICOLON // ; + COLON // : + + keywordBeg + // ALL and the following are InfluxQL Keywords + ALL + ALTER + ANY + AS + ASC + BEGIN + BY + CREATE + TABLE + CONTINUOUS + DATABASE + DATABASES + DEFAULT + DELETE + DESC + DESTINATIONS + DIAGNOSTICS + DISTINCT + DROP + DURATION + END + EVERY + EXISTS + EXPLAIN + FIELD + FOR + FROM + GRANT + GRANTS + GROUP + GROUPS + IF + IN + INCR + INDEX + INF + INSERT + INTO + KEY + KEYS + KILL + LIMIT + MEASUREMENT + MEASUREMENTS + NAME + NOT + NULL + OFFSET + ON + ORDER + PASSWORD + PK + POLICY + POLICIES + PRIMARY + PRIVILEGES + QUERIES + QUERY + READ + REPLICATION + RESAMPLE + RETENTION + REVOKE + SELECT + SERIES + SET + SHOW + SHARD + SHARDS + SLIMIT + SOFFSET + STATS + SUBSCRIPTION + SUBSCRIPTIONS + TAG + TO + UNIQUE + USER + USERS + VALUES + WHERE + WITH + WRITE + keywordEnd + + aliasBeg + TABLE_NAME + COLUMN_NAME + INDEX_NAME + aliasEnd +) + +var tokens = [...]string{ + ILLEGAL: "ILLEGAL", + EOF: "EOF", + WS: "WS", + COMMENT: "COMMENT", + + IDENT: "IDENT", + NUMBER: "NUMBER", + DURATIONVAL: "DURATIONVAL", + STRING: "STRING", + DSTRING: "DSTRING", + TSTRING: "TSTRING", + BADSTRING: "BADSTRING", + BADESCAPE: "BADESCAPE", + TRUE: "TRUE", + FALSE: "FALSE", + REGEX: "REGEX", + + ADD: "+", + SUB: "-", + MUL: "*", + DIV: "/", + + AND: "AND", + OR: "OR", + + EQ: "=", + NEQ: "!=", + EQREGEX: "=~", + NEQREGEX: "!~", + LT: "<", + LTE: "<=", + GT: ">", + GTE: ">=", + + LPAREN: "(", + LBRACK: "[", + LBRACE: "{", + COMMA: ",", + PERIOD: ".", + + RPAREN: ")", + RBRACK: "]", + RBRACE: "]", + SEMICOLON: ";", + COLON: ":", + + ALL: "ALL", + ALTER: "ALTER", + ANY: "ANY", + AS: "AS", + ASC: "ASC", + BEGIN: "BEGIN", + BY: "BY", + CREATE: "CREATE", + CONTINUOUS: "CONTINUOUS", + DATABASE: "DATABASE", + DATABASES: "DATABASES", + DEFAULT: "DEFAULT", + DELETE: "DELETE", + DESC: "DESC", + DESTINATIONS: "DESTINATIONS", + DIAGNOSTICS: "DIAGNOSTICS", + DISTINCT: "DISTINCT", + DROP: "DROP", + DURATION: "DURATION", + END: "END", + EVERY: "EVERY", + EXPLAIN: "EXPLAIN", + FIELD: "FIELD", + FOR: "FOR", + FROM: "FROM", + GRANT: "GRANT", + GRANTS: "GRANTS", + GROUP: "GROUP", + GROUPS: "GROUPS", + IN: "IN", + INF: "INF", + INSERT: "INSERT", + INTO: "INTO", + KEY: "KEY", + KEYS: "KEYS", + KILL: "KILL", + LIMIT: "LIMIT", + MEASUREMENT: "MEASUREMENT", + MEASUREMENTS: "MEASUREMENTS", + NAME: "NAME", + NOT: "NOT", + NULL: "NULL", + OFFSET: "OFFSET", + ON: "ON", + ORDER: "ORDER", + PASSWORD: "PASSWORD", + POLICY: "POLICY", + POLICIES: "POLICIES", + PRIVILEGES: "PRIVILEGES", + QUERIES: "QUERIES", + QUERY: "QUERY", + READ: "READ", + REPLICATION: "REPLICATION", + RESAMPLE: "RESAMPLE", + RETENTION: "RETENTION", + REVOKE: "REVOKE", + SELECT: "SELECT", + SERIES: "SERIES", + SET: "SET", + SHOW: "SHOW", + SHARD: "SHARD", + SHARDS: "SHARDS", + SLIMIT: "SLIMIT", + SOFFSET: "SOFFSET", + STATS: "STATS", + SUBSCRIPTION: "SUBSCRIPTION", + SUBSCRIPTIONS: "SUBSCRIPTIONS", + TAG: "TAG", + TO: "TO", + UNIQUE: "UNIQUE", + USER: "USER", + USERS: "USERS", + VALUES: "VALUES", + WHERE: "WHERE", + WITH: "WITH", + WRITE: "WRITE", +} + +var keywords map[string]Token + +func init() { + keywords = make(map[string]Token) + for tok := keywordBeg + 1; tok < keywordEnd; tok++ { + keywords[strings.ToLower(tokens[tok])] = tok + } + + for _, tok := range []Token{AND, OR} { + keywords[strings.ToLower(tokens[tok])] = tok + } + + keywords["true"] = TRUE + keywords["false"] = FALSE +} + +// String returns the string representation of the token. +func (tok Token) String() string { + if tok >= 0 && tok < Token(len(tokens)) { + return tokens[tok] + } + + return "" +} + +// Precedence returns the operator precedence of the binary operator token. +func (tok Token) Precedence() int { + switch tok { + case OR: + return 1 + + case AND: + return 2 + + case EQ, NEQ, EQREGEX, NEQREGEX, LT, LTE, GT, GTE: + return 3 + + case ADD, SUB: + return 4 + + case MUL, DIV: + return 5 + } + + return 0 +} + +// isOperator returns true for operator tokens. +func (tok Token) isOperator() bool { return tok > operatorBeg && tok < operatorEnd } + +// tokstr returns a literal if provided, otherwise returns the token string. +func tokstr(tok Token, lit string) string { + if lit != "" { + return lit + } + + return tok.String() +} + +// Lookup returns the token associated with a given string. +func Lookup(ident string) Token { + if tok, ok := keywords[strings.ToLower(ident)]; ok { + return tok + } + + return IDENT +} + +// Pos specifies the line and character position of a token. +// The Char and Line are both zero-based indexes. +type Pos struct { + Line int + Char int +} diff --git a/sqlize.go b/sqlize.go index 22ff70a..c53a9d5 100644 --- a/sqlize.go +++ b/sqlize.go @@ -9,6 +9,7 @@ import ( "github.com/sunary/sqlize/mysql-parser" "github.com/sunary/sqlize/sql-builder" "github.com/sunary/sqlize/utils" + "github.com/sunary/sqlize/postgres-parser" ) const ( @@ -25,6 +26,7 @@ type Sqlize struct { isLower bool sqlBuilder *sql_builder.SqlBuilder mysqlParser *mysql_parser.Parser + postgresParser *postgres_parser.Parser } func NewSqlize(opts ...SqlizeOption) *Sqlize { @@ -59,6 +61,7 @@ func NewSqlize(opts ...SqlizeOption) *Sqlize { sqlBuilder: sb, mysqlParser: mysql_parser.NewParser(o.isLower), + postgresParser: postgres_parser.NewParser(), } } @@ -73,6 +76,10 @@ func (s *Sqlize) FromObjects(objs ...interface{}) error { } func (s *Sqlize) FromString(sql string) error { + if s.isPostgres { + return s.postgresParser.Parse(sql) + } + return s.mysqlParser.Parser(sql) }