Skip to content

Commit

Permalink
for apache#675 initial succeed
Browse files Browse the repository at this point in the history
  • Loading branch information
tuohai666 committed Apr 20, 2018
1 parent 3a09e21 commit 5c021b8
Show file tree
Hide file tree
Showing 13 changed files with 225 additions and 54 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,27 @@ public final class StatementExecuteBackendHandler implements BackendHandler {

private final PreparedStatementRoutingEngine routingEngine;

private List<Connection> connections;

private List<ResultSet> resultSets;

private MergedResult mergedResult;

private int currentSequenceId;

private int columnCount;

private final List<ColumnType> columnTypes;

private boolean noMoreValues;

public StatementExecuteBackendHandler(final List<PreparedStatementParameter> preparedStatementParameters, final int statementId, final DatabaseType databaseType, final boolean showSQL) {
this.preparedStatementParameters = preparedStatementParameters;
routingEngine = new PreparedStatementRoutingEngine(PreparedStatementRegistry.getInstance().getSQL(statementId), ShardingRuleRegistry.getInstance().getShardingRule(), databaseType, showSQL);
connections = new ArrayList<>(1024);
resultSets = new ArrayList<>(1024);
columnTypes = new ArrayList<>(32);
noMoreValues = false;
}

@Override
Expand All @@ -77,24 +95,23 @@ public CommandResponsePackets execute() {
if (routeResult.getExecutionUnits().isEmpty()) {
return new CommandResponsePackets(new OKPacket(1, 0, 0, StatusFlag.SERVER_STATUS_AUTOCOMMIT.getValue(), 0, ""));
}
List<ColumnType> columnTypes = new ArrayList<>(32);
List<CommandResponsePackets> result = new LinkedList<>();
for (SQLExecutionUnit each : routeResult.getExecutionUnits()) {
// TODO multiple threads
result.add(execute(routeResult.getSqlStatement(), each, columnTypes));
result.add(execute(routeResult.getSqlStatement(), each));
}
return merge(routeResult.getSqlStatement(), result, columnTypes);
return merge(routeResult.getSqlStatement(), result);
}

private CommandResponsePackets execute(final SQLStatement sqlStatement, final SQLExecutionUnit sqlExecutionUnit, final List<ColumnType> columnTypes) {
private CommandResponsePackets execute(final SQLStatement sqlStatement, final SQLExecutionUnit sqlExecutionUnit) {
switch (sqlStatement.getType()) {
case DQL:
return executeQuery(ShardingRuleRegistry.getInstance().getDataSourceMap().get(sqlExecutionUnit.getDataSource()), sqlExecutionUnit.getSql(), columnTypes);
return executeQuery(ShardingRuleRegistry.getInstance().getDataSourceMap().get(sqlExecutionUnit.getDataSource()), sqlExecutionUnit.getSql());
case DML:
case DDL:
return executeUpdate(ShardingRuleRegistry.getInstance().getDataSourceMap().get(sqlExecutionUnit.getDataSource()), sqlExecutionUnit.getSql(), sqlStatement);
default:
return executeCommon(ShardingRuleRegistry.getInstance().getDataSourceMap().get(sqlExecutionUnit.getDataSource()), sqlExecutionUnit.getSql(), columnTypes);
return executeCommon(ShardingRuleRegistry.getInstance().getDataSourceMap().get(sqlExecutionUnit.getDataSource()), sqlExecutionUnit.getSql());
}
}

Expand All @@ -112,15 +129,25 @@ private void setJDBCPreparedStatementParameters(final PreparedStatement prepared
}
}

private CommandResponsePackets executeQuery(final DataSource dataSource, final String sql, final List<ColumnType> columnTypes) {
try (
Connection connection = dataSource.getConnection();
PreparedStatement preparedStatement = connection.prepareStatement(sql)) {
private CommandResponsePackets executeQuery(final DataSource dataSource, final String sql) {
PreparedStatement preparedStatement = null;
try {
Connection connection = dataSource.getConnection();
connections.add(connection);
preparedStatement = connection.prepareStatement(sql);
preparedStatement.setFetchSize(Integer.MIN_VALUE);
setJDBCPreparedStatementParameters(preparedStatement);
ResultSet resultSet = preparedStatement.executeQuery();
return getDatabaseProtocolPackets(resultSet, columnTypes);
resultSets.add(preparedStatement.executeQuery());
return getDatabaseProtocolPackets();
} catch (final SQLException ex) {
return new CommandResponsePackets(new ErrPacket(1, ex.getErrorCode(), "", ex.getSQLState(), ex.getMessage()));
} finally {
// if (preparedStatement != null) {
// try {
// preparedStatement.close();
// } catch (SQLException ignore) {
// }
// }
}
}

Expand Down Expand Up @@ -151,17 +178,17 @@ private CommandResponsePackets executeUpdate(final DataSource dataSource, final
}
}
}

}

private CommandResponsePackets executeCommon(final DataSource dataSource, final String sql, final List<ColumnType> columnTypes) {
private CommandResponsePackets executeCommon(final DataSource dataSource, final String sql) {
try (
Connection connection = dataSource.getConnection();
PreparedStatement preparedStatement = connection.prepareStatement(sql)) {
setJDBCPreparedStatementParameters(preparedStatement);
boolean hasResultSet = preparedStatement.execute();
if (hasResultSet) {
return getDatabaseProtocolPackets(preparedStatement.getResultSet(), columnTypes);
resultSets.add(preparedStatement.getResultSet());
return getDatabaseProtocolPackets();
} else {
return new CommandResponsePackets(new OKPacket(1, preparedStatement.getUpdateCount(), 0, StatusFlag.SERVER_STATUS_AUTOCOMMIT.getValue(), 0, ""));
}
Expand All @@ -170,11 +197,11 @@ private CommandResponsePackets executeCommon(final DataSource dataSource, final
}
}

private CommandResponsePackets getDatabaseProtocolPackets(final ResultSet resultSet, final List<ColumnType> columnTypes) throws SQLException {
private CommandResponsePackets getDatabaseProtocolPackets() throws SQLException {
CommandResponsePackets result = new CommandResponsePackets();
int currentSequenceId = 0;
ResultSetMetaData resultSetMetaData = resultSet.getMetaData();
int columnCount = resultSetMetaData.getColumnCount();
ResultSetMetaData resultSetMetaData = resultSets.get(resultSets.size() - 1).getMetaData();
columnCount = resultSetMetaData.getColumnCount();
if (0 == columnCount) {
result.addPacket(new OKPacket(++currentSequenceId, 0, 0, StatusFlag.SERVER_STATUS_AUTOCOMMIT.getValue(), 0, ""));
return result;
Expand All @@ -188,14 +215,6 @@ private CommandResponsePackets getDatabaseProtocolPackets(final ResultSet result
columnTypes.add(columnType);
}
result.addPacket(new EofPacket(++currentSequenceId, 0, StatusFlag.SERVER_STATUS_AUTOCOMMIT.getValue()));
while (resultSet.next()) {
List<Object> data = new ArrayList<>(columnCount);
for (int i = 1; i <= columnCount; i++) {
data.add(resultSet.getObject(i));
}
result.addPacket(new BinaryResultSetRowPacket(++currentSequenceId, columnCount, data, columnTypes));
}
result.addPacket(new EofPacket(++currentSequenceId, 0, StatusFlag.SERVER_STATUS_AUTOCOMMIT.getValue()));
return result;
}

Expand All @@ -208,7 +227,7 @@ private long getGeneratedKey(final PreparedStatement preparedStatement) throws S
return result;
}

private CommandResponsePackets merge(final SQLStatement sqlStatement, final List<CommandResponsePackets> packets, final List<ColumnType> columnTypes) {
private CommandResponsePackets merge(final SQLStatement sqlStatement, final List<CommandResponsePackets> packets) {
if (1 == packets.size()) {
return packets.iterator().next();
}
Expand All @@ -225,7 +244,7 @@ private CommandResponsePackets merge(final SQLStatement sqlStatement, final List
return mergeDML(headPackets);
}
if (SQLType.DQL == sqlStatement.getType() || SQLType.DAL == sqlStatement.getType()) {
return mergeDQLorDAL(sqlStatement, packets, columnTypes);
return mergeDQLorDAL(sqlStatement, packets);
}
return packets.get(0);
}
Expand All @@ -241,44 +260,71 @@ private CommandResponsePackets mergeDML(final CommandResponsePackets firstPacket
return new CommandResponsePackets(new OKPacket(1, affectedRows, 0, StatusFlag.SERVER_STATUS_AUTOCOMMIT.getValue(), 0, ""));
}

private CommandResponsePackets mergeDQLorDAL(final SQLStatement sqlStatement, final List<CommandResponsePackets> packets, final List<ColumnType> columnTypes) {
private CommandResponsePackets mergeDQLorDAL(final SQLStatement sqlStatement, final List<CommandResponsePackets> packets) {
List<QueryResult> queryResults = new ArrayList<>(packets.size());
for (CommandResponsePackets each : packets) {
// for (CommandResponsePackets each : packets) {
// // TODO replace to a common PacketQueryResult
// queryResults.add(new MySQLPacketStatementExecuteQueryResult(each, resultSet, columnTypes));
// }
for (int i = 0; i < packets.size(); i++) {
// TODO replace to a common PacketQueryResult
queryResults.add(new MySQLPacketStatementExecuteQueryResult(each));
queryResults.add(new MySQLPacketStatementExecuteQueryResult(packets.get(i), resultSets.get(i), columnTypes));
}
MergedResult mergedResult;
try {
mergedResult = MergeEngineFactory.newInstance(ShardingRuleRegistry.getInstance().getShardingRule(), queryResults, sqlStatement).merge();
} catch (final SQLException ex) {
return new CommandResponsePackets(new ErrPacket(1, ex.getErrorCode(), "", ex.getSQLState(), ex.getMessage()));
}
return buildPackets(packets, mergedResult, columnTypes);
return buildPackets(packets);
}

private CommandResponsePackets buildPackets(final List<CommandResponsePackets> packets, final MergedResult mergedResult, final List<ColumnType> columnTypes) {
private CommandResponsePackets buildPackets(final List<CommandResponsePackets> packets) {
CommandResponsePackets result = new CommandResponsePackets();
Iterator<DatabaseProtocolPacket> databaseProtocolPacketsSampling = packets.iterator().next().getDatabaseProtocolPackets().iterator();
FieldCountPacket fieldCountPacketSampling = (FieldCountPacket) databaseProtocolPacketsSampling.next();
result.addPacket(fieldCountPacketSampling);
++currentSequenceId;
int columnCount = fieldCountPacketSampling.getColumnCount();
for (int i = 0; i < columnCount; i++) {
result.addPacket(databaseProtocolPacketsSampling.next());
++currentSequenceId;
}
result.addPacket(databaseProtocolPacketsSampling.next());
int currentSequenceId = result.size();
try {
while (mergedResult.next()) {
List<Object> data = new ArrayList<>(columnCount);
for (int i = 1; i <= columnCount; i++) {
data.add(mergedResult.getValue(i, Object.class));
++currentSequenceId;
return result;
}

public boolean hasMoreResultValue() throws SQLException {
if (noMoreValues) {

return false;
}
if (!mergedResult.next()) {
noMoreValues = true;
for (Connection each : connections) {
if (null != each) {
try {
each.close();
} catch (SQLException ignore) {
}
}
result.addPacket(new BinaryResultSetRowPacket(++currentSequenceId, columnCount, data, columnTypes));
}
}
return true;
}

public DatabaseProtocolPacket getResultValue() {
if (noMoreValues) {
return new EofPacket(++currentSequenceId, 0, StatusFlag.SERVER_STATUS_AUTOCOMMIT.getValue());
}
try {
List<Object> data = new ArrayList<>(columnCount);
for (int i = 1; i <= columnCount; i++) {
data.add(mergedResult.getValue(i, Object.class));
}
return new BinaryResultSetRowPacket(++currentSequenceId, columnCount, data, columnTypes);
} catch (final SQLException ex) {
return new CommandResponsePackets(new ErrPacket(1, ex.getErrorCode(), "", ex.getSQLState(), ex.getMessage()));
return new ErrPacket(1, ex.getErrorCode(), "", ex.getSQLState(), ex.getMessage());
}
result.addPacket(new EofPacket(++currentSequenceId, 0, StatusFlag.SERVER_STATUS_AUTOCOMMIT.getValue()));
return result;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,21 @@

import io.shardingjdbc.core.merger.QueryResult;
import io.shardingjdbc.proxy.transport.common.packet.DatabaseProtocolPacket;
import io.shardingjdbc.proxy.transport.mysql.constant.ColumnType;
import io.shardingjdbc.proxy.transport.mysql.packet.command.CommandResponsePackets;
import io.shardingjdbc.proxy.transport.mysql.packet.command.statement.execute.BinaryResultSetRowPacket;
import io.shardingjdbc.proxy.transport.mysql.packet.command.text.query.ColumnDefinition41Packet;
import io.shardingjdbc.proxy.transport.mysql.packet.command.text.query.FieldCountPacket;
import lombok.RequiredArgsConstructor;

import java.io.InputStream;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.util.ArrayList;
import java.util.Calendar;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;

/**
Expand All @@ -45,11 +50,15 @@ public final class MySQLPacketStatementExecuteQueryResult implements QueryResult

private final Map<String, Integer> columnLabelAndIndexMap;

private final Iterator<DatabaseProtocolPacket> data;
private final ResultSet resultSet;

private final List<ColumnType> columnTypes;

private int currentSequenceId;

private BinaryResultSetRowPacket currentRow;

public MySQLPacketStatementExecuteQueryResult(final CommandResponsePackets packets) {
public MySQLPacketStatementExecuteQueryResult(final CommandResponsePackets packets, final ResultSet resultSet, final List<ColumnType> columnTypes) {
Iterator<DatabaseProtocolPacket> packetIterator = packets.getDatabaseProtocolPackets().iterator();
columnCount = ((FieldCountPacket) packetIterator.next()).getColumnCount();
columnIndexAndLabelMap = new HashMap<>(columnCount, 1);
Expand All @@ -60,14 +69,18 @@ public MySQLPacketStatementExecuteQueryResult(final CommandResponsePackets packe
columnLabelAndIndexMap.put(columnDefinition41Packet.getName(), i);
}
packetIterator.next();
data = packetIterator;
this.resultSet = resultSet;
this.columnTypes = columnTypes;
}

@Override
public boolean next() {
DatabaseProtocolPacket databaseProtocolPacket = data.next();
if (databaseProtocolPacket instanceof BinaryResultSetRowPacket) {
currentRow = (BinaryResultSetRowPacket) databaseProtocolPacket;
public boolean next() throws SQLException {
if (resultSet.next()) {
List<Object> data = new ArrayList<>(columnCount);
for (int i = 1; i <= columnCount; i++) {
data.add(resultSet.getObject(i));
}
currentRow = new BinaryResultSetRowPacket(++currentSequenceId, columnCount, data, columnTypes);
return true;
}
return false;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,11 @@ public void run() {
int sequenceId = mysqlPacketPayload.readInt1();
CommandPacket commandPacket = CommandPacketFactory.getCommandPacket(sequenceId, mysqlPacketPayload);
for (DatabaseProtocolPacket each : commandPacket.execute().getDatabaseProtocolPackets()) {
context.write(each);
context.writeAndFlush(each);
}
while (commandPacket.hasMoreResultValue()) {
context.writeAndFlush(commandPacket.getResultValue());
}
context.flush();
}
});
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
* @see <a href="https://dev.mysql.com/doc/internals/en/binary-protocol-value.html">binary protocol value</a>
*
* @author zhangliang
* @author zhangyonglun
*/
@RequiredArgsConstructor
@Getter
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package io.shardingjdbc.proxy.transport.mysql.packet.command;

import io.shardingjdbc.proxy.transport.common.packet.DatabaseProtocolPacket;
import io.shardingjdbc.proxy.transport.mysql.packet.MySQLPacket;

/**
Expand All @@ -36,4 +37,18 @@ public CommandPacket(final int sequenceId) {
* @return result packets to be sent
*/
public abstract CommandResponsePackets execute();

/**
* Has more result value.
*
* @return has more result value
*/
public abstract boolean hasMoreResultValue();

/**
* Get result value.
*
* @return result to be sent
*/
public abstract DatabaseProtocolPacket getResultValue();
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package io.shardingjdbc.proxy.transport.mysql.packet.command;

import io.shardingjdbc.proxy.transport.common.packet.DatabaseProtocolPacket;
import io.shardingjdbc.proxy.transport.mysql.packet.MySQLPacketPayload;
import io.shardingjdbc.proxy.transport.mysql.packet.generic.ErrPacket;

Expand Down Expand Up @@ -50,4 +51,14 @@ public CommandResponsePackets execute() {
@Override
public void write(final MySQLPacketPayload mysqlPacketPayload) {
}

@Override
public boolean hasMoreResultValue() {
return false;
}

@Override
public DatabaseProtocolPacket getResultValue() {
return null;
}
}
Loading

0 comments on commit 5c021b8

Please sign in to comment.