Skip to content

Commit

Permalink
bugfix: InsertExecutor afterImage (apache#1512)
Browse files Browse the repository at this point in the history
  • Loading branch information
jsbxyyx authored and slievrly committed Sep 5, 2019
1 parent f99fc51 commit 0cd1d98
Show file tree
Hide file tree
Showing 9 changed files with 830 additions and 64 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@
import io.seata.rm.datasource.sql.SQLRecognizer;
import io.seata.rm.datasource.sql.struct.ColumnMeta;
import io.seata.rm.datasource.sql.struct.Null;
import io.seata.rm.datasource.sql.struct.SqlMethodExpr;
import io.seata.rm.datasource.sql.struct.SqlSequenceExpr;
import io.seata.rm.datasource.sql.struct.TableMeta;
import io.seata.rm.datasource.sql.struct.TableRecords;
import org.slf4j.Logger;
Expand All @@ -49,6 +51,8 @@ public class InsertExecutor<T, S extends Statement> extends AbstractDMLBaseExecu
private static final Logger LOGGER = LoggerFactory.getLogger(InsertExecutor.class);
protected static final String ERR_SQL_STATE = "S1009";

private static final String PLACEHOLDER = "?";

/**
* Instantiates a new Insert executor.
*
Expand All @@ -69,7 +73,8 @@ protected TableRecords beforeImage() throws SQLException {
@Override
protected TableRecords afterImage(TableRecords beforeImage) throws SQLException {
//Pk column exists or PK is just auto generated
List<Object> pkValues = containsPK() ? getPkValuesByColumn() : getPkValuesByAuto();
List<Object> pkValues = containsPK() ? getPkValuesByColumn() :
(containsColumns() ? getPkValuesByAuto() : getPkValuesByColumn());

TableRecords afterImage = buildTableRecords(pkValues);

Expand All @@ -87,56 +92,106 @@ protected boolean containsPK() {
return tmeta.containsPK(insertColumns);
}

protected boolean containsColumns() {
SQLInsertRecognizer recognizer = (SQLInsertRecognizer) sqlRecognizer;
List<String> insertColumns = recognizer.getInsertColumns();
return insertColumns != null && !insertColumns.isEmpty();
}

protected List<Object> getPkValuesByColumn() throws SQLException {
// insert values including PK
SQLInsertRecognizer recognizer = (SQLInsertRecognizer) sqlRecognizer;
List<String> insertColumns = recognizer.getInsertColumns();
String pk = getTableMeta().getPkName();
final int pkIndex = getPkIndex();
List<Object> pkValues = null;
if (statementProxy instanceof PreparedStatementProxy) {
PreparedStatementProxy preparedStatementProxy = (PreparedStatementProxy) statementProxy;
ArrayList<Object>[] paramters = preparedStatementProxy.getParameters();
int insertColumnsSize = insertColumns.size();
int cycleNums = paramters.length / insertColumnsSize;
List<Integer> pkIndexs = new ArrayList<>(cycleNums);
int firstPkIndex = 0;
for (int paramIdx = 0; paramIdx < insertColumns.size(); paramIdx++) {
if (insertColumns.get(paramIdx).equalsIgnoreCase(pk)) {
firstPkIndex = paramIdx;
break;

List<List<Object>> insertRows = recognizer.getInsertRows();
if (insertRows != null && !insertRows.isEmpty()) {
ArrayList<Object>[] parameters = preparedStatementProxy.getParameters();
final int rowSize = insertRows.size();

if (rowSize == 1) {
Object pkValue = insertRows.get(0).get(pkIndex);
if (PLACEHOLDER.equals(pkValue)) {
pkValues = parameters[pkIndex];
} else {
int finalPkIndex = pkIndex;
pkValues = insertRows.stream().map(insertRow -> insertRow.get(finalPkIndex)).collect(Collectors.toList());
}
} else {
int totalPlaceholderNum = -1;
pkValues = new ArrayList<>(rowSize);
for (int i = 0; i < rowSize; i++) {
List<Object> row = insertRows.get(i);
Object pkValue = row.get(pkIndex);
int currentRowPlaceholderNum = -1;
for (Object r : row) {
if (PLACEHOLDER.equals(r)) {
totalPlaceholderNum += 1;
currentRowPlaceholderNum += 1;
}
}
if (PLACEHOLDER.equals(pkValue)) {
int idx = pkIndex;
if (i != 0) {
idx = totalPlaceholderNum - currentRowPlaceholderNum + pkIndex;
}
ArrayList<Object> parameter = parameters[idx];
for (Object obj : parameter) {
pkValues.add(obj);
}
} else {
pkValues.add(pkValue);
}
}
}
}
for (int i = 0; i < cycleNums; i++) {
pkIndexs.add(insertColumnsSize * i + firstPkIndex);
}
if (pkIndexs.size() == 1) {
//adapter test case
pkValues = preparedStatementProxy.getParamsByIndex(pkIndexs.get(0));
} else {
pkValues = pkIndexs.stream().map(pkIndex -> paramters[pkIndex].get(0)).collect(Collectors.toList());
}
} else {
for (int paramIdx = 0; paramIdx < insertColumns.size(); paramIdx++) {
if (insertColumns.get(paramIdx).equalsIgnoreCase(pk)) {
List<List<Object>> insertRows = recognizer.getInsertRows();
pkValues = new ArrayList<>(insertRows.size());
for (List<Object> row : insertRows) {
pkValues.add(row.get(paramIdx));
}
break;
}
List<List<Object>> insertRows = recognizer.getInsertRows();
pkValues = new ArrayList<>(insertRows.size());
for (List<Object> row : insertRows) {
pkValues.add(row.get(pkIndex));
}
}
if (pkValues == null) {
throw new ShouldNeverHappenException();
}
//pk auto generated while column exists and value is null
if (pkValues.size() == 1 && pkValues.get(0) instanceof Null) {
boolean b = this.checkPkValues(pkValues);
if (!b) {
throw new NotSupportYetException("not support sql [" + sqlRecognizer.getOriginalSQL() + "]");
}
if (pkValues.size() == 1 && pkValues.get(0) instanceof SqlSequenceExpr) {
pkValues = getPkValuesBySequence(pkValues.get(0));
}
// pk auto generated while single insert primary key is expression
else if (pkValues.size() == 1 && pkValues.get(0) instanceof SqlMethodExpr) {
pkValues = getPkValuesByAuto();
}
// pk auto generated while column exists and value is null
else if (pkValues.size() > 0 && pkValues.get(0) instanceof Null) {
pkValues = getPkValuesByAuto();
}
return pkValues;
}

protected List<Object> getPkValuesBySequence(Object expr) throws SQLException {
ResultSet genKeys = null;
if (expr instanceof SqlSequenceExpr) {
SqlSequenceExpr sequenceExpr = (SqlSequenceExpr) expr;
final String sql = "SELECT " + sequenceExpr.getSequence() + ".currval FROM DUAL";
LOGGER.warn("Fail to get auto-generated keys, use \'{}\' instead. Be cautious, statement could be polluted. Recommend you set the statement to return generated keys.", sql);
genKeys = statementProxy.getConnection().createStatement().executeQuery(sql);
} else {
throw new NotSupportYetException(String.format("not support expr [%s]", expr.getClass().getName()));
}
List<Object> pkValues = new ArrayList<>();
while (genKeys.next()) {
Object v = genKeys.getObject(1);
pkValues.add(v);
}
return pkValues;
}

protected List<Object> getPkValuesByAuto() throws SQLException {
// PK is just auto generated
Expand Down Expand Up @@ -170,4 +225,66 @@ protected List<Object> getPkValuesByAuto() throws SQLException {
}
return pkValues;
}

/**
* get pk index
* @return -1 not found pk index
*/
protected int getPkIndex() {
SQLInsertRecognizer recognizer = (SQLInsertRecognizer) sqlRecognizer;
String pkName = getTableMeta().getPkName();
List<String> insertColumns = recognizer.getInsertColumns();
if (insertColumns != null && !insertColumns.isEmpty()) {
final int insertColumnsSize = insertColumns.size();
int pkIndex = -1;
for (int paramIdx = 0; paramIdx < insertColumnsSize; paramIdx++) {
if (insertColumns.get(paramIdx).equalsIgnoreCase(pkName)) {
pkIndex = paramIdx;
break;
}
}
return pkIndex;
}
int pkIndex = -1;
Map<String, ColumnMeta> allColumns = getTableMeta().getAllColumns();
for (Map.Entry<String, ColumnMeta> entry : allColumns.entrySet()) {
pkIndex++;
if (entry.getValue().getColumnName().equalsIgnoreCase(pkName)) {
break;
}
}
return pkIndex;
}

/**
* check pk values
* @param pkValues
* @return true support false not support
*/
private boolean checkPkValues(List<Object> pkValues) {
boolean pkParameterHasNull = false;
boolean pkParameterHasNotNull = false;
boolean pkParameterHasExpr = false;
if (pkValues.size() == 1) {
return true;
}
for (Object pkValue : pkValues) {
if (pkValue instanceof Null) {
pkParameterHasNull = true;
continue;
}
pkParameterHasNotNull = true;
if (pkValue instanceof SqlMethodExpr) {
pkParameterHasExpr = true;
}
}
if (pkParameterHasExpr) {
return false;
}
if (pkParameterHasNull && pkParameterHasNotNull) {
return false;
}
return true;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,19 @@
import com.alibaba.druid.sql.ast.SQLExpr;
import com.alibaba.druid.sql.ast.SQLStatement;
import com.alibaba.druid.sql.ast.expr.SQLIdentifierExpr;
import com.alibaba.druid.sql.ast.expr.SQLMethodInvokeExpr;
import com.alibaba.druid.sql.ast.expr.SQLNullExpr;
import com.alibaba.druid.sql.ast.expr.SQLValuableExpr;
import com.alibaba.druid.sql.ast.expr.SQLVariantRefExpr;
import com.alibaba.druid.sql.ast.statement.SQLExprTableSource;
import com.alibaba.druid.sql.ast.statement.SQLInsertStatement;
import com.alibaba.druid.sql.dialect.mysql.ast.statement.MySqlInsertStatement;
import com.alibaba.druid.sql.dialect.mysql.visitor.MySqlOutputVisitor;
import io.seata.rm.datasource.sql.SQLInsertRecognizer;
import io.seata.rm.datasource.sql.SQLParsingException;
import io.seata.rm.datasource.sql.SQLType;
import io.seata.rm.datasource.sql.struct.Null;
import io.seata.rm.datasource.sql.struct.SqlMethodExpr;

/**
* The type My sql insert recognizer.
Expand Down Expand Up @@ -102,8 +107,14 @@ public List<List<Object>> getInsertRows() {
List<Object> row = new ArrayList<>(exprs.size());
rows.add(row);
for (SQLExpr expr : valuesClause.getValues()) {
if (expr instanceof SQLValuableExpr) {
if (expr instanceof SQLNullExpr) {
row.add(Null.get());
} else if (expr instanceof SQLValuableExpr) {
row.add(((SQLValuableExpr)expr).getValue());
} else if (expr instanceof SQLVariantRefExpr) {
row.add(((SQLVariantRefExpr)expr).getName());
} else if (expr instanceof SQLMethodInvokeExpr) {
row.add(new SqlMethodExpr());
} else {
throw new SQLParsingException("Unknown SQLExpr: " + expr.getClass() + " " + expr);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,11 @@
import com.alibaba.druid.sql.ast.SQLExpr;
import com.alibaba.druid.sql.ast.SQLStatement;
import com.alibaba.druid.sql.ast.expr.SQLIdentifierExpr;
import com.alibaba.druid.sql.ast.expr.SQLMethodInvokeExpr;
import com.alibaba.druid.sql.ast.expr.SQLNullExpr;
import com.alibaba.druid.sql.ast.expr.SQLSequenceExpr;
import com.alibaba.druid.sql.ast.expr.SQLValuableExpr;
import com.alibaba.druid.sql.ast.expr.SQLVariantRefExpr;
import com.alibaba.druid.sql.ast.statement.SQLExprTableSource;
import com.alibaba.druid.sql.ast.statement.SQLInsertStatement;
import com.alibaba.druid.sql.dialect.oracle.ast.stmt.OracleInsertStatement;
Expand All @@ -27,6 +31,9 @@
import io.seata.rm.datasource.sql.SQLParsingException;
import io.seata.rm.datasource.sql.SQLType;
import io.seata.rm.datasource.sql.druid.BaseRecognizer;
import io.seata.rm.datasource.sql.struct.Null;
import io.seata.rm.datasource.sql.struct.SqlMethodExpr;
import io.seata.rm.datasource.sql.struct.SqlSequenceExpr;

import java.util.ArrayList;
import java.util.List;
Expand Down Expand Up @@ -103,8 +110,19 @@ public List<List<Object>> getInsertRows() {
List<Object> row = new ArrayList<>(exprs.size());
rows.add(row);
for (SQLExpr expr : valuesClause.getValues()) {
if (expr instanceof SQLValuableExpr) {
if (expr instanceof SQLNullExpr) {
row.add(Null.get());
} else if (expr instanceof SQLValuableExpr) {
row.add(((SQLValuableExpr)expr).getValue());
} else if (expr instanceof SQLVariantRefExpr) {
row.add(((SQLVariantRefExpr)expr).getName());
} else if (expr instanceof SQLMethodInvokeExpr) {
row.add(new SqlMethodExpr());
} else if (expr instanceof SQLSequenceExpr) {
SQLSequenceExpr sequenceExpr = ((SQLSequenceExpr) expr);
String sequence = sequenceExpr.getSequence().getSimpleName();
String function = sequenceExpr.getFunction().name;
row.add(new SqlSequenceExpr(sequence, function));
} else {
throw new SQLParsingException("Unknown SQLExpr: " + expr.getClass() + " " + expr);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
/*
* Copyright 1999-2019 Seata.io Group.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.seata.rm.datasource.sql.struct;

/**
* TODO
* sql method invoke expression
* @author jsbxyyx
*/
public class SqlMethodExpr {

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
/*
* Copyright 1999-2019 Seata.io Group.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.seata.rm.datasource.sql.struct;

/**
* TODO
* sql sequence expression
* @author jsbxyyx
*/
public class SqlSequenceExpr {

private String sequence;
private String function;

public SqlSequenceExpr() {}

public SqlSequenceExpr(String sequence, String function) {
this.sequence = sequence;
this.function = function;
}

public String getSequence() {
return sequence;
}

public void setSequence(String sequence) {
this.sequence = sequence;
}

public String getFunction() {
return function;
}

public void setFunction(String function) {
this.function = function;
}
}
Loading

0 comments on commit 0cd1d98

Please sign in to comment.