Skip to content

Commit

Permalink
Fix issues#9510 , support rewrite SQL correctly when using sharding a…
Browse files Browse the repository at this point in the history
…nd encrypt together (apache#9749)

* fix#9510.refactor SubstitutableColumnNameToken to SubstitutableColumnsToken,support rewrite SQL correctly when using sharding and encrypt together.

* fix#9510.fix sql parser case of column-projection stop-index value.

* fix#9510.fix SubstitutableColumnsToken.toString(RouteUnit), use logic table name to instead of  actualTableName when get actual table name from RouteUnit is empty.

Co-authored-by: huanghao-jk <[email protected]>
  • Loading branch information
huanghao495430759 and huanghao-jk authored Mar 22, 2021
1 parent 51d8c02 commit dc96405
Show file tree
Hide file tree
Showing 17 changed files with 386 additions and 219 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,17 +22,18 @@
import org.apache.shardingsphere.encrypt.rewrite.aware.QueryWithCipherColumnAware;
import org.apache.shardingsphere.encrypt.rewrite.token.generator.BaseEncryptSQLTokenGenerator;
import org.apache.shardingsphere.encrypt.rule.EncryptTable;
import org.apache.shardingsphere.infra.rewrite.sql.token.generator.CollectionSQLTokenGenerator;
import org.apache.shardingsphere.infra.rewrite.sql.token.generator.aware.SchemaMetaDataAware;
import org.apache.shardingsphere.infra.rewrite.sql.token.pojo.generic.SubstitutableColumnNameToken;
import org.apache.shardingsphere.infra.metadata.schema.ShardingSphereSchema;
import org.apache.shardingsphere.infra.binder.statement.SQLStatementContext;
import org.apache.shardingsphere.infra.binder.type.WhereAvailable;
import org.apache.shardingsphere.infra.metadata.schema.ShardingSphereSchema;
import org.apache.shardingsphere.infra.rewrite.sql.token.generator.CollectionSQLTokenGenerator;
import org.apache.shardingsphere.infra.rewrite.sql.token.generator.aware.SchemaMetaDataAware;
import org.apache.shardingsphere.infra.rewrite.sql.token.pojo.generic.SubstitutableColumn;
import org.apache.shardingsphere.infra.rewrite.sql.token.pojo.generic.SubstitutableColumnsToken;
import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.column.ColumnSegment;
import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.ExpressionSegment;
import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.predicate.AndPredicate;
import org.apache.shardingsphere.sql.parser.sql.common.util.ExpressionBuilder;
import org.apache.shardingsphere.sql.parser.sql.common.util.ColumnExtractor;
import org.apache.shardingsphere.sql.parser.sql.common.util.ExpressionBuilder;

import java.util.Collection;
import java.util.LinkedHashSet;
Expand All @@ -55,9 +56,9 @@ protected boolean isGenerateSQLTokenForEncrypt(final SQLStatementContext sqlStat
}

@Override
public Collection<SubstitutableColumnNameToken> generateSQLTokens(final SQLStatementContext sqlStatementContext) {
public Collection<SubstitutableColumnsToken> generateSQLTokens(final SQLStatementContext sqlStatementContext) {
Preconditions.checkState(((WhereAvailable) sqlStatementContext).getWhere().isPresent());
Collection<SubstitutableColumnNameToken> result = new LinkedHashSet<>();
Collection<SubstitutableColumnsToken> result = new LinkedHashSet<>();
ExpressionSegment expression = ((WhereAvailable) sqlStatementContext).getWhere().get().getExpr();
ExpressionBuilder expressionBuilder = new ExpressionBuilder(expression);
Collection<AndPredicate> andPredicates = new LinkedList<>(expressionBuilder.extractAndPredicates().getAndPredicates());
Expand All @@ -67,8 +68,8 @@ public Collection<SubstitutableColumnNameToken> generateSQLTokens(final SQLState
return result;
}

private Collection<SubstitutableColumnNameToken> generateSQLTokens(final SQLStatementContext sqlStatementContext, final AndPredicate andPredicate) {
Collection<SubstitutableColumnNameToken> result = new LinkedList<>();
private Collection<SubstitutableColumnsToken> generateSQLTokens(final SQLStatementContext sqlStatementContext, final AndPredicate andPredicate) {
Collection<SubstitutableColumnsToken> result = new LinkedList<>();
for (ExpressionSegment each : andPredicate.getPredicates()) {
Optional<ColumnSegment> column = ColumnExtractor.extract(each);
if (!column.isPresent()) {
Expand All @@ -78,18 +79,23 @@ private Collection<SubstitutableColumnNameToken> generateSQLTokens(final SQLStat
if (!encryptTable.isPresent() || !encryptTable.get().findEncryptorName(column.get().getIdentifier().getValue()).isPresent()) {
continue;
}
int startIndex = column.get().getOwner().isPresent() ? column.get().getOwner().get().getStopIndex() + 2 : column.get().getStartIndex();
int startIndex = column.get().getOwner().isPresent() ? column.get().getOwner().get().getStartIndex() : column.get().getStartIndex();
int stopIndex = column.get().getStopIndex();
Optional<String> tableName = sqlStatementContext.getTablesContext().findTableName(column.get(), schema);
String owner = column.get().getOwner().isPresent() ? column.get().getOwner().get().getIdentifier().getValue() : "";
if (!queryWithCipherColumn) {
Optional<String> plainColumn = encryptTable.get().findPlainColumn(column.get().getIdentifier().getValue());
if (plainColumn.isPresent()) {
result.add(new SubstitutableColumnNameToken(startIndex, stopIndex, plainColumn.get()));
result.add(new SubstitutableColumnsToken(startIndex, stopIndex,
new SubstitutableColumn(tableName.get(), owner, plainColumn.get(), column.get().getIdentifier().getQuoteCharacter(), Optional.empty())));
continue;
}
}
Optional<String> assistedQueryColumn = encryptTable.get().findAssistedQueryColumn(column.get().getIdentifier().getValue());
SubstitutableColumnNameToken encryptColumnNameToken = assistedQueryColumn.map(columnName -> new SubstitutableColumnNameToken(startIndex, stopIndex, columnName))
.orElseGet(() -> new SubstitutableColumnNameToken(startIndex, stopIndex, encryptTable.get().getCipherColumn(column.get().getIdentifier().getValue())));
SubstitutableColumnsToken encryptColumnNameToken = assistedQueryColumn.map(columnName -> new SubstitutableColumnsToken(startIndex, stopIndex,
new SubstitutableColumn(tableName.get(), owner, columnName, column.get().getIdentifier().getQuoteCharacter(), Optional.empty())))
.orElseGet(() -> new SubstitutableColumnsToken(startIndex, stopIndex, new SubstitutableColumn(tableName.get(), owner,
encryptTable.get().getCipherColumn(column.get().getIdentifier().getValue()), column.get().getIdentifier().getQuoteCharacter(), Optional.empty())));
result.add(encryptColumnNameToken);
}
return result;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

package org.apache.shardingsphere.encrypt.rewrite.token.generator.impl;

import com.google.common.base.Joiner;
import lombok.Setter;
import org.apache.shardingsphere.encrypt.rewrite.aware.QueryWithCipherColumnAware;
import org.apache.shardingsphere.encrypt.rewrite.token.generator.BaseEncryptSQLTokenGenerator;
Expand All @@ -30,7 +29,8 @@
import org.apache.shardingsphere.infra.binder.statement.dml.SelectStatementContext;
import org.apache.shardingsphere.infra.database.type.DatabaseType;
import org.apache.shardingsphere.infra.rewrite.sql.token.generator.CollectionSQLTokenGenerator;
import org.apache.shardingsphere.infra.rewrite.sql.token.pojo.generic.SubstitutableColumnNameToken;
import org.apache.shardingsphere.infra.rewrite.sql.token.pojo.generic.SubstitutableColumn;
import org.apache.shardingsphere.infra.rewrite.sql.token.pojo.generic.SubstitutableColumnsToken;
import org.apache.shardingsphere.sql.parser.sql.common.constant.QuoteCharacter;
import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.item.ColumnProjectionSegment;
import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.item.ProjectionSegment;
Expand All @@ -41,7 +41,6 @@
import java.util.Collection;
import java.util.Collections;
import java.util.LinkedList;
import java.util.List;
import java.util.Optional;

/**
Expand All @@ -58,17 +57,17 @@ protected boolean isGenerateSQLTokenForEncrypt(final SQLStatementContext sqlStat
}

@Override
public Collection<SubstitutableColumnNameToken> generateSQLTokens(final SelectStatementContext selectStatementContext) {
public Collection<SubstitutableColumnsToken> generateSQLTokens(final SelectStatementContext selectStatementContext) {
ProjectionsSegment projectionsSegment = selectStatementContext.getSqlStatement().getProjections();
// TODO process multiple tables
String tableName = selectStatementContext.getAllSimpleTableSegments().iterator().next().getTableName().getIdentifier().getValue();
return getEncryptRule().findEncryptTable(tableName).map(
encryptTable -> generateSQLTokens(projectionsSegment, tableName, selectStatementContext, encryptTable)).orElseGet(Collections::emptyList);
}

private Collection<SubstitutableColumnNameToken> generateSQLTokens(final ProjectionsSegment segment, final String tableName,
final SelectStatementContext selectStatementContext, final EncryptTable encryptTable) {
Collection<SubstitutableColumnNameToken> result = new LinkedList<>();
private Collection<SubstitutableColumnsToken> generateSQLTokens(final ProjectionsSegment segment, final String tableName,
final SelectStatementContext selectStatementContext, final EncryptTable encryptTable) {
Collection<SubstitutableColumnsToken> result = new LinkedList<>();
for (ProjectionSegment each : segment.getProjections()) {
if (each instanceof ColumnProjectionSegment) {
if (encryptTable.getLogicColumns().contains(((ColumnProjectionSegment) each).getColumn().getIdentifier().getValue())) {
Expand All @@ -93,28 +92,28 @@ private boolean isToGeneratedSQLToken(final ProjectionSegment projectionSegment,
return ownerSegment.map(segment -> selectStatementContext.getTablesContext().findTableNameFromSQL(segment.getIdentifier().getValue()).equalsIgnoreCase(tableName)).orElse(true);
}

private SubstitutableColumnNameToken generateSQLToken(final ColumnProjectionSegment segment, final String tableName) {
private SubstitutableColumnsToken generateSQLToken(final ColumnProjectionSegment segment, final String tableName) {
String encryptColumnName = getEncryptColumnName(tableName, segment.getColumn().getIdentifier().getValue());
if (!segment.getAlias().isPresent()) {
encryptColumnName += " AS " + segment.getColumn().getIdentifier().getValue();
}
return segment.getColumn().getOwner().isPresent() ? new SubstitutableColumnNameToken(segment.getColumn().getOwner().get().getStopIndex() + 2, segment.getStopIndex(), encryptColumnName)
: new SubstitutableColumnNameToken(segment.getStartIndex(), segment.getStopIndex(), encryptColumnName);
String owner = segment.getColumn().getOwner().isPresent() ? segment.getColumn().getOwner().get().getIdentifier().getValue() : null;
Optional<String> alias = segment.getAlias().isPresent() ? segment.getAlias() : Optional.ofNullable(segment.getColumn().getIdentifier().getValue());
return new SubstitutableColumnsToken(segment.getStartIndex(), segment.getStopIndex(),
new SubstitutableColumn(tableName, owner, encryptColumnName, segment.getColumn().getIdentifier().getQuoteCharacter(), alias));
}

private SubstitutableColumnNameToken generateSQLToken(final ShorthandProjectionSegment segment,
final ShorthandProjection shorthandProjection, final String tableName, final EncryptTable encryptTable, final DatabaseType databaseType) {
private SubstitutableColumnsToken generateSQLToken(final ShorthandProjectionSegment segment,
final ShorthandProjection shorthandProjection, final String tableName, final EncryptTable encryptTable, final DatabaseType databaseType) {
String owner = segment.getOwner().isPresent() ? segment.getOwner().get().getIdentifier().getValue() : "";
SubstitutableColumnsToken substitutableColumnsToken = new SubstitutableColumnsToken(segment.getStartIndex(), segment.getStopIndex());
QuoteCharacter quoteCharacter = databaseType.getQuoteCharacter();
List<String> shorthandExtensionProjections = new LinkedList<>();
for (ColumnProjection each : shorthandProjection.getActualColumns()) {
if (encryptTable.getLogicColumns().contains(each.getName())) {
shorthandExtensionProjections.add(new ColumnProjection(null == each.getOwner() ? null : quoteCharacter.wrap(each.getOwner()),
quoteCharacter.wrap(getEncryptColumnName(tableName, each.getName())), each.getName()).getExpressionWithAlias());
substitutableColumnsToken.addColumn(new SubstitutableColumn(tableName, !owner.isEmpty() ? owner : each.getOwner(),
getEncryptColumnName(tableName, each.getName()), quoteCharacter, Optional.ofNullable(each.getName())));
} else {
shorthandExtensionProjections.add(null == each.getOwner() ? quoteCharacter.wrap(each.getName()) : quoteCharacter.wrap(each.getOwner()) + "." + quoteCharacter.wrap(each.getName()));
substitutableColumnsToken.addColumn(new SubstitutableColumn(tableName, !owner.isEmpty() ? owner : each.getOwner(), each.getName(), quoteCharacter, each.getAlias()));
}
}
return new SubstitutableColumnNameToken(segment.getStartIndex(), segment.getStopIndex(), Joiner.on(", ").join(shorthandExtensionProjections));
return substitutableColumnsToken;
}

private String getEncryptColumnName(final String tableName, final String logicEncryptColumnName) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,17 @@
package org.apache.shardingsphere.encrypt.rewrite.token.generator.impl;

import com.google.common.base.Preconditions;
import lombok.Setter;
import org.apache.shardingsphere.encrypt.rewrite.token.generator.BaseEncryptSQLTokenGenerator;
import org.apache.shardingsphere.infra.binder.statement.SQLStatementContext;
import org.apache.shardingsphere.infra.binder.statement.dml.InsertStatementContext;
import org.apache.shardingsphere.infra.metadata.schema.ShardingSphereSchema;
import org.apache.shardingsphere.infra.rewrite.sql.token.generator.CollectionSQLTokenGenerator;
import org.apache.shardingsphere.infra.rewrite.sql.token.generator.aware.SchemaMetaDataAware;
import org.apache.shardingsphere.infra.rewrite.sql.token.pojo.generic.SubstitutableColumn;
import org.apache.shardingsphere.infra.rewrite.sql.token.pojo.generic.SubstitutableColumnsToken;
import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.column.ColumnSegment;
import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.column.InsertColumnsSegment;
import org.apache.shardingsphere.infra.rewrite.sql.token.generator.CollectionSQLTokenGenerator;
import org.apache.shardingsphere.infra.rewrite.sql.token.pojo.generic.SubstitutableColumnNameToken;

import java.util.Collection;
import java.util.LinkedList;
Expand All @@ -34,8 +38,10 @@
/**
* Insert cipher column name token generator.
*/
public final class InsertCipherNameTokenGenerator extends BaseEncryptSQLTokenGenerator implements CollectionSQLTokenGenerator<InsertStatementContext> {

@Setter
public final class InsertCipherNameTokenGenerator extends BaseEncryptSQLTokenGenerator implements SchemaMetaDataAware, CollectionSQLTokenGenerator<InsertStatementContext> {
private ShardingSphereSchema schema;

@Override
protected boolean isGenerateSQLTokenForEncrypt(final SQLStatementContext sqlStatementContext) {
if (!(sqlStatementContext instanceof InsertStatementContext)) {
Expand All @@ -46,14 +52,18 @@ protected boolean isGenerateSQLTokenForEncrypt(final SQLStatementContext sqlStat
}

@Override
public Collection<SubstitutableColumnNameToken> generateSQLTokens(final InsertStatementContext insertStatementContext) {
public Collection<SubstitutableColumnsToken> generateSQLTokens(final InsertStatementContext insertStatementContext) {
Optional<InsertColumnsSegment> sqlSegment = insertStatementContext.getSqlStatement().getInsertColumns();
Preconditions.checkState(sqlSegment.isPresent());
Map<String, String> logicAndCipherColumns = getEncryptRule().getLogicAndCipherColumns(insertStatementContext.getSqlStatement().getTable().getTableName().getIdentifier().getValue());
Collection<SubstitutableColumnNameToken> result = new LinkedList<>();
Collection<SubstitutableColumnsToken> result = new LinkedList<>();
for (ColumnSegment each : sqlSegment.get().getColumns()) {
if (logicAndCipherColumns.containsKey(each.getIdentifier().getValue())) {
result.add(new SubstitutableColumnNameToken(each.getStartIndex(), each.getStopIndex(), logicAndCipherColumns.get(each.getIdentifier().getValue())));
Optional<String> tableName = insertStatementContext.getTablesContext().findTableName(each, schema);
String owner = each.getOwner().isPresent() ? each.getOwner().get().getIdentifier().getValue() : "";
SubstitutableColumnsToken token = new SubstitutableColumnsToken(each.getStartIndex(), each.getStopIndex(),
new SubstitutableColumn(tableName.get(), owner, logicAndCipherColumns.get(each.getIdentifier().getValue()), each.getIdentifier().getQuoteCharacter(), Optional.empty()));
result.add(token);
}
}
return result;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,11 @@
</rewrite-assertion>

<!-- FIXME #9510 should rewrite owner table as sharding table -->
<!-- <rewrite-assertion id="select_with_unqualified_shorthand">-->
<!-- <input sql="SELECT * FROM t_account" />-->
<!-- <output sql="SELECT `t_account_0`.`account_id`, `t_account_0`.`cipher_password` AS password, `t_account_0`.`cipher_amount` AS amount, `t_account_0`.`status` FROM t_account_0" />-->
<!-- <output sql="SELECT `t_account_1`.`account_id`, `t_account_1`.`cipher_password` AS password, `t_account_1`.`cipher_amount` AS amount, `t_account_1`.`status` FROM t_account_1" />-->
<!-- </rewrite-assertion>-->
<rewrite-assertion id="select_with_unqualified_shorthand">
<input sql="SELECT * FROM t_account" />
<output sql="SELECT `t_account_0`.`account_id`, `t_account_0`.`cipher_password` AS password, `t_account_0`.`cipher_amount` AS amount, `t_account_0`.`status` FROM t_account_0" />
<output sql="SELECT `t_account_1`.`account_id`, `t_account_1`.`cipher_password` AS password, `t_account_1`.`cipher_amount` AS amount, `t_account_1`.`status` FROM t_account_1" />
</rewrite-assertion>

<rewrite-assertion id="select_with_qualified_shorthand">
<input sql="SELECT a.* FROM t_account a" />
Expand All @@ -74,9 +74,9 @@
</rewrite-assertion>

<!-- FIXME #9510 should rewrite owner table as sharding table -->
<!--<rewrite-assertion id="select_with_table_qualified_shorthand">-->
<!--<input sql="SELECT t_account.* FROM t_account" />-->
<!--<output sql="SELECT t_account_0.account_id, t_account_0.cipher_password AS password, t_account_0.cipher_amount AS amount, t_account_0.status FROM t_account_0" />-->
<!--<output sql="SELECT t_account_1.account_id, t_account_1.cipher_password AS password, t_account_1.cipher_amount AS amount, t_account_1.status FROM t_account_1" />-->
<!--</rewrite-assertion>-->
<rewrite-assertion id="select_with_table_qualified_shorthand">
<input sql="SELECT t_account.* FROM t_account" />
<output sql="SELECT `t_account_0`.`account_id`, `t_account_0`.`cipher_password` AS password, `t_account_0`.`cipher_amount` AS amount, `t_account_0`.`status` FROM t_account_0" />
<output sql="SELECT `t_account_1`.`account_id`, `t_account_1`.`cipher_password` AS password, `t_account_1`.`cipher_amount` AS amount, `t_account_1`.`status` FROM t_account_1" />
</rewrite-assertion>
</rewrite-assertions>
Loading

0 comments on commit dc96405

Please sign in to comment.