Skip to content

Commit

Permalink
Validate char/varchar values read in JDBC connectors
Browse files Browse the repository at this point in the history
  • Loading branch information
findepi committed Sep 19, 2020
1 parent b055601 commit 27c4839
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -366,9 +366,10 @@ protected Optional<ColumnMapping> getForcedMappingToVarchar(JdbcTypeHandle typeH

protected static Optional<ColumnMapping> mapToUnboundedVarchar(JdbcTypeHandle typeHandle)
{
VarcharType unboundedVarcharType = createUnboundedVarcharType();
return Optional.of(ColumnMapping.sliceMapping(
createUnboundedVarcharType(),
varcharReadFunction(),
unboundedVarcharType,
varcharReadFunction(unboundedVarcharType),
(statement, index, value) -> {
throw new PrestoException(
NOT_SUPPORTED,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,12 @@
import com.google.common.base.CharMatcher;
import com.google.common.primitives.Shorts;
import com.google.common.primitives.SignedBytes;
import io.airlift.slice.Slice;
import io.prestosql.spi.type.CharType;
import io.prestosql.spi.type.DecimalType;
import io.prestosql.spi.type.Decimals;
import io.prestosql.spi.type.TimestampType;
import io.prestosql.spi.type.Type;
import io.prestosql.spi.type.VarcharType;
import org.joda.time.DateTimeZone;
import org.joda.time.chrono.ISOChronology;
Expand All @@ -40,6 +42,8 @@
import java.util.Optional;

import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.io.BaseEncoding.base16;
import static io.airlift.slice.SliceUtf8.countCodePoints;
import static io.airlift.slice.Slices.utf8Slice;
import static io.airlift.slice.Slices.wrappedBuffer;
import static io.prestosql.plugin.jdbc.ColumnMapping.DISABLE_PUSHDOWN;
Expand Down Expand Up @@ -77,6 +81,7 @@
import static java.lang.Math.max;
import static java.lang.Math.min;
import static java.lang.Math.toIntExact;
import static java.lang.String.format;
import static java.math.RoundingMode.UNNECESSARY;
import static java.time.ZoneOffset.UTC;
import static java.util.Objects.requireNonNull;
Expand Down Expand Up @@ -199,12 +204,17 @@ public static SliceWriteFunction longDecimalWriteFunction(DecimalType decimalTyp
public static ColumnMapping charColumnMapping(CharType charType)
{
requireNonNull(charType, "charType is null");
return ColumnMapping.sliceMapping(charType, charReadFunction(), charWriteFunction());
return ColumnMapping.sliceMapping(charType, charReadFunction(charType), charWriteFunction());
}

public static SliceReadFunction charReadFunction()
public static SliceReadFunction charReadFunction(CharType charType)
{
return (resultSet, columnIndex) -> utf8Slice(CharMatcher.is(' ').trimTrailingFrom(resultSet.getString(columnIndex)));
requireNonNull(charType, "charType is null");
return (resultSet, columnIndex) -> {
Slice slice = utf8Slice(CharMatcher.is(' ').trimTrailingFrom(resultSet.getString(columnIndex)));
checkLengthInCodePoints(slice, charType, charType.getLength());
return slice;
};
}

public static SliceWriteFunction charWriteFunction()
Expand All @@ -216,12 +226,37 @@ public static SliceWriteFunction charWriteFunction()

public static ColumnMapping varcharColumnMapping(VarcharType varcharType)
{
return ColumnMapping.sliceMapping(varcharType, varcharReadFunction(), varcharWriteFunction());
return ColumnMapping.sliceMapping(varcharType, varcharReadFunction(varcharType), varcharWriteFunction());
}

public static SliceReadFunction varcharReadFunction()
public static SliceReadFunction varcharReadFunction(VarcharType varcharType)
{
return (resultSet, columnIndex) -> utf8Slice(resultSet.getString(columnIndex));
requireNonNull(varcharType, "varcharType is null");
if (varcharType.isUnbounded()) {
return (resultSet, columnIndex) -> utf8Slice(resultSet.getString(columnIndex));
}
return (resultSet, columnIndex) -> {
Slice slice = utf8Slice(resultSet.getString(columnIndex));
checkLengthInCodePoints(slice, varcharType, varcharType.getBoundedLength());
return slice;
};
}

private static void checkLengthInCodePoints(Slice value, Type characterDataType, int lengthLimit)
{
// Quick check in bytes
if (value.length() <= lengthLimit) {
return;
}
// Actual check
if (countCodePoints(value) <= lengthLimit) {
return;
}
throw new IllegalStateException(format(
"Illegal value for type %s: '%s' [%s]",
characterDataType,
value.toStringUtf8(),
base16().encode(value.getBytes())));
}

public static SliceWriteFunction varcharWriteFunction()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ public Optional<ColumnMapping> toPrestoType(ConnectorSession session, Connection
VarcharType varcharType = (columnSize <= VarcharType.MAX_LENGTH) ? createVarcharType(columnSize) : createUnboundedVarcharType();
// Remote database can be case insensitive.
PredicatePushdownController predicatePushdownController = PUSHDOWN_AND_KEEP;
return Optional.of(ColumnMapping.sliceMapping(varcharType, varcharReadFunction(), varcharWriteFunction(), predicatePushdownController));
return Optional.of(ColumnMapping.sliceMapping(varcharType, varcharReadFunction(varcharType), varcharWriteFunction(), predicatePushdownController));

case Types.DECIMAL:
int precision = columnSize;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,7 @@ else if (precision > Decimals.MAX_PRECISION || columnSize <= 0) {
CharType charType = createCharType(columnSize);
return Optional.of(ColumnMapping.sliceMapping(
charType,
charReadFunction(),
charReadFunction(charType),
oracleCharWriteFunction(charType),
OracleClient::fullPushdownIfSupported));

Expand Down

0 comments on commit 27c4839

Please sign in to comment.