Skip to content

Commit

Permalink
Escape invalid names
Browse files Browse the repository at this point in the history
  • Loading branch information
NiseVoid committed Sep 15, 2020
1 parent 03bf7cc commit 76ea5be
Show file tree
Hide file tree
Showing 16 changed files with 143 additions and 101 deletions.
4 changes: 0 additions & 4 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,6 @@ all: setup test lint

.PHONY: setup
setup:
@go install ./qb-generator
@printf "Running go generate ...\n"
@go generate ./...
@printf "Getting dependencies ...\n"
@go get -t ./...
@printf "\n\n"

Expand Down
7 changes: 6 additions & 1 deletion driver/msqb/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,19 @@ func (d Driver) ValueString(i int) string {
return `@p` + strconv.Itoa(i)
}

// BoolString formats a boolean in a format supported by PostgreSQL
// BoolString formats a boolean in a format supported by MSSQL
func (d Driver) BoolString(v bool) string {
if v {
return `1`
}
return `0`
}

// EscapeCharacter returns the correct escape character for MSSQL
func (d Driver) EscapeCharacter() string {
return `"`
}

// UpsertSQL implements qb.Driver
func (d Driver) UpsertSQL(t *qb.Table, _ []qb.Field, q qb.Query) (string, []interface{}) {
panic(`mssql does not support upsert`)
Expand Down
7 changes: 6 additions & 1 deletion driver/myqb/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,19 @@ func (d Driver) ValueString(_ int) string {
return `?`
}

// BoolString formats a boolean in a format supported by PostgreSQL
// BoolString formats a boolean in a format supported by MySQL
func (d Driver) BoolString(v bool) string {
if v {
return `1`
}
return `0`
}

// EscapeCharacter returns the correct escape character for MySQL
func (d Driver) EscapeCharacter() string {
return "`"
}

// UpsertSQL implements qb.Driver
func (d Driver) UpsertSQL(t *qb.Table, _ []qb.Field, q qb.Query) (string, []interface{}) {
usql, values := q.SQL(qb.NewSQLBuilder(d))
Expand Down
5 changes: 5 additions & 0 deletions driver/pgqb/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,11 @@ func (d Driver) BoolString(v bool) string {
return strconv.FormatBool(v)
}

// EscapeCharacter returns the correct escape character for PostgreSQL
func (d Driver) EscapeCharacter() string {
return `"`
}

// UpsertSQL implements qb.Driver
func (d Driver) UpsertSQL(t *qb.Table, conflict []qb.Field, q qb.Query) (string, []interface{}) {
c := qb.NewContext(d, qb.NoAlias())
Expand Down
8 changes: 6 additions & 2 deletions internal/tests/db_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,17 @@ func initDatabase(driverName, connectionString string) *sql.DB {
panic(err)
}

_, _ = db.Exec(`DROP TABLE one, two`)

dropQuery := `DROP TABLE one, "two $#!"`
sql := createSQL
if driverName != `postgres` {
sql = strings.Replace(sql, `timestamp`, `datetime`, -1)
}
if driverName == `mysql` {
sql = strings.Replace(sql, `"`, "`", -1)
dropQuery = strings.Replace(dropQuery, `"`, "`", -1)
}

_, _ = db.Exec(dropQuery)
_, err = db.Exec(sql)
if err != nil {
panic(err)
Expand Down
7 changes: 4 additions & 3 deletions internal/tests/dbstring_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,11 @@ func getMssqlDBString() string {
const createSQL = `
CREATE TABLE one (
ID int PRIMARY KEY,
Name varchar(50) NOT NULL,
CreatedAt timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP
"$Name. #()" varchar(50) NOT NULL,
created_at timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP
);
CREATE TABLE two (
CREATE TABLE "two $#!" (
OneID int,
Number int,
Comment varchar(100) NOT NULL,
Expand Down
6 changes: 3 additions & 3 deletions internal/tests/internal/model/db.json
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@
"name": "one",
"fields": [
{"name": "ID", "read_only": true, "data_type": "int"},
{"name": "Name", "data_type": "varchar", "size": 50},
{"name": "CreatedAt", "read_only": true, "data_type": "timestamp"}
{"name": "$Name. #()", "data_type": "varchar", "size": 50},
{"name": "created_at", "read_only": true, "data_type": "timestamp"}
]
},
{
"name": "two",
"name": "two $#!",
"alias": "tw",
"fields": [
{"name": "OneID", "data_type": "int"},
Expand Down
18 changes: 9 additions & 9 deletions internal/tests/internal/model/tables.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

28 changes: 14 additions & 14 deletions qb-architect/internal/db/msarchitect/msmodel/tables.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

22 changes: 11 additions & 11 deletions qb-architect/internal/db/myarchitect/mymodel/tables.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

36 changes: 18 additions & 18 deletions qb-architect/internal/db/pgarchitect/pgmodel/tables.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

40 changes: 22 additions & 18 deletions qb-generator/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,13 @@ package main // import "git.ultraware.nl/NiseVoid/qb/qb-generator"
import (
"encoding/json"
"flag"
"html/template"
"io"
"log"
"os"
"os/exec"
"regexp"
"strings"
"text/template"

"git.ultraware.nl/NiseVoid/qb/internal/filter"
)
Expand All @@ -31,12 +32,14 @@ type table struct {
Table string
TableString string
Alias string
Escape bool
Fields []field
}

type field struct {
Name string
String string
Escape bool
ReadOnly bool
DataType dataType
}
Expand Down Expand Up @@ -120,31 +123,31 @@ func getDataType(t string, size int, null bool) dataType {
return dataType{``, size, null}
}

var escapeRE = regexp.MustCompile(`[^a-zA-Z0-9_$]`)

func shouldEscape(s string) bool {
return escapeRE.MatchString(s)
}

func newField(f inputField) field {
return field{cleanName(f.String, false, fTrim), f.String, f.ReadOnly, getDataType(f.DataType, f.Size, f.Nullable)}
return field{cleanName(f.String, fTrim), f.String, shouldEscape(f.String), f.ReadOnly, getDataType(f.DataType, f.Size, f.Nullable)}
}

var nameReplacer = strings.NewReplacer(
` `, `_`,
`-`, `_`,
`$`, `_`,
`.`, `_`,
)
func removeSchema(s string) string {
parts := strings.Split(s, `.`)
return parts[len(parts)-1]
}

func cleanName(s string, trimSchema bool, f filter.Filters) string {
if trimSchema {
parts := strings.Split(s, `.`)
s = parts[len(parts)-1]
}
var nameRE = regexp.MustCompile(`[^a-zA-Z0-9_]`)

target := s
func cleanName(s string, f filter.Filters) string {
for _, re := range f {
target = re.ReplaceAllString(target, ``)
s = re.ReplaceAllString(s, ``)
}

target = nameReplacer.Replace(target)
s = nameRE.ReplaceAllString(s, `_`)

parts := strings.Split(target, `_`)
parts := strings.Split(s, `_`)
for k := range parts {
upper := false
for _, v := range fullUpperList {
Expand Down Expand Up @@ -219,9 +222,10 @@ func generateCode(out io.Writer, input []inputTable) error {
tables := make([]table, len(input))
for k, v := range input {
t := &tables[k]
t.Table = cleanName(v.String, true, tTrim)
t.Table = cleanName(removeSchema(v.String), tTrim)
t.Alias = v.Alias
t.TableString = v.String
t.Escape = shouldEscape(removeSchema(v.String))

for _, f := range v.Fields {
t.Fields = append(t.Fields, newField(f))
Expand Down
Loading

0 comments on commit 76ea5be

Please sign in to comment.