Skip to content

Commit

Permalink
Merge pull request apache#803 from ma-xiao-guang-64/dev
Browse files Browse the repository at this point in the history
optimize batch insert
  • Loading branch information
terrymanu authored May 6, 2018
2 parents 77e358d + e2b1bd5 commit 5ead479
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@

import java.util.ArrayList;
import java.util.Collection;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;

Expand All @@ -59,6 +60,7 @@ public ShardingConditions optimize() {
List<AndCondition> andConditions = insertStatement.getConditions().getOrCondition().getAndConditions();
List<InsertValue> insertValues = insertStatement.getInsertValues().getInsertValues();
List<ShardingCondition> result = new ArrayList<>(andConditions.size());
Iterator<Number> generatedKeys = null;
int count = 0;
for (AndCondition each : andConditions) {
InsertValue insertValue = insertValues.get(count);
Expand All @@ -67,38 +69,41 @@ public ShardingConditions optimize() {

String logicTableName = insertStatement.getTables().getSingleTableName();
Optional<Column> generateKeyColumn = shardingRule.getGenerateKeyColumn(logicTableName);
String expression;
InsertShardingCondition insertShardingCondition;
if (-1 != insertStatement.getGenerateKeyColumnIndex() || !generateKeyColumn.isPresent()) {
expression = insertValue.getExpression();
insertShardingCondition = new InsertShardingCondition(insertValue.getExpression(), currentParameters);
} else {
if (null == generatedKeys) {
generatedKeys = generatedKey.getGeneratedKeys().iterator();
}
String expression;
Number currentGeneratedKey = generatedKeys.next();
if (0 == parameters.size()) {
expression = insertValue.getExpression().substring(0, insertValue.getExpression().length() - 1) + ", " + generatedKey.getGeneratedKeys().get(count).toString() + ")";
expression = insertValue.getExpression().substring(0, insertValue.getExpression().length() - 1) + ", " + currentGeneratedKey.toString() + ")";
} else {
expression = insertValue.getExpression().substring(0, insertValue.getExpression().length() - 1) + ", ?)";
currentParameters.add(generatedKey.getGeneratedKeys().get(count));
currentParameters.add(currentGeneratedKey);
}
insertShardingCondition = new InsertShardingCondition(expression, currentParameters);
insertShardingCondition.getShardingValues().add(getShardingCondition(generateKeyColumn.get(), currentGeneratedKey));
}
InsertShardingCondition insertShardingCondition = new InsertShardingCondition(expression, currentParameters);
insertShardingCondition.getShardingValues().addAll(getShardingCondition(each));
if (-1 == insertStatement.getGenerateKeyColumnIndex() && generateKeyColumn.isPresent()) {
insertShardingCondition.getShardingValues().add(getShardingCondition(generateKeyColumn.get(), generatedKey.getGeneratedKeys().get(count)));
}
result.add(insertShardingCondition);
count++;
}
return new ShardingConditions(result);
}

private ListShardingValue getShardingCondition(final Column column, final Number value) {
return new ListShardingValue<>(column.getTableName(), column.getName(),
new GeneratedKeyCondition(column, -1, value).getConditionValues(parameters));
}

private Collection<ListShardingValue> getShardingCondition(final AndCondition andCondition) {
Collection<ListShardingValue> result = new LinkedList<>();
for (Condition each : andCondition.getConditions()) {
result.add(new ListShardingValue<>(each.getColumn().getTableName(), each.getColumn().getName(), each.getConditionValues(parameters)));
}
return result;
}

private ListShardingValue getShardingCondition(final Column column, final Number value) {
return new ListShardingValue<>(column.getTableName(), column.getName(),
new GeneratedKeyCondition(column, -1, value).getConditionValues(parameters));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@ private void processInsertShardingCondition(final TableUnit tableUnit, final Ins
if (dataNode.getDataSourceName().equals(tableUnit.getDataSourceName()) && dataNode.getTableName().equals(tableUnit.getRoutingTables().iterator().next().getActualTableName())) {
expressions.add(shardingCondition.getInsertValueExpression());
parameters.addAll(shardingCondition.getParameters());
break;
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -136,15 +136,15 @@ private Collection<String> routeDataSources(final TableRule tableRule, final Lis
if (databaseShardingValues.isEmpty()) {
return availableTargetDatabases;
}
Collection<String> result = shardingRule.getDatabaseShardingStrategy(tableRule).doSharding(availableTargetDatabases, databaseShardingValues);
Collection<String> result = new LinkedHashSet<>(shardingRule.getDatabaseShardingStrategy(tableRule).doSharding(availableTargetDatabases, databaseShardingValues));
Preconditions.checkState(!result.isEmpty(), "no database route info");
return result;
}

private Collection<DataNode> routeTables(final TableRule tableRule, final String routedDataSource, final List<ShardingValue> tableShardingValues) {
Collection<String> availableTargetTables = tableRule.getActualTableNames(routedDataSource);
Collection<String> routedTables = tableShardingValues.isEmpty() ? availableTargetTables
: shardingRule.getTableShardingStrategy(tableRule).doSharding(availableTargetTables, tableShardingValues);
Collection<String> routedTables = new LinkedHashSet<>(tableShardingValues.isEmpty() ? availableTargetTables
: shardingRule.getTableShardingStrategy(tableRule).doSharding(availableTargetTables, tableShardingValues));
Preconditions.checkState(!routedTables.isEmpty(), "no table route info");
Collection<DataNode> result = new LinkedList<>();
for (String each : routedTables) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,10 +104,10 @@ public void assertOptimizeWithGeneratedKey() {
assertThat(((InsertShardingCondition) actual.getShardingConditions().get(1)).getInsertValueExpression(), is("(?, ?, ?)"));
assertThat(actual.getShardingConditions().get(0).getShardingValues().size(), is(2));
assertThat(actual.getShardingConditions().get(1).getShardingValues().size(), is(2));
assertShardingValue((ListShardingValue) actual.getShardingConditions().get(0).getShardingValues().get(0), 10);
assertShardingValue((ListShardingValue) actual.getShardingConditions().get(0).getShardingValues().get(1), 1);
assertShardingValue((ListShardingValue) actual.getShardingConditions().get(1).getShardingValues().get(0), 11);
assertShardingValue((ListShardingValue) actual.getShardingConditions().get(1).getShardingValues().get(1), 2);
assertShardingValue((ListShardingValue) actual.getShardingConditions().get(0).getShardingValues().get(0), 1);
assertShardingValue((ListShardingValue) actual.getShardingConditions().get(0).getShardingValues().get(1), 10);
assertShardingValue((ListShardingValue) actual.getShardingConditions().get(1).getShardingValues().get(0), 2);
assertShardingValue((ListShardingValue) actual.getShardingConditions().get(1).getShardingValues().get(1), 11);
}

@Test
Expand Down

0 comments on commit 5ead479

Please sign in to comment.