Skip to content

Commit

Permalink
[parser] Fix several lexer bugs involving special comments (pingcap#342)
Browse files Browse the repository at this point in the history
* lexer: ensure /*! ... */ follow the same SQL mode as original parser
* lexer: forward stmtText() to specialComment parser if exists
* lexer: ensure invalid tokens in optimizer hint won't loop the parser
  • Loading branch information
kennytm authored and ti-chi-bot committed Oct 9, 2021
1 parent e8b48ae commit 9ff76da
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 3 deletions.
27 changes: 24 additions & 3 deletions parser/lexer.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ type Scanner struct {
}

type specialCommentScanner interface {
stmtTexter
scan() (tok int, pos Pos, lit string)
}

Expand Down Expand Up @@ -85,11 +86,18 @@ func (s *optimizerHintScanner) scan() (tok int, pos Pos, lit string) {
pos.Line += s.Pos.Line
pos.Col += s.Pos.Col
pos.Offset += s.Pos.Offset
if tok == 0 {
switch tok {
case 0:
if !s.end {
tok = hintEnd
s.end = true
}
case invalid:
// an optimizer hint is allowed to contain invalid characters, the
// remaining hints are just ignored.
// force advance the lexer even when encountering an invalid character
// to prevent infinite parser loop. (see issue #336)
s.r.inc()
}
return
}
Expand All @@ -110,6 +118,10 @@ func (s *Scanner) reset(sql string) {
}

func (s *Scanner) stmtText() string {
if s.specialComment != nil {
return s.specialComment.stmtText()
}

endPos := s.r.pos().Offset
if s.r.s[endPos-1] == '\n' {
endPos = endPos - 1 // trim new line
Expand Down Expand Up @@ -220,6 +232,15 @@ func (s *Scanner) EnableWindowFunc(val bool) {
s.supportWindowFunc = val
}

// InheritScanner returns a new scanner object which inherits configurations from the parent scanner.
func (s *Scanner) InheritScanner(sql string) *Scanner {
return &Scanner{
r: reader{s: sql},
sqlMode: s.sqlMode,
supportWindowFunc: s.supportWindowFunc,
}
}

// NewScanner returns a new scanner object.
func NewScanner(s string) *Scanner {
return &Scanner{r: reader{s: s}}
Expand Down Expand Up @@ -396,7 +417,7 @@ func startWithSlash(s *Scanner) (tok int, pos Pos, lit string) {
end := len(comment) - 2
sql := comment[begin:end]
s.specialComment = &optimizerHintScanner{
Scanner: NewScanner(sql),
Scanner: s.InheritScanner(sql),
Pos: Pos{
pos.Line,
pos.Col,
Expand All @@ -413,7 +434,7 @@ func startWithSlash(s *Scanner) (tok int, pos Pos, lit string) {
if strings.HasPrefix(comment, "/*!") {
sql := specCodePattern.ReplaceAllStringFunc(comment, TrimComment)
s.specialComment = &mysqlSpecificCodeScanner{
Scanner: NewScanner(sql),
Scanner: s.InheritScanner(sql),
Pos: Pos{
pos.Line,
pos.Col,
Expand Down
32 changes: 32 additions & 0 deletions parser/parser_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,38 @@ func (s *testParserSuite) TestSimple(c *C) {
c.Assert(vExpr.Kind(), Equals, types.KindUint64)
}

func (s *testParserSuite) TestSpecialComments(c *C) {
parser := parser.New()

// 1. Make sure /*! ... */ respects the same SQL mode.
_, err := parser.ParseOneStmt(`SELECT /*! '\' */;`, "", "")
c.Assert(err, NotNil)

parser.SetSQLMode(mysql.ModeNoBackslashEscapes)
st, err := parser.ParseOneStmt(`SELECT /*! '\' */;`, "", "")
c.Assert(err, IsNil)
c.Assert(st, FitsTypeOf, &ast.SelectStmt{})

// 2. Make sure multiple statements inside /*! ... */ will not crash
// (this is issue #330)
stmts, _, err := parser.Parse("/*! SET x = 1; SELECT 2 */", "", "")
c.Assert(err, IsNil)
c.Assert(stmts, HasLen, 2)
c.Assert(stmts[0], FitsTypeOf, &ast.SetStmt{})
c.Assert(stmts[0].Text(), Equals, "SET x = 1;")
c.Assert(stmts[1], FitsTypeOf, &ast.SelectStmt{})
c.Assert(stmts[1].Text(), Equals, "/*! SET x = 1; SELECT 2 */")
// ^ not sure if correct approach; having multiple statements in MySQL is a syntax error.

// 3. Make sure invalid text won't cause infinite loop
// (this is issue #336)
st, err = parser.ParseOneStmt("SELECT /*+ 😅 */ SLEEP(1);", "", "")
c.Assert(err, IsNil)
sel, ok := st.(*ast.SelectStmt)
c.Assert(ok, IsTrue)
c.Assert(sel.TableHints, HasLen, 0)
}

type testCase struct {
src string
ok bool
Expand Down

0 comments on commit 9ff76da

Please sign in to comment.