Skip to content

Commit

Permalink
*: add specified columns for LOAD DATA INFILE Syntax (pingcap#3240)
Browse files Browse the repository at this point in the history
* : add specified columns for LOAD DATA INFILE Syntax
  • Loading branch information
IANTHEREAL authored May 15, 2017
1 parent 9b8a445 commit 213954d
Show file tree
Hide file tree
Showing 12 changed files with 106 additions and 14 deletions.
8 changes: 8 additions & 0 deletions ast/dml.go
Original file line number Diff line number Diff line change
Expand Up @@ -657,6 +657,7 @@ type LoadDataStmt struct {
IsLocal bool
Path string
Table *TableName
Columns []*ColumnName
FieldsInfo *FieldsClause
LinesInfo *LinesClause
}
Expand All @@ -675,6 +676,13 @@ func (n *LoadDataStmt) Accept(v Visitor) (Node, bool) {
}
n.Table = node.(*TableName)
}
for i, val := range n.Columns {
node, ok := val.Accept(v)
if !ok {
return n, false
}
n.Columns[i] = node.(*ColumnName)
}
return v.Leave(n)
}

Expand Down
5 changes: 5 additions & 0 deletions context/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,11 @@ type Context interface {

GetSessionManager() util.SessionManager

// RefreshTxnCtx commits old transaction without retry,
// and creates a new transation.
// now just for load data and batch insert.
RefreshTxnCtx() error

// ActivePendingTxn receives the pending transaction from the transaction channel.
// It should be called right before we builds an executor.
ActivePendingTxn() error
Expand Down
12 changes: 10 additions & 2 deletions executor/builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -282,17 +282,25 @@ func (b *executorBuilder) buildLoadData(v *plan.LoadData) Executor {
b.err = errors.Errorf("Can not get table %d", v.Table.TableInfo.ID)
return nil
}
insertVal := &InsertValues{ctx: b.ctx, Table: tbl, Columns: v.Columns}
tableCols := tbl.WritableCols()
columns, err := insertVal.getColumns(tableCols)
if err != nil {
b.err = errors.Trace(err)
return nil
}

return &LoadData{
IsLocal: v.IsLocal,
loadDataInfo: &LoadDataInfo{
row: make([]types.Datum, len(tbl.Cols())),
insertVal: &InsertValues{ctx: b.ctx, Table: tbl},
row: make([]types.Datum, len(columns)),
insertVal: insertVal,
Path: v.Path,
Table: tbl,
FieldsInfo: v.FieldsInfo,
LinesInfo: v.LinesInfo,
Ctx: b.ctx,
columns: columns,
},
}
}
Expand Down
6 changes: 4 additions & 2 deletions executor/write.go
Original file line number Diff line number Diff line change
Expand Up @@ -271,12 +271,13 @@ func (e *DeleteExec) Open() error {
}

// NewLoadDataInfo returns a LoadDataInfo structure, and it's only used for tests now.
func NewLoadDataInfo(row []types.Datum, ctx context.Context, tbl table.Table) *LoadDataInfo {
func NewLoadDataInfo(row []types.Datum, ctx context.Context, tbl table.Table, cols []*table.Column) *LoadDataInfo {
return &LoadDataInfo{
row: row,
insertVal: &InsertValues{ctx: ctx, Table: tbl},
Table: tbl,
Ctx: ctx,
columns: cols,
}
}

Expand All @@ -290,6 +291,7 @@ type LoadDataInfo struct {
FieldsInfo *ast.FieldsClause
LinesInfo *ast.LinesClause
Ctx context.Context
columns []*table.Column
}

// SetBatchCount sets the number of rows to insert in a batch.
Expand Down Expand Up @@ -506,7 +508,7 @@ func (e *LoadDataInfo) insertData(cols []string) {
}
e.row[i].SetString(cols[i])
}
row, err := e.insertVal.fillRowData(e.Table.Cols(), e.row, true)
row, err := e.insertVal.fillRowData(e.columns, e.row, true)
if err != nil {
warnLog := fmt.Sprintf("Load Data: insert data:%v failed:%v", e.row, errors.ErrorStack(err))
e.insertVal.handleLoadDataWarnings(err, warnLog)
Expand Down
42 changes: 38 additions & 4 deletions executor/write_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import (
"github.com/pingcap/tidb/kv"
"github.com/pingcap/tidb/model"
"github.com/pingcap/tidb/sessionctx"
"github.com/pingcap/tidb/table"
"github.com/pingcap/tidb/util/testkit"
"github.com/pingcap/tidb/util/testleak"
"github.com/pingcap/tidb/util/types"
Expand Down Expand Up @@ -631,7 +632,7 @@ func (s *testSuite) TestLoadData(c *C) {
c.Assert(err, NotNil)
tk.MustExec("load data local infile '/tmp/nonexistence.csv' into table load_data_test")
ctx := tk.Se.(context.Context)
ld := makeLoadDataInfo(4, ctx, c)
ld := makeLoadDataInfo(4, nil, ctx, c)

deleteSQL := "delete from load_data_test"
selectSQL := "select * from load_data_test;"
Expand Down Expand Up @@ -789,7 +790,7 @@ func (s *testSuite) TestLoadDataEscape(c *C) {
tk.MustExec("CREATE TABLE load_data_test (id INT NOT NULL PRIMARY KEY, value TEXT NOT NULL) CHARACTER SET utf8")
tk.MustExec("load data local infile '/tmp/nonexistence.csv' into table load_data_test")
ctx := tk.Se.(context.Context)
ld := makeLoadDataInfo(2, ctx, c)
ld := makeLoadDataInfo(2, nil, ctx, c)
// test escape
tests := []testCase{
// data1 = nil, data2 != nil
Expand All @@ -805,15 +806,48 @@ func (s *testSuite) TestLoadDataEscape(c *C) {
checkCases(tests, ld, c, tk, ctx, selectSQL, deleteSQL)
}

func makeLoadDataInfo(column int, ctx context.Context, c *C) (ld *executor.LoadDataInfo) {
// reuse TestLoadDataEscape's test case :-)
func (s *testSuite) TestLoadDataSpecifiedCoumns(c *C) {
defer func() {
s.cleanEnv(c)
testleak.AfterTest(c)()
}()
tk := testkit.NewTestKit(c, s.store)
tk.MustExec("use test; drop table if exists load_data_test;")
tk.MustExec(`create table load_data_test (id int PRIMARY KEY AUTO_INCREMENT, c1 int, c2 varchar(255) default "def", c3 int default 0);`)
tk.MustExec("load data local infile '/tmp/nonexistence.csv' into table load_data_test (c1, c2)")
ctx := tk.Se.(context.Context)
ld := makeLoadDataInfo(2, []string{"c1", "c2"}, ctx, c)
// test
tests := []testCase{
// data1 = nil, data2 != nil
{nil, []byte("7\ta string\n"), []string{"1|7|a string|0"}, nil},
{nil, []byte("8\tstr \\t\n"), []string{"2|8|str \t|0"}, nil},
{nil, []byte("9\tstr \\n\n"), []string{"3|9|str \n|0"}, nil},
{nil, []byte("10\tboth \\t\\n\n"), []string{"4|10|both \t\n|0"}, nil},
{nil, []byte("11\tstr \\\\\n"), []string{"5|11|str \\|0"}, nil},
{nil, []byte("12\t\\r\\t\\n\\0\\Z\\b\n"), []string{"6|12|" + string([]byte{'\r', '\t', '\n', 0, 26, '\b'}) + "|0"}, nil},
}
deleteSQL := "delete from load_data_test"
selectSQL := "select * from load_data_test;"
checkCases(tests, ld, c, tk, ctx, selectSQL, deleteSQL)
}

func makeLoadDataInfo(column int, specifiedColumns []string, ctx context.Context, c *C) (ld *executor.LoadDataInfo) {
domain := sessionctx.GetDomain(ctx)
is := domain.InfoSchema()
c.Assert(is, NotNil)
tbl, err := is.TableByName(model.NewCIStr("test"), model.NewCIStr("load_data_test"))
c.Assert(err, IsNil)
columns := tbl.Cols()
// filter specified columns
if len(specifiedColumns) > 0 {
columns, err = table.FindCols(columns, specifiedColumns)
c.Assert(err, IsNil)
}
fields := &ast.FieldsClause{Terminated: "\t"}
lines := &ast.LinesClause{Starting: "", Terminated: "\n"}
ld = executor.NewLoadDataInfo(make([]types.Datum, column), ctx, tbl)
ld = executor.NewLoadDataInfo(make([]types.Datum, column), ctx, tbl, columns)
ld.SetBatchCount(0)
ld.FieldsInfo = fields
ld.LinesInfo = lines
Expand Down
14 changes: 13 additions & 1 deletion parser/parser.y
Original file line number Diff line number Diff line change
Expand Up @@ -565,6 +565,7 @@ import (
ColumnName "column name"
ColumnNameList "column name list"
ColumnNameListOpt "column name list opt"
ColumnNameListOptWithBrackets "column name list opt with brackets"
ColumnSetValue "insert statement set value by column name"
ColumnSetValueList "insert statement set value by column name list"
CommitStmt "COMMIT statement"
Expand Down Expand Up @@ -1218,6 +1219,16 @@ ColumnNameListOpt:
$$ = $1.([]*ast.ColumnName)
}

ColumnNameListOptWithBrackets:
/* EMPTY */
{
$$ = []*ast.ColumnName{}
}
| '(' ColumnNameListOpt ')'
{
$$ = $2.([]*ast.ColumnName)
}

CommitStmt:
"COMMIT"
{
Expand Down Expand Up @@ -6272,11 +6283,12 @@ RevokeStmt:
* See https://dev.mysql.com/doc/refman/5.7/en/load-data.html
*******************************************************************************************/
LoadDataStmt:
"LOAD" "DATA" LocalOpt "INFILE" stringLit "INTO" "TABLE" TableName Fields Lines
"LOAD" "DATA" LocalOpt "INFILE" stringLit "INTO" "TABLE" TableName Fields Lines ColumnNameListOptWithBrackets
{
x := &ast.LoadDataStmt{
Path: $5,
Table: $8.(*ast.TableName),
Columns: $11.([]*ast.ColumnName),
}
if $3 != nil {
x.IsLocal = true
Expand Down
10 changes: 10 additions & 0 deletions parser/parser_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,16 @@ func (s *testParserSuite) TestDMLStmt(c *C) {
{"load data local infile '/tmp/t.csv' into table t lines starting by 'ab' terminated by 'xy'", true},
{"load data local infile '/tmp/t.csv' into table t fields terminated by 'ab' lines terminated by 'xy'", true},
{"load data local infile '/tmp/t.csv' into table t terminated by 'xy' fields terminated by 'ab'", false},
{"load data infile '/tmp/t.csv' into table t (a,b)", true},
{"load data local infile '/tmp/t.csv' into table t (a,b)", true},
{"load data local infile '/tmp/t.csv' into table t fields terminated by 'ab' (a,b)", true},
{"load data local infile '/tmp/t.csv' into table t columns terminated by 'ab' (a,b)", true},
{"load data local infile '/tmp/t.csv' into table t fields terminated by 'ab' enclosed by 'b' (a,b)", true},
{"load data local infile '/tmp/t.csv' into table t fields terminated by 'ab' enclosed by 'b' escaped by '*' (a,b)", true},
{"load data local infile '/tmp/t.csv' into table t lines starting by 'ab' (a,b)", true},
{"load data local infile '/tmp/t.csv' into table t lines starting by 'ab' terminated by 'xy' (a,b)", true},
{"load data local infile '/tmp/t.csv' into table t fields terminated by 'ab' lines terminated by 'xy' (a,b)", true},
{"load data local infile '/tmp/t.csv' into table t (a,b) fields terminated by 'ab'", false},

// select for update
{"SELECT * from t for update", true},
Expand Down
1 change: 1 addition & 0 deletions plan/planbuilder.go
Original file line number Diff line number Diff line change
Expand Up @@ -742,6 +742,7 @@ func (b *planBuilder) buildLoadData(ld *ast.LoadDataStmt) Plan {
IsLocal: ld.IsLocal,
Path: ld.Path,
Table: ld.Table,
Columns: ld.Columns,
FieldsInfo: ld.FieldsInfo,
LinesInfo: ld.LinesInfo,
}
Expand Down
1 change: 1 addition & 0 deletions plan/plans.go
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@ type LoadData struct {
IsLocal bool
Path string
Table *ast.TableName
Columns []*ast.ColumnName
FieldsInfo *ast.FieldsClause
LinesInfo *ast.LinesClause
}
Expand Down
7 changes: 2 additions & 5 deletions server/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -593,10 +593,7 @@ func insertDataWithCommit(prevData, curData []byte, loadDataInfo *executor.LoadD
break
}
// Make sure that there are no retries when committing.
if err = loadDataInfo.Ctx.Txn().Commit(); err != nil {
return nil, errors.Trace(err)
}
if err = loadDataInfo.Ctx.NewTxn(); err != nil {
if err = loadDataInfo.Ctx.RefreshTxnCtx(); err != nil {
return nil, errors.Trace(err)
}
curData = prevData
Expand Down Expand Up @@ -654,7 +651,7 @@ func (cc *clientConn) handleLoadData(loadDataInfo *executor.LoadDataInfo) error

txn := loadDataInfo.Ctx.Txn()
if err != nil {
if txn.Valid() {
if txn != nil && txn.Valid() {
if err1 := txn.Rollback(); err1 != nil {
log.Errorf("load data rollback failed: %v", err1)
}
Expand Down
9 changes: 9 additions & 0 deletions session.go
Original file line number Diff line number Diff line change
Expand Up @@ -1067,6 +1067,15 @@ func (s *session) prepareTxnCtx() {
}
}

// RefreshTxnCtx implements context.RefreshTxnCtx interface.
func (s *session) RefreshTxnCtx() error {
if err := s.doCommit(); err != nil {
return errors.Trace(err)
}

return errors.Trace(s.NewTxn())
}

// ActivePendingTxn implements Context.ActivePendingTxn interface.
func (s *session) ActivePendingTxn() error {
if s.txn != nil && s.txn.Valid() {
Expand Down
5 changes: 5 additions & 0 deletions util/mock/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,11 @@ func (c *Context) NewTxn() error {
return nil
}

// RefreshTxnCtx implements the context.Context interface.
func (c *Context) RefreshTxnCtx() error {
return errors.Trace(c.NewTxn())
}

// ActivePendingTxn implements the context.Context interface.
func (c *Context) ActivePendingTxn() error {
if c.txn != nil {
Expand Down

0 comments on commit 213954d

Please sign in to comment.