Skip to content

Commit

Permalink
split AbstractUnsupportedOperationPreparedStatement and AbstractUnsup…
Browse files Browse the repository at this point in the history
…portedOperationStatement
terrymanu committed Aug 30, 2017

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature.
1 parent 45ee7bf commit 07f5b77
Showing 9 changed files with 356 additions and 114 deletions.
Original file line number Diff line number Diff line change
@@ -19,7 +19,6 @@

import com.dangdang.ddframe.rdb.sharding.exception.ShardingJdbcException;
import com.dangdang.ddframe.rdb.sharding.jdbc.adapter.invocation.SetParameterMethodInvocation;
import com.dangdang.ddframe.rdb.sharding.jdbc.core.connection.ShardingConnection;
import com.dangdang.ddframe.rdb.sharding.jdbc.unsupported.AbstractUnsupportedOperationPreparedStatement;
import lombok.Getter;

@@ -32,11 +31,13 @@
import java.sql.Date;
import java.sql.PreparedStatement;
import java.sql.SQLException;
import java.sql.SQLWarning;
import java.sql.SQLXML;
import java.sql.Time;
import java.sql.Timestamp;
import java.util.ArrayList;
import java.util.Calendar;
import java.util.Collection;
import java.util.LinkedList;
import java.util.List;
import java.util.Objects;
@@ -48,15 +49,175 @@
*/
public abstract class AbstractPreparedStatementAdapter extends AbstractUnsupportedOperationPreparedStatement {

private boolean closed;

private boolean poolable;

private int fetchSize;

private final List<SetParameterMethodInvocation> setParameterMethodInvocations = new LinkedList<>();

@Getter
private final List<Object> parameters = new ArrayList<>();

protected AbstractPreparedStatementAdapter(final ShardingConnection shardingConnection, final int resultSetType, final int resultSetConcurrency, final int resultSetHoldability) {
super(shardingConnection, resultSetType, resultSetConcurrency, resultSetHoldability);
@Override
public final void close() throws SQLException {
closed = true;
getRoutedPreparedStatements().clear();
Collection<SQLException> exceptions = new LinkedList<>();
for (PreparedStatement each : getRoutedPreparedStatements()) {
try {
each.close();
} catch (final SQLException ex) {
exceptions.add(ex);
}
}
throwSQLExceptionIfNecessary(exceptions);
}

@Override
public final boolean isClosed() throws SQLException {
return closed;
}

@Override
public final boolean isPoolable() throws SQLException {
return poolable;
}

@Override
public final void setPoolable(final boolean poolable) throws SQLException {
this.poolable = poolable;
if (getRoutedPreparedStatements().isEmpty()) {
recordMethodInvocation(PreparedStatement.class, "setPoolable", new Class[] {boolean.class}, new Object[] {poolable});
return;
}
for (PreparedStatement each : getRoutedPreparedStatements()) {
each.setPoolable(poolable);
}
}

@Override
public final int getFetchSize() throws SQLException {
return fetchSize;
}

@Override
public final void setFetchSize(final int rows) throws SQLException {
this.fetchSize = rows;
if (getRoutedPreparedStatements().isEmpty()) {
recordMethodInvocation(PreparedStatement.class, "setFetchSize", new Class[] {int.class}, new Object[] {rows});
return;
}
for (PreparedStatement each : getRoutedPreparedStatements()) {
each.setFetchSize(rows);
}
}

@Override
public final void setEscapeProcessing(final boolean enable) throws SQLException {
if (getRoutedPreparedStatements().isEmpty()) {
recordMethodInvocation(PreparedStatement.class, "setEscapeProcessing", new Class[] {boolean.class}, new Object[] {enable});
return;
}
for (PreparedStatement each : getRoutedPreparedStatements()) {
each.setEscapeProcessing(enable);
}
}

@Override
public final void cancel() throws SQLException {
for (PreparedStatement each : getRoutedPreparedStatements()) {
each.cancel();
}
}

@Override
public final int getUpdateCount() throws SQLException {
long result = 0;
boolean hasResult = false;
for (PreparedStatement each : getRoutedPreparedStatements()) {
if (each.getUpdateCount() > -1) {
hasResult = true;
}
result += each.getUpdateCount();
}
if (result > Integer.MAX_VALUE) {
result = Integer.MAX_VALUE;
}
return hasResult ? Long.valueOf(result).intValue() : -1;
}

@Override
public SQLWarning getWarnings() throws SQLException {
return null;
}

@Override
public void clearWarnings() throws SQLException {
}

@Override
public final boolean getMoreResults() throws SQLException {
return false;
}

@Override
public final boolean getMoreResults(final int current) throws SQLException {
return false;
}

@Override
public final int getMaxFieldSize() throws SQLException {
return getRoutedPreparedStatements().isEmpty() ? 0 : getRoutedPreparedStatements().iterator().next().getMaxFieldSize();
}

@Override
public final void setMaxFieldSize(final int max) throws SQLException {
if (getRoutedPreparedStatements().isEmpty()) {
recordMethodInvocation(PreparedStatement.class, "setMaxFieldSize", new Class[] {int.class}, new Object[] {max});
return;
}
for (PreparedStatement each : getRoutedPreparedStatements()) {
each.setMaxFieldSize(max);
}
}

// TODO Confirm MaxRows for multiple databases is need special handle. eg: 10 statements maybe MaxRows / 10
@Override
public final int getMaxRows() throws SQLException {
return getRoutedPreparedStatements().isEmpty() ? -1 : getRoutedPreparedStatements().iterator().next().getMaxRows();
}

@Override
public final void setMaxRows(final int max) throws SQLException {
if (getRoutedPreparedStatements().isEmpty()) {
recordMethodInvocation(PreparedStatement.class, "setMaxRows", new Class[] {int.class}, new Object[] {max});
return;
}
for (PreparedStatement each : getRoutedPreparedStatements()) {
each.setMaxRows(max);
}
}

@Override
public final int getQueryTimeout() throws SQLException {
return getRoutedPreparedStatements().isEmpty() ? 0 : getRoutedPreparedStatements().iterator().next().getQueryTimeout();
}

@Override
public final void setQueryTimeout(final int seconds) throws SQLException {
if (getRoutedPreparedStatements().isEmpty()) {
recordMethodInvocation(PreparedStatement.class, "setQueryTimeout", new Class[] {int.class}, new Object[] {seconds});
return;
}
for (PreparedStatement each : getRoutedPreparedStatements()) {
each.setQueryTimeout(seconds);
}
}

protected abstract Collection<PreparedStatement> getRoutedPreparedStatements();

@Override
public final void setNull(final int parameterIndex, final int sqlType) throws SQLException {
setParameter(parameterIndex, null);
Original file line number Diff line number Diff line change
@@ -18,7 +18,6 @@
package com.dangdang.ddframe.rdb.sharding.jdbc.adapter;

import com.dangdang.ddframe.rdb.sharding.jdbc.unsupported.AbstractUnsupportedOperationStatement;
import lombok.RequiredArgsConstructor;

import java.sql.SQLException;
import java.sql.SQLWarning;
@@ -29,21 +28,18 @@
/**
* Adapter for {@code Statement}.
*
* @author zhangliang
* @author gaohongtao
*/
@RequiredArgsConstructor
public abstract class AbstractStatementAdapter extends AbstractUnsupportedOperationStatement {

private final Class<? extends Statement> recordTargetClass;

private boolean closed;

private boolean poolable;

private int fetchSize;

@Override
@SuppressWarnings("unchecked")
public final void close() throws SQLException {
closed = true;
getRoutedStatements().clear();
@@ -72,7 +68,7 @@ public final boolean isPoolable() throws SQLException {
public final void setPoolable(final boolean poolable) throws SQLException {
this.poolable = poolable;
if (getRoutedStatements().isEmpty()) {
recordMethodInvocation(recordTargetClass, "setPoolable", new Class[] {boolean.class}, new Object[] {poolable});
recordMethodInvocation(Statement.class, "setPoolable", new Class[] {boolean.class}, new Object[] {poolable});
return;
}
for (Statement each : getRoutedStatements()) {
@@ -89,7 +85,7 @@ public final int getFetchSize() throws SQLException {
public final void setFetchSize(final int rows) throws SQLException {
this.fetchSize = rows;
if (getRoutedStatements().isEmpty()) {
recordMethodInvocation(recordTargetClass, "setFetchSize", new Class[] {int.class}, new Object[] {rows});
recordMethodInvocation(Statement.class, "setFetchSize", new Class[] {int.class}, new Object[] {rows});
return;
}
for (Statement each : getRoutedStatements()) {
@@ -100,7 +96,7 @@ public final void setFetchSize(final int rows) throws SQLException {
@Override
public final void setEscapeProcessing(final boolean enable) throws SQLException {
if (getRoutedStatements().isEmpty()) {
recordMethodInvocation(recordTargetClass, "setEscapeProcessing", new Class[] {boolean.class}, new Object[] {enable});
recordMethodInvocation(Statement.class, "setEscapeProcessing", new Class[] {boolean.class}, new Object[] {enable});
return;
}
for (Statement each : getRoutedStatements()) {
@@ -115,17 +111,6 @@ public final void cancel() throws SQLException {
}
}

@Override
public final void setCursorName(final String name) throws SQLException {
if (getRoutedStatements().isEmpty()) {
recordMethodInvocation(recordTargetClass, "setCursorName", new Class[] {String.class}, new Object[] {name});
return;
}
for (Statement each : getRoutedStatements()) {
each.setCursorName(name);
}
}

@Override
public final int getUpdateCount() throws SQLException {
long result = 0;
@@ -151,9 +136,6 @@ public SQLWarning getWarnings() throws SQLException {
public void clearWarnings() throws SQLException {
}

/*
* Only store procedures will support multiple ResetSets, so don't support here.
*/
@Override
public final boolean getMoreResults() throws SQLException {
return false;
@@ -172,7 +154,7 @@ public final int getMaxFieldSize() throws SQLException {
@Override
public final void setMaxFieldSize(final int max) throws SQLException {
if (getRoutedStatements().isEmpty()) {
recordMethodInvocation(recordTargetClass, "setMaxFieldSize", new Class[] {int.class}, new Object[] {max});
recordMethodInvocation(Statement.class, "setMaxFieldSize", new Class[] {int.class}, new Object[] {max});
return;
}
for (Statement each : getRoutedStatements()) {
@@ -189,7 +171,7 @@ public final int getMaxRows() throws SQLException {
@Override
public final void setMaxRows(final int max) throws SQLException {
if (getRoutedStatements().isEmpty()) {
recordMethodInvocation(recordTargetClass, "setMaxRows", new Class[] {int.class}, new Object[] {max});
recordMethodInvocation(Statement.class, "setMaxRows", new Class[] {int.class}, new Object[] {max});
return;
}
for (Statement each : getRoutedStatements()) {
@@ -205,13 +187,13 @@ public final int getQueryTimeout() throws SQLException {
@Override
public final void setQueryTimeout(final int seconds) throws SQLException {
if (getRoutedStatements().isEmpty()) {
recordMethodInvocation(recordTargetClass, "setQueryTimeout", new Class[] {int.class}, new Object[] {seconds});
recordMethodInvocation(Statement.class, "setQueryTimeout", new Class[] {int.class}, new Object[] {seconds});
return;
}
for (Statement each : getRoutedStatements()) {
each.setQueryTimeout(seconds);
}
}

protected abstract Collection<? extends Statement> getRoutedStatements();
protected abstract Collection<Statement> getRoutedStatements();
}
Original file line number Diff line number Diff line change
@@ -22,6 +22,7 @@
import com.google.common.base.Preconditions;
import lombok.AccessLevel;
import lombok.Getter;
import lombok.RequiredArgsConstructor;

import java.sql.Connection;
import java.sql.ResultSet;
@@ -35,6 +36,7 @@
*
* @author zhangliang
*/
@RequiredArgsConstructor
@Getter
public final class MasterSlaveStatement extends AbstractStatementAdapter {

@@ -57,14 +59,6 @@ public MasterSlaveStatement(final MasterSlaveConnection connection, final int re
this(connection, resultSetType, resultSetConcurrency, ResultSet.HOLD_CURSORS_OVER_COMMIT);
}

public MasterSlaveStatement(final MasterSlaveConnection connection, final int resultSetType, final int resultSetConcurrency, final int resultSetHoldability) {
super(Statement.class);
this.connection = connection;
this.resultSetType = resultSetType;
this.resultSetConcurrency = resultSetConcurrency;
this.resultSetHoldability = resultSetHoldability;
}

@Override
public ResultSet executeQuery(final String sql) throws SQLException {
Collection<Connection> connections = connection.getConnection(sql);
@@ -155,7 +149,8 @@ public ResultSet getResultSet() throws SQLException {
return routedStatement.getResultSet();
}

protected Collection<? extends Statement> getRoutedStatements() {
@Override
protected Collection<Statement> getRoutedStatements() {
return Collections.singletonList(routedStatement);
}
}
Original file line number Diff line number Diff line change
@@ -24,15 +24,20 @@
import com.dangdang.ddframe.rdb.sharding.executor.type.prepared.PreparedStatementUnit;
import com.dangdang.ddframe.rdb.sharding.jdbc.adapter.AbstractPreparedStatementAdapter;
import com.dangdang.ddframe.rdb.sharding.jdbc.core.connection.ShardingConnection;
import com.dangdang.ddframe.rdb.sharding.jdbc.core.resultset.GeneratedKeysResultSet;
import com.dangdang.ddframe.rdb.sharding.jdbc.core.resultset.ShardingResultSet;
import com.dangdang.ddframe.rdb.sharding.merger.MergeEngine;
import com.dangdang.ddframe.rdb.sharding.parsing.parser.context.GeneratedKey;
import com.dangdang.ddframe.rdb.sharding.parsing.parser.sql.dml.insert.InsertStatement;
import com.dangdang.ddframe.rdb.sharding.parsing.parser.sql.dql.select.SelectStatement;
import com.dangdang.ddframe.rdb.sharding.routing.PreparedStatementRoutingEngine;
import com.dangdang.ddframe.rdb.sharding.routing.SQLExecutionUnit;
import com.dangdang.ddframe.rdb.sharding.routing.SQLRouteResult;
import com.google.common.base.Optional;
import com.google.common.base.Predicate;
import com.google.common.collect.Iterators;
import lombok.AccessLevel;
import lombok.Getter;

import java.sql.Connection;
import java.sql.PreparedStatement;
@@ -51,14 +56,34 @@
* @author zhangliang
* @author caohao
*/
@Getter
public final class ShardingPreparedStatement extends AbstractPreparedStatementAdapter {

private final ShardingConnection connection;

private final int resultSetType;

private final int resultSetConcurrency;

private final int resultSetHoldability;

private final PreparedStatementRoutingEngine routingEngine;

private final List<BatchPreparedStatementUnit> batchStatementUnits = new LinkedList<>();

private final List<List<Object>> parameterSets = new LinkedList<>();

private final Collection<PreparedStatement> routedPreparedStatements = new LinkedList<>();

@Getter(AccessLevel.NONE)
private boolean returnGeneratedKeys;

@Getter(AccessLevel.NONE)
private SQLRouteResult routeResult;

@Getter(AccessLevel.NONE)
private ResultSet currentResultSet;

public ShardingPreparedStatement(final ShardingConnection shardingConnection, final String sql) {
this(shardingConnection, sql, ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY, ResultSet.HOLD_CURSORS_OVER_COMMIT);
}
@@ -70,13 +95,16 @@ public ShardingPreparedStatement(final ShardingConnection shardingConnection, fi
public ShardingPreparedStatement(final ShardingConnection shardingConnection, final String sql, final int autoGeneratedKeys) {
this(shardingConnection, sql);
if (RETURN_GENERATED_KEYS == autoGeneratedKeys) {
markReturnGeneratedKeys();
returnGeneratedKeys = true;
}
}

public ShardingPreparedStatement(final ShardingConnection shardingConnection, final String sql, final int resultSetType, final int resultSetConcurrency, final int resultSetHoldability) {
super(shardingConnection, resultSetType, resultSetConcurrency, resultSetHoldability);
routingEngine = new PreparedStatementRoutingEngine(sql, shardingConnection.getShardingContext());
public ShardingPreparedStatement(final ShardingConnection connection, final String sql, final int resultSetType, final int resultSetConcurrency, final int resultSetHoldability) {
this.connection = connection;
this.resultSetType = resultSetType;
this.resultSetConcurrency = resultSetConcurrency;
this.resultSetHoldability = resultSetHoldability;
routingEngine = new PreparedStatementRoutingEngine(sql, connection.getShardingContext());
}

@Override
@@ -85,12 +113,12 @@ public ResultSet executeQuery() throws SQLException {
try {
Collection<PreparedStatementUnit> preparedStatementUnits = route();
List<ResultSet> resultSets = new PreparedStatementExecutor(
getConnection().getShardingContext().getExecutorEngine(), getRouteResult().getSqlStatement().getType(), preparedStatementUnits, getParameters()).executeQuery();
result = new ShardingResultSet(resultSets, new MergeEngine(resultSets, (SelectStatement) getRouteResult().getSqlStatement()).merge());
getConnection().getShardingContext().getExecutorEngine(), routeResult.getSqlStatement().getType(), preparedStatementUnits, getParameters()).executeQuery();
result = new ShardingResultSet(resultSets, new MergeEngine(resultSets, (SelectStatement) routeResult.getSqlStatement()).merge());
} finally {
clearBatch();
}
setCurrentResultSet(result);
currentResultSet = result;
return result;
}

@@ -99,7 +127,7 @@ public int executeUpdate() throws SQLException {
try {
Collection<PreparedStatementUnit> preparedStatementUnits = route();
return new PreparedStatementExecutor(
getConnection().getShardingContext().getExecutorEngine(), getRouteResult().getSqlStatement().getType(), preparedStatementUnits, getParameters()).executeUpdate();
getConnection().getShardingContext().getExecutorEngine(), routeResult.getSqlStatement().getType(), preparedStatementUnits, getParameters()).executeUpdate();
} finally {
clearBatch();
}
@@ -110,24 +138,24 @@ public boolean execute() throws SQLException {
try {
Collection<PreparedStatementUnit> preparedStatementUnits = route();
return new PreparedStatementExecutor(
getConnection().getShardingContext().getExecutorEngine(), getRouteResult().getSqlStatement().getType(), preparedStatementUnits, getParameters()).execute();
getConnection().getShardingContext().getExecutorEngine(), routeResult.getSqlStatement().getType(), preparedStatementUnits, getParameters()).execute();
} finally {
clearBatch();
}
}

private Collection<PreparedStatementUnit> route() throws SQLException {
Collection<PreparedStatementUnit> result = new LinkedList<>();
setRouteResult(routingEngine.route(getParameters()));
for (SQLExecutionUnit each : getRouteResult().getExecutionUnits()) {
SQLType sqlType = getRouteResult().getSqlStatement().getType();
routeResult = routingEngine.route(getParameters());
for (SQLExecutionUnit each : routeResult.getExecutionUnits()) {
SQLType sqlType = routeResult.getSqlStatement().getType();
Collection<PreparedStatement> preparedStatements;
if (SQLType.DDL == sqlType) {
preparedStatements = generatePreparedStatementForDDL(each);
} else {
preparedStatements = Collections.singletonList(generatePreparedStatement(each));
}
getRoutedStatements().addAll(preparedStatements);
routedPreparedStatements.addAll(preparedStatements);
for (PreparedStatement preparedStatement : preparedStatements) {
replaySetParameter(preparedStatement);
result.add(new PreparedStatementUnit(each, preparedStatement));
@@ -140,23 +168,30 @@ private Collection<PreparedStatement> generatePreparedStatementForDDL(final SQLE
Collection<PreparedStatement> result = new LinkedList<>();
Collection<Connection> connections = getConnection().getAllConnections(sqlExecutionUnit.getDataSource());
for (Connection each : connections) {
result.add(each.prepareStatement(sqlExecutionUnit.getSql(), getResultSetType(), getResultSetConcurrency(), getResultSetHoldability()));
result.add(each.prepareStatement(sqlExecutionUnit.getSql(), resultSetType, resultSetConcurrency, resultSetHoldability));
}
return result;
}

private PreparedStatement generatePreparedStatement(final SQLExecutionUnit sqlExecutionUnit) throws SQLException {
Optional<GeneratedKey> generatedKey = getGeneratedKey();
Connection connection = getConnection().getConnection(sqlExecutionUnit.getDataSource(), getRouteResult().getSqlStatement().getType());
if (isReturnGeneratedKeys() || isReturnGeneratedKeys() && generatedKey.isPresent()) {
Connection connection = getConnection().getConnection(sqlExecutionUnit.getDataSource(), routeResult.getSqlStatement().getType());
if (returnGeneratedKeys && generatedKey.isPresent()) {
return connection.prepareStatement(sqlExecutionUnit.getSql(), RETURN_GENERATED_KEYS);
}
return connection.prepareStatement(sqlExecutionUnit.getSql(), getResultSetType(), getResultSetConcurrency(), getResultSetHoldability());
return connection.prepareStatement(sqlExecutionUnit.getSql(), resultSetType, resultSetConcurrency, resultSetHoldability);
}

private Optional<GeneratedKey> getGeneratedKey() {
if (null != routeResult && routeResult.getSqlStatement() instanceof InsertStatement) {
return Optional.fromNullable(((InsertStatement) routeResult.getSqlStatement()).getGeneratedKey());
}
return Optional.absent();
}

@Override
public void clearBatch() throws SQLException {
setCurrentResultSet(null);
currentResultSet = null;
clearParameters();
batchStatementUnits.clear();
parameterSets.clear();
@@ -171,7 +206,7 @@ public void addBatch() throws SQLException {
}
parameterSets.add(getParameters());
} finally {
setCurrentResultSet(null);
currentResultSet = null;
clearParameters();
}
}
@@ -180,16 +215,28 @@ public void addBatch() throws SQLException {
public int[] executeBatch() throws SQLException {
try {
return new BatchPreparedStatementExecutor(getConnection().getShardingContext().getExecutorEngine(),
getConnection().getShardingContext().getDatabaseType(), getRouteResult().getSqlStatement().getType(), batchStatementUnits, parameterSets).executeBatch();
getConnection().getShardingContext().getDatabaseType(), routeResult.getSqlStatement().getType(), batchStatementUnits, parameterSets).executeBatch();
} finally {
clearBatch();
}
}

@Override
public ResultSet getGeneratedKeys() throws SQLException {
Optional<GeneratedKey> generatedKey = getGeneratedKey();
if (returnGeneratedKeys && generatedKey.isPresent()) {
return new GeneratedKeysResultSet(routeResult.getGeneratedKeys().iterator(), generatedKey.get().getColumn(), this);
}
if (1 == routedPreparedStatements.size()) {
return routedPreparedStatements.iterator().next().getGeneratedKeys();
}
return new GeneratedKeysResultSet();
}

private List<BatchPreparedStatementUnit> routeBatch() throws SQLException {
List<BatchPreparedStatementUnit> result = new ArrayList<>();
setRouteResult(routingEngine.route(getParameters()));
for (SQLExecutionUnit each : getRouteResult().getExecutionUnits()) {
routeResult = routingEngine.route(getParameters());
for (SQLExecutionUnit each : routeResult.getExecutionUnits()) {
BatchPreparedStatementUnit batchStatementUnit = getPreparedBatchStatement(each);
replaySetParameter(batchStatementUnit.getStatement());
result.add(batchStatementUnit);
@@ -212,4 +259,26 @@ public boolean apply(final BatchPreparedStatementUnit input) {
batchStatementUnits.add(result);
return result;
}

@Override
public ResultSet getResultSet() throws SQLException {
if (null != currentResultSet) {
return currentResultSet;
}
if (1 == routedPreparedStatements.size()) {
currentResultSet = routedPreparedStatements.iterator().next().getResultSet();
return currentResultSet;
}
List<ResultSet> resultSets = new ArrayList<>(routedPreparedStatements.size());
for (PreparedStatement each : routedPreparedStatements) {
resultSets.add(each.getResultSet());
}
currentResultSet = new ShardingResultSet(resultSets, new MergeEngine(resultSets, (SelectStatement) routeResult.getSqlStatement()).merge());
return currentResultSet;
}

@Override
protected Collection<PreparedStatement> getRoutedPreparedStatements() {
return routedPreparedStatements;
}
}
Original file line number Diff line number Diff line change
@@ -34,7 +34,6 @@
import com.google.common.base.Optional;
import lombok.AccessLevel;
import lombok.Getter;
import lombok.Setter;

import java.sql.Connection;
import java.sql.ResultSet;
@@ -66,12 +65,13 @@ public class ShardingStatement extends AbstractStatementAdapter {

private final Collection<Statement> routedStatements = new LinkedList<>();

@Getter(AccessLevel.NONE)
private boolean returnGeneratedKeys;

@Setter(AccessLevel.PROTECTED)
@Getter(AccessLevel.NONE)
private SQLRouteResult routeResult;

@Setter(AccessLevel.PROTECTED)
@Getter(AccessLevel.NONE)
private ResultSet currentResultSet;

public ShardingStatement(final ShardingConnection connection) {
@@ -83,7 +83,6 @@ public ShardingStatement(final ShardingConnection connection, final int resultSe
}

public ShardingStatement(final ShardingConnection connection, final int resultSetType, final int resultSetConcurrency, final int resultSetHoldability) {
super(Statement.class);
this.connection = connection;
this.resultSetType = resultSetType;
this.resultSetConcurrency = resultSetConcurrency;
@@ -96,11 +95,11 @@ public ResultSet executeQuery(final String sql) throws SQLException {
try {
List<ResultSet> resultSets = generateExecutor(sql).executeQuery();
result = new ShardingResultSet(
resultSets, new MergeEngine(resultSets, (SelectStatement) getRouteResult().getSqlStatement()).merge());
resultSets, new MergeEngine(resultSets, (SelectStatement) routeResult.getSqlStatement()).merge());
} finally {
setCurrentResultSet(null);
currentResultSet = null;
}
setCurrentResultSet(result);
currentResultSet = result;
return result;
}

@@ -109,39 +108,39 @@ public int executeUpdate(final String sql) throws SQLException {
try {
return generateExecutor(sql).executeUpdate();
} finally {
setCurrentResultSet(null);
currentResultSet = null;
}
}

@Override
public int executeUpdate(final String sql, final int autoGeneratedKeys) throws SQLException {
if (RETURN_GENERATED_KEYS == autoGeneratedKeys) {
markReturnGeneratedKeys();
returnGeneratedKeys = true;
}
try {
return generateExecutor(sql).executeUpdate(autoGeneratedKeys);
} finally {
setCurrentResultSet(null);
currentResultSet = null;
}
}

@Override
public int executeUpdate(final String sql, final int[] columnIndexes) throws SQLException {
markReturnGeneratedKeys();
returnGeneratedKeys = true;
try {
return generateExecutor(sql).executeUpdate(columnIndexes);
} finally {
setCurrentResultSet(null);
currentResultSet = null;
}
}

@Override
public int executeUpdate(final String sql, final String[] columnNames) throws SQLException {
markReturnGeneratedKeys();
returnGeneratedKeys = true;
try {
return generateExecutor(sql).executeUpdate(columnNames);
} finally {
setCurrentResultSet(null);
currentResultSet = null;
}
}

@@ -150,46 +149,42 @@ public boolean execute(final String sql) throws SQLException {
try {
return generateExecutor(sql).execute();
} finally {
setCurrentResultSet(null);
currentResultSet = null;
}
}

@Override
public boolean execute(final String sql, final int autoGeneratedKeys) throws SQLException {
if (RETURN_GENERATED_KEYS == autoGeneratedKeys) {
markReturnGeneratedKeys();
returnGeneratedKeys = true;
}
try {
return generateExecutor(sql).execute(autoGeneratedKeys);
} finally {
setCurrentResultSet(null);
currentResultSet = null;
}
}

@Override
public boolean execute(final String sql, final int[] columnIndexes) throws SQLException {
markReturnGeneratedKeys();
returnGeneratedKeys = true;
try {
return generateExecutor(sql).execute(columnIndexes);
} finally {
setCurrentResultSet(null);
currentResultSet = null;
}
}

@Override
public boolean execute(final String sql, final String[] columnNames) throws SQLException {
markReturnGeneratedKeys();
returnGeneratedKeys = true;
try {
return generateExecutor(sql).execute(columnNames);
} finally {
setCurrentResultSet(null);
currentResultSet = null;
}
}

protected final void markReturnGeneratedKeys() {
returnGeneratedKeys = true;
}

private StatementExecutor generateExecutor(final String sql) throws SQLException {
clearPrevious();
routeResult = new StatementRoutingEngine(connection.getShardingContext()).route(sql);
@@ -222,7 +217,7 @@ private void clearPrevious() throws SQLException {
@Override
public ResultSet getGeneratedKeys() throws SQLException {
Optional<GeneratedKey> generatedKey = getGeneratedKey();
if (generatedKey.isPresent() && returnGeneratedKeys) {
if (returnGeneratedKeys && generatedKey.isPresent()) {
return new GeneratedKeysResultSet(routeResult.getGeneratedKeys().iterator(), generatedKey.get().getColumn(), this);
}
if (1 == getRoutedStatements().size()) {
@@ -231,7 +226,7 @@ public ResultSet getGeneratedKeys() throws SQLException {
return new GeneratedKeysResultSet();
}

protected final Optional<GeneratedKey> getGeneratedKey() {
private Optional<GeneratedKey> getGeneratedKey() {
if (null != routeResult && routeResult.getSqlStatement() instanceof InsertStatement) {
return Optional.fromNullable(((InsertStatement) routeResult.getSqlStatement()).getGeneratedKey());
}
@@ -251,7 +246,7 @@ public ResultSet getResultSet() throws SQLException {
for (Statement each : routedStatements) {
resultSets.add(each.getResultSet());
}
currentResultSet = new ShardingResultSet(resultSets, new MergeEngine(resultSets, (SelectStatement) getRouteResult().getSqlStatement()).merge());
currentResultSet = new ShardingResultSet(resultSets, new MergeEngine(resultSets, (SelectStatement) routeResult.getSqlStatement()).merge());
return currentResultSet;
}
}
Original file line number Diff line number Diff line change
@@ -17,15 +17,13 @@

package com.dangdang.ddframe.rdb.sharding.jdbc.unsupported;

import com.dangdang.ddframe.rdb.sharding.jdbc.core.connection.ShardingConnection;
import com.dangdang.ddframe.rdb.sharding.jdbc.core.statement.ShardingStatement;

import java.io.Reader;
import java.sql.Array;
import java.sql.NClob;
import java.sql.ParameterMetaData;
import java.sql.PreparedStatement;
import java.sql.Ref;
import java.sql.ResultSet;
import java.sql.ResultSetMetaData;
import java.sql.RowId;
import java.sql.SQLException;
@@ -36,11 +34,7 @@
*
* @author zhangliang
*/
public abstract class AbstractUnsupportedOperationPreparedStatement extends ShardingStatement implements PreparedStatement {

protected AbstractUnsupportedOperationPreparedStatement(final ShardingConnection shardingConnection, final int resultSetType, final int resultSetConcurrency, final int resultSetHoldability) {
super(shardingConnection, resultSetType, resultSetConcurrency, resultSetHoldability);
}
public abstract class AbstractUnsupportedOperationPreparedStatement extends AbstractUnsupportedOperationStatement implements PreparedStatement {

@Override
public final ResultSetMetaData getMetaData() throws SQLException {
@@ -96,4 +90,49 @@ public final void setRowId(final int parameterIndex, final RowId x) throws SQLEx
public final void setRef(final int parameterIndex, final Ref x) throws SQLException {
throw new SQLFeatureNotSupportedException("setRef");
}

@Override
public final ResultSet executeQuery(final String sql) throws SQLException {
throw new SQLFeatureNotSupportedException("executeQuery with SQL for PreparedStatement");
}

@Override
public final int executeUpdate(final String sql) throws SQLException {
throw new SQLFeatureNotSupportedException("executeUpdate with SQL for PreparedStatement");
}

@Override
public final int executeUpdate(final String sql, final int autoGeneratedKeys) throws SQLException {
throw new SQLFeatureNotSupportedException("executeUpdate with SQL for PreparedStatement");
}

@Override
public final int executeUpdate(final String sql, final int[] columnIndexes) throws SQLException {
throw new SQLFeatureNotSupportedException("executeUpdate with SQL for PreparedStatement");
}

@Override
public final int executeUpdate(final String sql, final String[] columnNames) throws SQLException {
throw new SQLFeatureNotSupportedException("executeUpdate with SQL for PreparedStatement");
}

@Override
public final boolean execute(final String sql) throws SQLException {
throw new SQLFeatureNotSupportedException("execute with SQL for PreparedStatement");
}

@Override
public final boolean execute(final String sql, final int autoGeneratedKeys) throws SQLException {
throw new SQLFeatureNotSupportedException("execute with SQL for PreparedStatement");
}

@Override
public final boolean execute(final String sql, final int[] columnIndexes) throws SQLException {
throw new SQLFeatureNotSupportedException("execute with SQL for PreparedStatement");
}

@Override
public final boolean execute(final String sql, final String[] columnNames) throws SQLException {
throw new SQLFeatureNotSupportedException("execute with SQL for PreparedStatement");
}
}
Original file line number Diff line number Diff line change
@@ -64,4 +64,9 @@ public final void closeOnCompletion() throws SQLException {
public final boolean isCloseOnCompletion() throws SQLException {
throw new SQLFeatureNotSupportedException("isCloseOnCompletion");
}

@Override
public final void setCursorName(final String name) throws SQLException {
throw new SQLFeatureNotSupportedException("setCursorName");
}
}
Original file line number Diff line number Diff line change
@@ -146,17 +146,6 @@ public void assertCancel() throws SQLException {
}
}

@Test
public void assertSetCursorName() throws SQLException {
for (Map.Entry<DatabaseType, Statement> each : statements.entrySet()) {
if (DatabaseType.Oracle != each.getKey()) {
each.getValue().setCursorName("cursorName");
each.getValue().executeQuery(sql);
each.getValue().setCursorName("cursorName");
}
}
}

@Test
public void assertGetUpdateCount() throws SQLException {
String sql = "DELETE FROM t_order WHERE status = 'init'";
@@ -193,15 +182,15 @@ public void assertGetUpdateCountSelect() throws SQLException {

@Test
public void assertOverMaxUpdateRow() throws SQLException {
final Statement st1 = Mockito.mock(Statement.class);
when(st1.getUpdateCount()).thenReturn(Integer.MAX_VALUE);
final Statement st2 = Mockito.mock(Statement.class);
when(st2.getUpdateCount()).thenReturn(Integer.MAX_VALUE);
AbstractStatementAdapter statement = new AbstractStatementAdapter(Statement.class) {
final Statement statement1 = Mockito.mock(Statement.class);
when(statement1.getUpdateCount()).thenReturn(Integer.MAX_VALUE);
final Statement statement2 = Mockito.mock(Statement.class);
when(statement2.getUpdateCount()).thenReturn(Integer.MAX_VALUE);
AbstractStatementAdapter statement = new AbstractStatementAdapter() {

@Override
protected Collection<? extends Statement> getRoutedStatements() {
return Lists.newArrayList(st1, st2);
protected Collection<Statement> getRoutedStatements() {
return Lists.newArrayList(statement1, statement2);
}

@Override
Original file line number Diff line number Diff line change
@@ -106,4 +106,11 @@ public void assertIsCloseOnCompletion() throws SQLException {
each.isCloseOnCompletion();
}
}

@Test(expected = SQLFeatureNotSupportedException.class)
public void assertSetCursorName() throws SQLException {
for (Statement each : statements) {
each.setCursorName("cursorName");
}
}
}

0 comments on commit 07f5b77

Please sign in to comment.