diff --git a/ast.go b/ast.go index 1a72f9b..b30561b 100644 --- a/ast.go +++ b/ast.go @@ -52,9 +52,13 @@ const ( insertKind ) -type ast struct { - slct *SelectStatement - crtTbl *CreateTableStatement - inst *InsertStatement +type Statement struct { + SelectStatement *SelectStatement + CreateTableStatement *CreateTableStatement + InsertStatement *InsertStatement kind astKind } + +type Ast struct { + Statements []*Statement +} diff --git a/cmd/main.go b/cmd/main.go index 827ecb7..1f56eac 100644 --- a/cmd/main.go +++ b/cmd/main.go @@ -1,11 +1,18 @@ package main -import "github.com/eatonphil/gosql" +import ( + "bytes" + "fmt" + + "github.com/eatonphil/gosql" +) func main() { mb := gosql.NewMemoryBackend() - ast, err := gosql.Parse("CREATE TABLE users (id INT, name TEXT); INSERT INTO users VALUES (1, 'Admin'); SELECT id, name FROM users") + source := bytes.NewBufferString("CREATE TABLE users (id INT, name TEXT); INSERT INTO users VALUES (1, 'Admin'); SELECT id, name FROM users") + + ast, err := gosql.Parse(source) if err != nil { panic(err) } @@ -20,12 +27,12 @@ func main() { panic(err) } - results, err = mb.Select(ast.Statements[2].SelectStatement) + results, err := mb.Select(ast.Statements[2].SelectStatement) if err != nil { panic(err) } - for _, col := range results.columns { + for _, col := range results.Columns { fmt.Printf("| %s ", col.Name) } fmt.Println() @@ -37,9 +44,9 @@ func main() { typ := results.Columns[i].Type s := "" switch typ { - case gosql.Int: - s = fmt.Printf("%d", cell.AsInt()) - case gosql.Text: + case gosql.IntType: + s = fmt.Sprintf("%d", cell.AsInt()) + case gosql.TextType: s = cell.AsText() } diff --git a/interface.go b/interface.go index 4313712..8ff0a64 100644 --- a/interface.go +++ b/interface.go @@ -11,8 +11,16 @@ type Cell interface { AsInt() int } +type Results struct { + Columns []struct{ + Type ColumnType + Name string + } + Rows [][]Cell +} + type Backend interface { CreateTable(*CreateTableStatement) error Insert(*InsertStatement) error - Select(*SelectStatement) ([][]Cell, error) + Select(*SelectStatement) (*Results, error) } diff --git a/memory.go b/memory.go index f1fb9cf..8d639a2 100644 --- a/memory.go +++ b/memory.go @@ -52,7 +52,7 @@ func (mb *MemoryBackend) tokenToCell(t *token) MemoryCell { return nil } -func (mb *MemoryBackend) Select(slct *SelectStatement) ([][]MemoryCell, error) { +func (mb *MemoryBackend) Select(slct *SelectStatement) (*Results, error) { table := table{} if slct.from != nil && slct.from.table != nil { @@ -63,17 +63,20 @@ func (mb *MemoryBackend) Select(slct *SelectStatement) ([][]MemoryCell, error) { } } - results := [][]MemoryCell{} - if slct.item == nil || len(*slct.item) == 0 { - return results, nil + return &Results{}, nil } + results := [][]Cell{} + columns := []struct{ + Type ColumnType + Name string + }{} if len(table.rows) > 0 { for _, row := range table.rows { - result := []MemoryCell{} + result := []Cell{} - resultRow := []MemoryCell{} + resultRow := []Cell{} for _, col := range *slct.item { if col.asterisk { // TODO: handle asterisk @@ -93,6 +96,14 @@ func (mb *MemoryBackend) Select(slct *SelectStatement) ([][]MemoryCell, error) { found := false for i, tableCol := range table.columns { if tableCol == lit.value { + columns = append(columns, struct{ + Type ColumnType + Name string + }{ + Type: table.columnTypes[i], + Name: lit.value, + }) + resultRow = append(resultRow, row[i]) found = true break @@ -107,6 +118,18 @@ func (mb *MemoryBackend) Select(slct *SelectStatement) ([][]MemoryCell, error) { } if lit.kind == numericKind || lit.kind == stringKind { + columnType := IntType + if lit.kind == stringKind { + columnType = TextType + } + + columns = append(columns, struct{ + Type ColumnType + Name string + }{ + Type: columnType, + Name: col.exp.literal.value, + }) resultRow = append(resultRow, mb.tokenToCell(lit)) continue } @@ -128,15 +151,18 @@ func (mb *MemoryBackend) Select(slct *SelectStatement) ([][]MemoryCell, error) { } } - return results, nil + return &Results{ + Columns: columns, + Rows: results, + }, nil } -func (mb *MemoryBackend) Insert(inst *InsertStatement) { - +func (mb *MemoryBackend) Insert(inst *InsertStatement) (error) { + return nil } -func (mb *MemoryBackend) CreateTable(crt *CreateTableStatement) { - +func (mb *MemoryBackend) CreateTable(crt *CreateTableStatement) (error) { + return nil } func NewMemoryBackend() *MemoryBackend { diff --git a/parser.go b/parser.go index da798ec..7505947 100644 --- a/parser.go +++ b/parser.go @@ -193,19 +193,20 @@ func parseSelectStatement(tokens []*token, initialCursor uint, delimiter token) return &slct, cursor, true } -func parse(source io.Reader) (*ast, error) { +func Parse(source io.Reader) (*Ast, error) { tokens, err := lex(source) if err != nil { return nil, err } - a := ast{} + a := Ast{} cursor := uint(0) for cursor < uint(len(tokens)) { + stmt := &Statement{} slct, newCursor, ok := parseSelectStatement(tokens, cursor, semicolonToken) if ok { - a.kind = selectKind - a.slct = slct + stmt.kind = selectKind + stmt.SelectStatement = slct cursor = newCursor } @@ -213,6 +214,8 @@ func parse(source io.Reader) (*ast, error) { return nil, errors.New("Failed to parse") } + a.Statements = append(a.Statements, stmt) + if !expectToken(tokens, cursor, semicolonToken) { helpMessage(tokens, cursor, "Expected semi-colon delimiter between statements") return nil, errors.New("Missing semi-colon between statements") diff --git a/parser_test.go b/parser_test.go index ac1419c..b916a94 100644 --- a/parser_test.go +++ b/parser_test.go @@ -10,24 +10,28 @@ import ( func TestParse(t *testing.T) { tests := []struct { source string - ast *ast + ast *Ast }{ { source: "SELECT *, exclusive", - ast: &ast{ - kind: selectKind, - slct: &SelectStatement{ - item: &[]*selectItem{ - { - asterisk: true, - }, - { - exp: &expression{ - kind: literalKind, - literal: &token{ - loc: location{col: 9, line: 0}, - kind: identifierKind, - value: "exclusive", + ast: &Ast{ + Statements: []*Statement{ + { + kind: selectKind, + SelectStatement: &SelectStatement{ + item: &[]*selectItem{ + { + asterisk: true, + }, + { + exp: &expression{ + kind: literalKind, + literal: &token{ + loc: location{col: 9, line: 0}, + kind: identifierKind, + value: "exclusive", + }, + }, }, }, }, @@ -37,41 +41,45 @@ func TestParse(t *testing.T) { }, { source: "SELECT id, name AS fullname FROM users", - ast: &ast{ - kind: selectKind, - slct: &SelectStatement{ - item: &[]*selectItem{ - { - exp: &expression{ - kind: literalKind, - literal: &token{ - loc: location{col: 6, line: 0}, - kind: identifierKind, - value: "id", + ast: &Ast{ + Statements: []*Statement{ + { + kind: selectKind, + SelectStatement: &SelectStatement{ + item: &[]*selectItem{ + { + exp: &expression{ + kind: literalKind, + literal: &token{ + loc: location{col: 6, line: 0}, + kind: identifierKind, + value: "id", + }, + }, + }, + { + exp: &expression{ + kind: literalKind, + literal: &token{ + loc: location{col: 10, line: 0}, + kind: identifierKind, + value: "name", + }, + }, + as: &token{ + loc: location{col: 18, line: 0}, + kind: identifierKind, + value: "fullname", + }, }, }, - }, - { - exp: &expression{ - kind: literalKind, - literal: &token{ - loc: location{col: 10, line: 0}, + from: &fromItem{ + table: &token{ + loc: location{col: 32, line: 0}, kind: identifierKind, - value: "name", + value: "users", }, }, - as: &token{ - loc: location{col: 18, line: 0}, - kind: identifierKind, - value: "fullname", - }, - }, - }, - from: &fromItem{ - table: &token{ - loc: location{col: 32, line: 0}, - kind: identifierKind, - value: "users", }, }, }, @@ -80,7 +88,7 @@ func TestParse(t *testing.T) { } for _, test := range tests { - ast, err := parse(bytes.NewBufferString(test.source)) + ast, err := Parse(bytes.NewBufferString(test.source)) assert.Nil(t, err, test.source) assert.Equal(t, test.ast, ast, test.source) }