Skip to content

Commit

Permalink
feat: support multi-insert
Browse files Browse the repository at this point in the history
feat: support insert type verification
  • Loading branch information
xeonds committed Sep 10, 2024
1 parent e7cae64 commit 5553ed0
Show file tree
Hide file tree
Showing 5 changed files with 57 additions and 103 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ sqlc currently supports the following SQL statements:
- `CREATE TABLE IDENTIFIER ( table_columns );`
- `SHOW TABLES;`
- `SHOW DATABASES;`
- `INSERT INTO IDENTIFIER ( column1, column2, ... ) VALUES ( value1, value2, ... );`
- `INSERT INTO IDENTIFIER ( column1, column2, ... ) VALUES ( value1, value2, ... ) [ ( value1, value2, ... ) ... ];`
- `UPDATE IDENTIFIER SET IDENTIFIER EQUALS value [ WHERE condition ];`
- `DELETE FROM IDENTIFIER [ WHERE condition ];`
- `DROP TABLE IDENTIFIER;`
Expand Down
12 changes: 9 additions & 3 deletions TODO.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
- insert type verification
- interactive sql tab completion&arrow navigation
- headless mode for sql file execution
- [x]支持INSERT后跟随多个VALUES
- [x]插入类型验证
- [ ]交互tab补全&方向键移动+历史命令
- [ ]多行sql语句
- [ ]注释
- [x]无头模式执行sql文件
- [ ]多表join查询
- [ ]索引功能
- [ ]视图功能
93 changes: 2 additions & 91 deletions src/ast.ml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ type expr =
| CreateTable of string * (string * data_type) list
| ShowTables
| ShowDatabases
| InsertInto of string * string list * value list
| InsertInto of string * string list * value list list
| Select of string list * string * (condition option)
| Update of string * string * value * (condition option)
| Delete of string * (condition option)
Expand All @@ -28,93 +28,4 @@ and condition =
| Equal of string * value
| And of condition * condition
| Or of condition * condition
| Not of condition

(* Show parsed expressions *)
let rec show_expr = function
| CreateDatabase name -> "CreateDatabase " ^ name
| UseDatabase name -> "UseDatabase " ^ name
| CreateTable (name, columns) -> "CreateTable " ^ name ^ " (" ^ (String.concat ", " (List.map show_column columns)) ^ ")"
| ShowTables -> "ShowTables"
| ShowDatabases -> "ShowDatabases"
| InsertInto (name, columns, values) -> "InsertInto " ^ name ^ " (" ^ (String.concat ", " columns) ^ ") VALUES (" ^ (String.concat ", " (List.map show_value values)) ^ ")"
| Select (columns, table, opt_where) -> "Select " ^ (String.concat ", " columns) ^ " FROM " ^ table ^ (match opt_where with Some(cond) -> " WHERE " ^ show_condition cond | None -> "")
| Update (table, column, value, opt_where) -> "Update " ^ table ^ " SET " ^ column ^ " = " ^ show_value value ^ (match opt_where with Some(cond) -> " WHERE " ^ show_condition cond | None -> "")
| Delete (table, opt_where) -> "Delete FROM " ^ table ^ (match opt_where with Some(cond) -> " WHERE " ^ show_condition cond | None -> "")
| DropTable name -> "DropTable " ^ name
| DropDatabase name -> "DropDatabase " ^ name
| Exit -> "Exit"
and show_column (name, data_type) = name ^ " " ^ (match data_type with
| IntType -> "INT"
| StringType -> "STRING"
| FloatType -> "FLOAT"
| BoolType -> "BOOL")
and show_value = function
| IntValue v -> string_of_int v
| StringValue v -> "\"" ^ v ^ "\""
| FloatValue v -> string_of_float v
| BoolValue v -> string_of_bool v
and show_condition = function
| LessThan (col, value) -> col ^ " < " ^ show_value value
| GreaterThan (col, value) -> col ^ " > " ^ show_value value
| LessEqual (col, value) -> col ^ " <= " ^ show_value value
| GreaterEqual (col, value) -> col ^ " >= " ^ show_value value
| NotEqual (col, value) -> col ^ " != " ^ show_value value
| Equal (col, value) -> col ^ " = " ^ show_value value
| And (cond1, cond2) -> show_condition cond1 ^ " AND " ^ show_condition cond2
| Or (cond1, cond2) -> show_condition cond1 ^ " OR " ^ show_condition cond2
| Not cond -> "NOT " ^ show_condition cond

(* Generate OCaml code *)
let rec generate_code = function
| CreateDatabase name -> Printf.sprintf "create_database \"%s\"" name
| UseDatabase name -> Printf.sprintf "use_database \"%s\"" name
| CreateTable (name, columns) ->
let cols = columns |> List.map (fun (col, typ) -> Printf.sprintf "(\"%s\", %s)" col (string_of_data_type typ)) |> String.concat "; " in
Printf.sprintf "create_table \"%s\" [%s]" name cols
| ShowTables -> "show_tables ()"
| ShowDatabases -> "show_databases ()"
| InsertInto (table, cols, vals) ->
let cols_str = String.concat "; " (List.map (Printf.sprintf "\"%s\"") cols) in
let vals_str = String.concat "; " (List.map string_of_value vals) in
Printf.sprintf "insert_into \"%s\" [%s] [%s]" table cols_str vals_str
| Select (cols, table, cond) ->
let cols_str = String.concat "; " (List.map (Printf.sprintf "\"%s\"") cols) in
let cond_str = match cond with
| Some c -> generate_condition_code c
| None -> "None" in
Printf.sprintf "select [%s] \"%s\" %s" cols_str table cond_str
| Update (table, col, value, cond) ->
let value_str = string_of_value value in
let cond_str = match cond with
| Some c -> generate_condition_code c
| None -> "None" in
Printf.sprintf "update \"%s\" \"%s\" %s %s" table col value_str cond_str
| Delete (table, cond) ->
let cond_str = match cond with
| Some c -> generate_condition_code c
| None -> "None" in
Printf.sprintf "delete \"%s\" %s" table cond_str
| DropTable name -> Printf.sprintf "drop_table \"%s\"" name
| DropDatabase name -> Printf.sprintf "drop_database \"%s\"" name
| Exit -> "exit_program ()"
and string_of_value = function
| IntValue i -> string_of_int i
| StringValue s -> Printf.sprintf "\"%s\"" s
| FloatValue f -> string_of_float f
| BoolValue b -> string_of_bool b
and string_of_data_type = function
| IntType -> "IntType"
| StringType -> "StringType"
| FloatType -> "FloatType"
| BoolType -> "BoolType"
and generate_condition_code = function
| LessThan (col, value) -> Printf.sprintf "LessThan(\"%s\", %s)" col (string_of_value value)
| GreaterThan (col, value) -> Printf.sprintf "GreaterThan(\"%s\", %s)" col (string_of_value value)
| LessEqual (col, value) -> Printf.sprintf "LessEqual(\"%s\", %s)" col (string_of_value value)
| GreaterEqual (col, value) -> Printf.sprintf "GreaterEqual(\"%s\", %s)" col (string_of_value value)
| NotEqual (col, value) -> Printf.sprintf "NotEqual(\"%s\", %s)" col (string_of_value value)
| Equal (col, value) -> Printf.sprintf "Equal(\"%s\", %s)" col (string_of_value value)
| And (c1, c2) -> Printf.sprintf "And(%s, %s)" (generate_condition_code c1) (generate_condition_code c2)
| Or (c1, c2) -> Printf.sprintf "Or(%s, %s)" (generate_condition_code c1) (generate_condition_code c2)
| Not c -> Printf.sprintf "Not(%s)" (generate_condition_code c)
| Not of condition
44 changes: 38 additions & 6 deletions src/eval.ml
Original file line number Diff line number Diff line change
Expand Up @@ -65,12 +65,35 @@ let show_databases () =
| exception Sys_error msg -> Printf.printf "Error: %s\n" msg

(* 将value转换为字符串 *)
let value_to_string = function
let string_of_value = function
| IntValue v -> string_of_int v
| StringValue v -> v
| FloatValue v -> string_of_float v
| BoolValue v -> string_of_bool v

let value_of_string = function
| "true" -> BoolValue true
| "false" -> BoolValue false
| s -> match int_of_string_opt s with
| Some i -> IntValue i
| None -> match float_of_string_opt s with
| Some f -> FloatValue f
| None -> StringValue s

let type_of_string string = match value_of_string string with
| IntValue _ -> IntType
| StringValue _ -> StringType
| FloatValue _ -> FloatType
| BoolValue _ -> BoolType

let type_of_data data = match data with
| IntValue _ -> IntType
| StringValue _ -> StringType
| FloatValue _ -> FloatType
| BoolValue _ -> BoolType

(* 条件表达式 *)

(* 条件表达式求值 *)
let rec eval_cond cond row headers = match cond with
| LessThan (col, value) -> (match List.assoc col (List.mapi (fun i h -> (h, i)) headers), value with
Expand Down Expand Up @@ -113,14 +136,23 @@ let insert_into table_name columns values =
let csvOut = Csv.to_channel (open_out_gen [Open_append] 0o666 table_path) in
let headers = Csv.next csvIn in
let types = List.map2 (fun h t -> (h, type_of_name t)) headers (Csv.next csvIn) in
Csv.output_record csvOut (List.map (fun header ->
List.iteri (fun row value -> Csv.output_record csvOut (List.map (fun header ->
match List.assoc_opt header (List.mapi (fun i h -> (h, i)) columns) with
| Some index -> value_to_string (List.nth values index)
| None -> value_to_string (match List.assoc header types with
| Some index -> (
let _,t = List.nth types index in
let tt = type_of_data (List.nth value index) in
if t != tt then Printf.printf "Type mismatch for row %d, column %s\n; Replaced with default value" row header;
if t == tt then string_of_value(List.nth value index)
else string_of_value (match t with
| IntType -> IntValue 0
| FloatType -> FloatValue 0.0
| StringType -> StringValue ""
| BoolType -> BoolValue false))
| None -> string_of_value (match List.assoc header types with
| IntType -> IntValue 0
| FloatType -> FloatValue 0.0
| StringType -> StringValue ""
| BoolType -> BoolValue false)) headers);
| BoolType -> BoolValue false)) headers)) values;
Csv.close_in csvIn;
Csv.close_out csvOut;
else Printf.printf "Table %s does not exist.\n" table_name
Expand Down Expand Up @@ -166,7 +198,7 @@ let update_table table_name column value condition =
let row_match_cond = match condition with
| None -> true
| Some cond -> (eval_cond cond row headers) in
if row_match_cond then List.mapi (fun j v -> if j == col_index then value_to_string value else v) row
if row_match_cond then List.mapi (fun j v -> if j == col_index then string_of_value value else v) row
else row) records in
let csv = Csv.to_channel (open_out table_path) in
Csv.output_record csv headers;
Expand Down
9 changes: 7 additions & 2 deletions src/parser.mly
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

main:
| statement SEMICOLON { $1 }
| EOF { Exit }

statement:
| SELECT columns FROM IDENTIFIER opt_where { Select($2, $4, $5) }
Expand All @@ -36,7 +37,7 @@ statement:
| CREATE TABLE IDENTIFIER LPAREN table_columns RPAREN { CreateTable($3, $5) }
| SHOW TABLES { ShowTables }
| SHOW DATABASES { ShowDatabases }
| INSERT INTO IDENTIFIER LPAREN columns RPAREN VALUES LPAREN values RPAREN { InsertInto($3, $5, $9) }
| INSERT INTO IDENTIFIER LPAREN columns RPAREN VALUES values { InsertInto($3, $5, $8) }
| UPDATE IDENTIFIER SET IDENTIFIER EQUALS value opt_where { Update($2, $4, $6, $7) }
| DELETE FROM IDENTIFIER opt_where { Delete($3, $4) }
| DROP TABLE IDENTIFIER { DropTable $3 }
Expand All @@ -56,7 +57,11 @@ columns:
| IDENTIFIER { [$1] }

values:
| value COMMA values { $1 :: $3 }
| LPAREN values_def RPAREN values { $2 :: $4 }
| LPAREN values_def RPAREN { [$2] }

values_def:
| value COMMA values_def { $1 :: $3 }
| value { [$1] }

value:
Expand Down

0 comments on commit 5553ed0

Please sign in to comment.