Skip to content

Commit

Permalink
fix XiaoMi#255 key words format to lower case
Browse files Browse the repository at this point in the history
  • Loading branch information
martianzhang committed Jun 3, 2020
1 parent 0787048 commit 342ec90
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 26 deletions.
52 changes: 26 additions & 26 deletions advisor/heuristic.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ func (q *Query4Audit) RulePrefixLike() Rule {
err := sqlparser.Walk(func(node sqlparser.SQLNode) (kontinue bool, err error) {
switch expr := node.(type) {
case *sqlparser.ComparisonExpr:
if expr.Operator == "like" {
if strings.ToLower(expr.Operator) == "like" {
switch sqlval := expr.Right.(type) {
case *sqlparser.SQLVal:
// prefix like with '%', '_'
Expand All @@ -130,7 +130,7 @@ func (q *Query4Audit) RuleEqualLike() Rule {
err := sqlparser.Walk(func(node sqlparser.SQLNode) (kontinue bool, err error) {
switch expr := node.(type) {
case *sqlparser.ComparisonExpr:
if expr.Operator == "like" {
if strings.ToLower(expr.Operator) == "like" {
switch sqlval := expr.Right.(type) {
case *sqlparser.SQLVal:
// not start with '%', '_' && not end with '%', '_'
Expand Down Expand Up @@ -397,7 +397,7 @@ func (q *Query4Audit) RuleOrderByRand() Rule {
for _, order := range n {
switch expr := order.Expr.(type) {
case *sqlparser.FuncExpr:
if expr.Name.String() == "rand" {
if strings.ToLower(expr.Name.String()) == "rand" {
rule = HeuristicRules["CLA.002"]
return false, nil
}
Expand Down Expand Up @@ -761,7 +761,7 @@ func (q *Query4Audit) RuleTblCommentCheck() Rule {
var rule = q.RuleOK()
switch node := q.Stmt.(type) {
case *sqlparser.DDL:
if node.Action != "create" {
if strings.ToLower(node.Action) != "create" {
return rule
}
if node.TableSpec == nil {
Expand Down Expand Up @@ -968,7 +968,7 @@ func (q *Query4Audit) RuleSQLCalcFoundRows() Rule {
var rule = q.RuleOK()
tkns := ast.Tokenizer(q.Query)
for _, tkn := range tkns {
if tkn.Val == "sql_calc_found_rows" {
if strings.ToLower(tkn.Val) == "sql_calc_found_rows" {
rule = HeuristicRules["KWR.001"]
break
}
Expand Down Expand Up @@ -1049,7 +1049,7 @@ func (idxAdv *IndexAdvisor) RuleImpossibleOuterJoin() Rule {

for _, l1 := range idxAdv.joinCond {
for _, l2 := range l1 {
if l2.Table != "" && l2.Table != "dual" {
if l2.Table != "" && strings.ToLower(l2.Table) != "dual" {
joinTables = append(joinTables, l2.Table)
}
}
Expand Down Expand Up @@ -1192,7 +1192,7 @@ func (q *Query4Audit) RuleImpossibleWhere() Rule {
err := sqlparser.Walk(func(node sqlparser.SQLNode) (kontinue bool, err error) {
switch n := node.(type) {
case *sqlparser.RangeCond:
if n.Operator == "between" {
if strings.ToLower(n.Operator) == "between" {
from := 0
to := 0
switch s := n.From.(type) {
Expand Down Expand Up @@ -1893,7 +1893,7 @@ func (q *Query4Audit) RuleSysdate() Rule {
err := sqlparser.Walk(func(node sqlparser.SQLNode) (kontinue bool, err error) {
switch n := node.(type) {
case *sqlparser.FuncExpr:
if n.Name.String() == "sysdate" {
if strings.ToLower(n.Name.String()) == "sysdate" {
rule = HeuristicRules["FUN.004"]
return false, nil
}
Expand Down Expand Up @@ -2161,7 +2161,7 @@ func (idxAdv *IndexAdvisor) RuleUpdatePrimaryKey() Rule {
return rule
}
for _, idx := range idxMeta.Rows {
if idx.KeyName == "PRIMARY" {
if strings.ToLower(idx.KeyName) == "primary" {
if col.Name == idx.ColumnName {
rule = HeuristicRules["CLA.016"]
return rule
Expand Down Expand Up @@ -2310,7 +2310,7 @@ func (q *Query4Audit) RuleNot() Rule {
err := sqlparser.Walk(func(node sqlparser.SQLNode) (kontinue bool, err error) {
switch n := node.(type) {
case *sqlparser.ComparisonExpr:
if strings.HasPrefix(n.Operator, "not") {
if strings.HasPrefix(strings.ToLower(n.Operator), "not") {
rule = HeuristicRules["ARG.011"]
return false, nil
}
Expand Down Expand Up @@ -2359,7 +2359,7 @@ func (q *Query4Audit) RuleUNIONUsage() Rule {
var rule = q.RuleOK()
switch s := q.Stmt.(type) {
case *sqlparser.Union:
if s.Type == "union" {
if strings.ToLower(s.Type) == "union" {
rule = HeuristicRules["SUB.002"]
}
}
Expand Down Expand Up @@ -2435,11 +2435,11 @@ func (q *Query4Audit) RuleDataDrop() Rule {
var rule = q.RuleOK()
switch s := q.Stmt.(type) {
case *sqlparser.DBDDL:
if s.Action == "drop" {
if strings.ToLower(s.Action) == "drop" {
rule = HeuristicRules["SEC.003"]
}
case *sqlparser.DDL:
if s.Action == "drop" || s.Action == "truncate" {
if strings.ToLower(s.Action) == "drop" || strings.ToLower(s.Action) == "truncate" {
rule = HeuristicRules["SEC.003"]
}
case *sqlparser.Delete:
Expand Down Expand Up @@ -2523,7 +2523,7 @@ func (q *Query4Audit) RuleTruncateTable() Rule {
var rule = q.RuleOK()
switch s := q.Stmt.(type) {
case *sqlparser.DDL:
if s.Action == "truncate" {
if strings.ToLower(s.Action) == "truncate" {
rule = HeuristicRules["SEC.001"]
}
}
Expand All @@ -2536,7 +2536,7 @@ func (q *Query4Audit) RuleIn() Rule {
err := sqlparser.Walk(func(node sqlparser.SQLNode) (kontinue bool, err error) {
switch n := node.(type) {
case *sqlparser.ComparisonExpr:
switch n.Operator {
switch strings.ToLower(n.Operator) {
case "in":
switch r := n.Right.(type) {
case sqlparser.ValTuple:
Expand Down Expand Up @@ -2842,12 +2842,12 @@ func (q *Query4Audit) RulePKNotInt() Rule {
var pk sqlparser.ColIdent
switch s := q.Stmt.(type) {
case *sqlparser.DDL:
if s.Action == "create" {
if strings.ToLower(s.Action) == "create" {
if s.TableSpec == nil {
return rule
}
for _, idx := range s.TableSpec.Indexes {
if idx.Info.Type == "primary key" {
if strings.ToLower(idx.Info.Type) == "primary key" {
if len(idx.Columns) == 1 {
pk = idx.Columns[0].Column
break
Expand All @@ -2864,7 +2864,7 @@ func (q *Query4Audit) RulePKNotInt() Rule {
// 主键非int, bigint类型
for _, col := range s.TableSpec.Columns {
if pk.String() == col.Name.String() {
switch col.Type.Type {
switch strings.ToLower(col.Type.Type) {
case "int", "bigint", "integer":
if !col.Type.Unsigned {
rule = HeuristicRules["KEY.007"]
Expand Down Expand Up @@ -2971,7 +2971,7 @@ func (q *Query4Audit) RuleFulltextIndex() Rule {
for _, tk := range tks {
switch tk.Type {
case ast.TokenTypeWord:
if strings.TrimSpace(strings.ToUpper(tk.Val)) == "FULLTEXT" {
if strings.TrimSpace(strings.ToLower(tk.Val)) == "fulltext" {
rule = HeuristicRules["KEY.010"]
}
default:
Expand Down Expand Up @@ -3001,8 +3001,8 @@ func (q *Query4Audit) RuleTimestampDefault() Rule {
if option.Tp == tidb.ColumnOptionDefaultValue {
hasDefault = true
if err := option.Restore(ctx); err == nil {
if strings.HasPrefix(sb.String(), `DEFAULT '0`) ||
strings.HasPrefix(sb.String(), `DEFAULT 0`) {
if strings.HasPrefix(strings.ToLower(sb.String()), `default '0`) ||
strings.HasPrefix(strings.ToLower(sb.String()), `default 0`) {
hasDefault = false
}
}
Expand Down Expand Up @@ -3034,8 +3034,8 @@ func (q *Query4Audit) RuleTimestampDefault() Rule {
if option.Tp == tidb.ColumnOptionDefaultValue {
hasDefault = true
if err := option.Restore(ctx); err == nil {
if strings.HasPrefix(sb.String(), `DEFAULT '0`) ||
strings.HasPrefix(sb.String(), `DEFAULT 0`) {
if strings.HasPrefix(strings.ToLower(sb.String()), `default '0`) ||
strings.HasPrefix(strings.ToLower(sb.String()), `default 0`) {
hasDefault = false
}
}
Expand Down Expand Up @@ -3464,7 +3464,7 @@ func (q *Query4Audit) RuleColumnNotAllowType() Rule {

switch s := q.Stmt.(type) {
case *sqlparser.DDL:
switch s.Action {
switch strings.ToLower(s.Action) {
case "create", "alter":
tks := ast.Tokenize(q.Query)
for _, tk := range tks {
Expand Down Expand Up @@ -3536,7 +3536,7 @@ func (q *Query4Audit) RuleNoOSCKey() Rule {
var rule = q.RuleOK()
switch s := q.Stmt.(type) {
case *sqlparser.DDL:
if s.Action == "create" {
if strings.ToLower(s.Action) == "create" {
pkReg := regexp.MustCompile(`(?i)(primary\s+key)`)
if !pkReg.MatchString(q.Query) {
ukReg := regexp.MustCompile(`(?i)(unique\s+((key)|(index)))`)
Expand Down Expand Up @@ -3605,7 +3605,7 @@ func (idxAdv *IndexAdvisor) RuleMaxTextColsCount() Rule {
err := sqlparser.Walk(func(node sqlparser.SQLNode) (kontinue bool, err error) {
switch stmt := node.(type) {
case *sqlparser.DDL:
if stmt.Action != "alter" {
if strings.ToLower(stmt.Action) != "alter" {
return true, nil
}

Expand Down
4 changes: 4 additions & 0 deletions advisor/heuristic_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1575,6 +1575,7 @@ func TestRuleSysdate(t *testing.T) {
common.Log.Debug("Entering function: %s", common.GetFunctionName())
sqls := []string{
`select sysdate();`,
`select Sysdate();`,
}
for _, sql := range sqls {
q, err := NewQuery4Audit(sql)
Expand Down Expand Up @@ -2435,6 +2436,7 @@ func TestRuleInjection(t *testing.T) {
{
`select benchmark(10, rand())`,
`select sleep(1)`,
`select Sleep(1)`,
`select get_lock('lock_name', 1)`,
`select release_lock('lock_name')`,
},
Expand Down Expand Up @@ -2542,6 +2544,7 @@ func TestRuleTruncateTable(t *testing.T) {
common.Log.Debug("Entering function: %s", common.GetFunctionName())
sqls := []string{
`TRUNCATE TABLE tbl_name;`,
`truncate TABLE tbl_name;`,
}
for _, sql := range sqls {
q, err := NewQuery4Audit(sql)
Expand Down Expand Up @@ -2861,6 +2864,7 @@ func TestRulePKNotInt(t *testing.T) {
},
{
"CREATE TABLE tbl (a int unsigned auto_increment, b int, primary key(`a`)) engine=InnoDB;",
"CREATE TABLE `tb` ( `id` Bigint unsigned NOT NULL AUTO_INCREMENT COMMENT 'auto id', Primary key (`id`) ) ENGINE = InnoDB COMMENT 'comment'",
},
}
for _, sql := range sqls[0] {
Expand Down

0 comments on commit 342ec90

Please sign in to comment.