Skip to content

Commit

Permalink
[SPARK-13043][SQL] Implement remaining catalyst types in ColumnarBatch.
Browse files Browse the repository at this point in the history
This includes: float, boolean, short, decimal and calendar interval.

Decimal is mapped to long or byte array depending on the size and calendar
interval is mapped to a struct of int and long.

The only remaining type is map. The schema mapping is straightforward but
we might want to revisit how we deal with this in the rest of the execution
engine.

Author: Nong Li <[email protected]>

Closes apache#10961 from nongli/spark-13043.
  • Loading branch information
nongli authored and rxin committed Feb 1, 2016
1 parent c9b89a0 commit 064b029
Show file tree
Hide file tree
Showing 8 changed files with 484 additions and 42 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,28 @@ object DecimalType extends AbstractDataType {
}
}

/**
* Returns if dt is a DecimalType that fits inside a long
*/
def is64BitDecimalType(dt: DataType): Boolean = {
dt match {
case t: DecimalType =>
t.precision <= Decimal.MAX_LONG_DIGITS
case _ => false
}
}

/**
* Returns if dt is a DecimalType that doesn't fit inside a long
*/
def isByteArrayDecimalType(dt: DataType): Boolean = {
dt match {
case t: DecimalType =>
t.precision > Decimal.MAX_LONG_DIGITS
case _ => false
}
}

def unapply(t: DataType): Boolean = t.isInstanceOf[DecimalType]

def unapply(e: Expression): Boolean = e.dataType.isInstanceOf[DecimalType]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@
*/
package org.apache.spark.sql.execution.vectorized;

import java.math.BigDecimal;
import java.math.BigInteger;

import org.apache.spark.memory.MemoryMode;
import org.apache.spark.sql.catalyst.InternalRow;
import org.apache.spark.sql.catalyst.util.ArrayData;
Expand Down Expand Up @@ -102,18 +105,36 @@ public Object[] array() {
DataType dt = data.dataType();
Object[] list = new Object[length];

if (dt instanceof ByteType) {
if (dt instanceof BooleanType) {
for (int i = 0; i < length; i++) {
if (!data.getIsNull(offset + i)) {
list[i] = data.getBoolean(offset + i);
}
}
} else if (dt instanceof ByteType) {
for (int i = 0; i < length; i++) {
if (!data.getIsNull(offset + i)) {
list[i] = data.getByte(offset + i);
}
}
} else if (dt instanceof ShortType) {
for (int i = 0; i < length; i++) {
if (!data.getIsNull(offset + i)) {
list[i] = data.getShort(offset + i);
}
}
} else if (dt instanceof IntegerType) {
for (int i = 0; i < length; i++) {
if (!data.getIsNull(offset + i)) {
list[i] = data.getInt(offset + i);
}
}
} else if (dt instanceof FloatType) {
for (int i = 0; i < length; i++) {
if (!data.getIsNull(offset + i)) {
list[i] = data.getFloat(offset + i);
}
}
} else if (dt instanceof DoubleType) {
for (int i = 0; i < length; i++) {
if (!data.getIsNull(offset + i)) {
Expand All @@ -126,12 +147,25 @@ public Object[] array() {
list[i] = data.getLong(offset + i);
}
}
} else if (dt instanceof DecimalType) {
DecimalType decType = (DecimalType)dt;
for (int i = 0; i < length; i++) {
if (!data.getIsNull(offset + i)) {
list[i] = getDecimal(i, decType.precision(), decType.scale());
}
}
} else if (dt instanceof StringType) {
for (int i = 0; i < length; i++) {
if (!data.getIsNull(offset + i)) {
list[i] = ColumnVectorUtils.toString(data.getByteArray(offset + i));
}
}
} else if (dt instanceof CalendarIntervalType) {
for (int i = 0; i < length; i++) {
if (!data.getIsNull(offset + i)) {
list[i] = getInterval(i);
}
}
} else {
throw new NotImplementedException("Type " + dt);
}
Expand Down Expand Up @@ -170,7 +204,14 @@ public float getFloat(int ordinal) {

@Override
public Decimal getDecimal(int ordinal, int precision, int scale) {
throw new NotImplementedException();
if (precision <= Decimal.MAX_LONG_DIGITS()) {
return Decimal.apply(getLong(ordinal), precision, scale);
} else {
byte[] bytes = getBinary(ordinal);
BigInteger bigInteger = new BigInteger(bytes);
BigDecimal javaDecimal = new BigDecimal(bigInteger, scale);
return Decimal.apply(javaDecimal, precision, scale);
}
}

@Override
Expand All @@ -181,17 +222,22 @@ public UTF8String getUTF8String(int ordinal) {

@Override
public byte[] getBinary(int ordinal) {
throw new NotImplementedException();
ColumnVector.Array array = data.getByteArray(offset + ordinal);
byte[] bytes = new byte[array.length];
System.arraycopy(array.byteArray, array.byteArrayOffset, bytes, 0, bytes.length);
return bytes;
}

@Override
public CalendarInterval getInterval(int ordinal) {
throw new NotImplementedException();
int month = data.getChildColumn(0).getInt(offset + ordinal);
long microseconds = data.getChildColumn(1).getLong(offset + ordinal);
return new CalendarInterval(month, microseconds);
}

@Override
public InternalRow getStruct(int ordinal, int numFields) {
throw new NotImplementedException();
return data.getStruct(offset + ordinal);
}

@Override
Expand Down Expand Up @@ -279,6 +325,21 @@ public void reset() {
*/
public abstract boolean getIsNull(int rowId);

/**
* Sets the value at rowId to `value`.
*/
public abstract void putBoolean(int rowId, boolean value);

/**
* Sets values from [rowId, rowId + count) to value.
*/
public abstract void putBooleans(int rowId, int count, boolean value);

/**
* Returns the value for rowId.
*/
public abstract boolean getBoolean(int rowId);

/**
* Sets the value at rowId to `value`.
*/
Expand All @@ -299,6 +360,26 @@ public void reset() {
*/
public abstract byte getByte(int rowId);

/**
* Sets the value at rowId to `value`.
*/
public abstract void putShort(int rowId, short value);

/**
* Sets values from [rowId, rowId + count) to value.
*/
public abstract void putShorts(int rowId, int count, short value);

/**
* Sets values from [rowId, rowId + count) to [src + srcIndex, src + srcIndex + count)
*/
public abstract void putShorts(int rowId, int count, short[] src, int srcIndex);

/**
* Returns the value for rowId.
*/
public abstract short getShort(int rowId);

/**
* Sets the value at rowId to `value`.
*/
Expand Down Expand Up @@ -351,6 +432,33 @@ public void reset() {
*/
public abstract long getLong(int rowId);

/**
* Sets the value at rowId to `value`.
*/
public abstract void putFloat(int rowId, float value);

/**
* Sets values from [rowId, rowId + count) to value.
*/
public abstract void putFloats(int rowId, int count, float value);

/**
* Sets values from [rowId, rowId + count) to [src + srcIndex, src + srcIndex + count)
* src should contain `count` doubles written as ieee format.
*/
public abstract void putFloats(int rowId, int count, float[] src, int srcIndex);

/**
* Sets values from [rowId, rowId + count) to [src[srcIndex], src[srcIndex + count])
* The data in src must be ieee formatted floats.
*/
public abstract void putFloats(int rowId, int count, byte[] src, int srcIndex);

/**
* Returns the value for rowId.
*/
public abstract float getFloat(int rowId);

/**
* Sets the value at rowId to `value`.
*/
Expand All @@ -369,7 +477,7 @@ public void reset() {

/**
* Sets values from [rowId, rowId + count) to [src[srcIndex], src[srcIndex + count])
* The data in src must be ieee formated doubles.
* The data in src must be ieee formatted doubles.
*/
public abstract void putDoubles(int rowId, int count, byte[] src, int srcIndex);

Expand Down Expand Up @@ -469,6 +577,20 @@ public final int appendNotNulls(int count) {
return result;
}

public final int appendBoolean(boolean v) {
reserve(elementsAppended + 1);
putBoolean(elementsAppended, v);
return elementsAppended++;
}

public final int appendBooleans(int count, boolean v) {
reserve(elementsAppended + count);
int result = elementsAppended;
putBooleans(elementsAppended, count, v);
elementsAppended += count;
return result;
}

public final int appendByte(byte v) {
reserve(elementsAppended + 1);
putByte(elementsAppended, v);
Expand All @@ -491,6 +613,28 @@ public final int appendBytes(int length, byte[] src, int offset) {
return result;
}

public final int appendShort(short v) {
reserve(elementsAppended + 1);
putShort(elementsAppended, v);
return elementsAppended++;
}

public final int appendShorts(int count, short v) {
reserve(elementsAppended + count);
int result = elementsAppended;
putShorts(elementsAppended, count, v);
elementsAppended += count;
return result;
}

public final int appendShorts(int length, short[] src, int offset) {
reserve(elementsAppended + length);
int result = elementsAppended;
putShorts(elementsAppended, length, src, offset);
elementsAppended += length;
return result;
}

public final int appendInt(int v) {
reserve(elementsAppended + 1);
putInt(elementsAppended, v);
Expand Down Expand Up @@ -535,6 +679,20 @@ public final int appendLongs(int length, long[] src, int offset) {
return result;
}

public final int appendFloat(float v) {
reserve(elementsAppended + 1);
putFloat(elementsAppended, v);
return elementsAppended++;
}

public final int appendFloats(int count, float v) {
reserve(elementsAppended + count);
int result = elementsAppended;
putFloats(elementsAppended, count, v);
elementsAppended += count;
return result;
}

public final int appendDouble(double v) {
reserve(elementsAppended + 1);
putDouble(elementsAppended, v);
Expand Down Expand Up @@ -661,7 +819,8 @@ protected ColumnVector(int capacity, DataType type, MemoryMode memMode) {
this.capacity = capacity;
this.type = type;

if (type instanceof ArrayType || type instanceof BinaryType || type instanceof StringType) {
if (type instanceof ArrayType || type instanceof BinaryType || type instanceof StringType
|| DecimalType.isByteArrayDecimalType(type)) {
DataType childType;
int childCapacity = capacity;
if (type instanceof ArrayType) {
Expand All @@ -682,6 +841,13 @@ protected ColumnVector(int capacity, DataType type, MemoryMode memMode) {
}
this.resultArray = null;
this.resultStruct = new ColumnarBatch.Row(this.childColumns);
} else if (type instanceof CalendarIntervalType) {
// Two columns. Months as int. Microseconds as Long.
this.childColumns = new ColumnVector[2];
this.childColumns[0] = ColumnVector.allocate(capacity, DataTypes.IntegerType, memMode);
this.childColumns[1] = ColumnVector.allocate(capacity, DataTypes.LongType, memMode);
this.resultArray = null;
this.resultStruct = new ColumnarBatch.Row(this.childColumns);
} else {
this.childColumns = null;
this.resultArray = null;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,15 @@
*/
package org.apache.spark.sql.execution.vectorized;

import java.math.BigDecimal;
import java.math.BigInteger;
import java.util.Iterator;
import java.util.List;

import org.apache.spark.memory.MemoryMode;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.types.*;
import org.apache.spark.unsafe.types.CalendarInterval;

import org.apache.commons.lang.NotImplementedException;

Expand Down Expand Up @@ -59,19 +62,44 @@ public static Object toPrimitiveJavaArray(ColumnVector.Array array) {

private static void appendValue(ColumnVector dst, DataType t, Object o) {
if (o == null) {
dst.appendNull();
if (t instanceof CalendarIntervalType) {
dst.appendStruct(true);
} else {
dst.appendNull();
}
} else {
if (t == DataTypes.ByteType) {
dst.appendByte(((Byte)o).byteValue());
if (t == DataTypes.BooleanType) {
dst.appendBoolean(((Boolean)o).booleanValue());
} else if (t == DataTypes.ByteType) {
dst.appendByte(((Byte) o).byteValue());
} else if (t == DataTypes.ShortType) {
dst.appendShort(((Short)o).shortValue());
} else if (t == DataTypes.IntegerType) {
dst.appendInt(((Integer)o).intValue());
} else if (t == DataTypes.LongType) {
dst.appendLong(((Long)o).longValue());
} else if (t == DataTypes.FloatType) {
dst.appendFloat(((Float)o).floatValue());
} else if (t == DataTypes.DoubleType) {
dst.appendDouble(((Double)o).doubleValue());
} else if (t == DataTypes.StringType) {
byte[] b =((String)o).getBytes();
dst.appendByteArray(b, 0, b.length);
} else if (t instanceof DecimalType) {
DecimalType dt = (DecimalType)t;
Decimal d = Decimal.apply((BigDecimal)o, dt.precision(), dt.scale());
if (dt.precision() <= Decimal.MAX_LONG_DIGITS()) {
dst.appendLong(d.toUnscaledLong());
} else {
final BigInteger integer = d.toJavaBigDecimal().unscaledValue();
byte[] bytes = integer.toByteArray();
dst.appendByteArray(bytes, 0, bytes.length);
}
} else if (t instanceof CalendarIntervalType) {
CalendarInterval c = (CalendarInterval)o;
dst.appendStruct(false);
dst.getChildColumn(0).appendInt(c.months);
dst.getChildColumn(1).appendLong(c.microseconds);
} else {
throw new NotImplementedException("Type " + t);
}
Expand Down
Loading

0 comments on commit 064b029

Please sign in to comment.