Skip to content

Commit

Permalink
[SYSTEMML-2057] New builtin functions for bitwise binary operators
Browse files Browse the repository at this point in the history
Closes apache#716.
  • Loading branch information
j143 authored and mboehm7 committed Jan 11, 2018
1 parent 9dc354a commit 7dbbaaa
Show file tree
Hide file tree
Showing 30 changed files with 1,075 additions and 40 deletions.
29 changes: 14 additions & 15 deletions src/main/java/org/apache/sysml/hops/BinaryOp.java
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ public Lop constructLops()
break;
}
default:
constructLopsBinaryDefault();
constructLopsBinaryDefault();
}

//add reblock/checkpoint lops if necessary
Expand Down Expand Up @@ -627,12 +627,11 @@ else if( op==OpOp2.MULT && right instanceof LiteralOp && ((LiteralOp)right).getD
ot = HopsOpOp2LopsU.get(op);

Unary unary1 = new Unary(getInput().get(0).constructLops(),
getInput().get(1).constructLops(), ot, getDataType(), getValueType(), et);
getInput().get(1).constructLops(), ot, getDataType(), getValueType(), et);

setOutputDimensions(unary1);
setLineNumbers(unary1);
setLops(unary1);

}
else
{
Expand Down Expand Up @@ -1597,18 +1596,18 @@ && getInput().get(0) == that2.getInput().get(0)
&& getInput().get(1) == that2.getInput().get(1));
}

public boolean supportsMatrixScalarOperations()
{
return ( op==OpOp2.PLUS ||op==OpOp2.MINUS
||op==OpOp2.MULT ||op==OpOp2.DIV
||op==OpOp2.MODULUS ||op==OpOp2.INTDIV
||op==OpOp2.LESS ||op==OpOp2.LESSEQUAL
||op==OpOp2.GREATER ||op==OpOp2.GREATEREQUAL
||op==OpOp2.EQUAL ||op==OpOp2.NOTEQUAL
||op==OpOp2.MIN ||op==OpOp2.MAX
||op==OpOp2.AND ||op==OpOp2.OR
||op == OpOp2.XOR
||op==OpOp2.LOG ||op==OpOp2.POW );
public boolean supportsMatrixScalarOperations() {
return ( op==OpOp2.PLUS ||op==OpOp2.MINUS
||op==OpOp2.MULT ||op==OpOp2.DIV
||op==OpOp2.MODULUS ||op==OpOp2.INTDIV
||op==OpOp2.LESS ||op==OpOp2.LESSEQUAL
||op==OpOp2.GREATER ||op==OpOp2.GREATEREQUAL
||op==OpOp2.EQUAL ||op==OpOp2.NOTEQUAL
||op==OpOp2.MIN ||op==OpOp2.MAX
||op==OpOp2.LOG ||op==OpOp2.POW
||op==OpOp2.AND ||op==OpOp2.OR ||op==OpOp2.XOR
||op==OpOp2.BW_AND ||op==OpOp2.BW_OR ||op==OpOp2.BW_XOR
||op==OpOp2.BW_SHIFTL ||op==OpOp2.BW_SHIFTR);
}

public boolean isPPredOperation()
Expand Down
46 changes: 37 additions & 9 deletions src/main/java/org/apache/sysml/hops/Hop.java
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
import org.apache.sysml.lops.LopsException;
import org.apache.sysml.lops.Nary;
import org.apache.sysml.lops.ReBlock;
import org.apache.sysml.lops.Unary;
import org.apache.sysml.lops.UnaryCP;
import org.apache.sysml.parser.Expression.DataType;
import org.apache.sysml.parser.Expression.ValueType;
Expand Down Expand Up @@ -311,12 +312,12 @@ private void constructAndSetReblockLopIfRequired()
&& ((DataOp)this).getInputFormatType() == FileFormatTypes.CSV )
{
reblock = new CSVReBlock( input, getRowsInBlock(), getColsInBlock(),
getDataType(), getValueType(), et);
getDataType(), getValueType(), et);
}
else //TEXT / MM / BINARYBLOCK / BINARYCELL
else //TEXT / MM / BINARYBLOCK / BINARYCELL
{
reblock = new ReBlock( input, getRowsInBlock(), getColsInBlock(),
getDataType(), getValueType(), _outputEmptyBlocks, et);
getDataType(), getValueType(), _outputEmptyBlocks, et);
}
}
catch( LopsException ex ) {
Expand Down Expand Up @@ -373,10 +374,10 @@ else if( !dimsKnown(true) ) {
}

//construct checkpoint w/ right storage level
Lop input = getLops();
Lop input = getLops();
Lop chkpoint = new Checkpoint(input, getDataType(), getValueType(),
serializedStorage ? Checkpoint.getSerializeStorageLevelString() :
Checkpoint.getDefaultStorageLevelString() );
Checkpoint.getDefaultStorageLevelString() );

setOutputDimensions( chkpoint );
setLineNumbers( chkpoint );
Expand Down Expand Up @@ -415,7 +416,7 @@ && getDataType()!=DataType.SCALAR )
{
try
{
Lop compress = new Compression(getLops(), getDataType(), getValueType(), et);
Lop compress = new Compression(getLops(), getDataType(), getValueType(), et);
setOutputDimensions( compress );
setLineNumbers( compress );
setLops( compress );
Expand Down Expand Up @@ -1067,6 +1068,7 @@ public enum OpOp2 {
MINUS_NZ, //sparse-safe minus: X-(mean*ppred(X,0,!=))
LOG_NZ, //sparse-safe log; ppred(X,0,"!=")*log(X,0.5)
MINUS1_MULT, //1-X*Y
BW_AND, BW_OR, BW_XOR, BW_SHIFTL, BW_SHIFTR, //bitwise operations
}

// Operations that require 3 operands
Expand Down Expand Up @@ -1208,6 +1210,11 @@ public enum Direction {
HopsOpOp2LopsB.put(OpOp2.SOLVE, Binary.OperationTypes.SOLVE);
HopsOpOp2LopsB.put(OpOp2.POW, Binary.OperationTypes.POW);
HopsOpOp2LopsB.put(OpOp2.LOG, Binary.OperationTypes.NOTSUPPORTED);
HopsOpOp2LopsB.put(OpOp2.BW_AND, Binary.OperationTypes.BW_AND);
HopsOpOp2LopsB.put(OpOp2.BW_OR, Binary.OperationTypes.BW_OR);
HopsOpOp2LopsB.put(OpOp2.BW_XOR, Binary.OperationTypes.BW_XOR);
HopsOpOp2LopsB.put(OpOp2.BW_SHIFTL, Binary.OperationTypes.BW_SHIFTL);
HopsOpOp2LopsB.put(OpOp2.BW_SHIFTR, Binary.OperationTypes.BW_SHIFTR);
}

protected static final HashMap<Hop.OpOp2, BinaryScalar.OperationTypes> HopsOpOp2LopsBS;
Expand All @@ -1233,6 +1240,11 @@ public enum Direction {
HopsOpOp2LopsBS.put(OpOp2.LOG, BinaryScalar.OperationTypes.LOG);
HopsOpOp2LopsBS.put(OpOp2.POW, BinaryScalar.OperationTypes.POW);
HopsOpOp2LopsBS.put(OpOp2.PRINT, BinaryScalar.OperationTypes.PRINT);
HopsOpOp2LopsBS.put(OpOp2.BW_AND, BinaryScalar.OperationTypes.BW_AND);
HopsOpOp2LopsBS.put(OpOp2.BW_OR, BinaryScalar.OperationTypes.BW_OR);
HopsOpOp2LopsBS.put(OpOp2.BW_XOR, BinaryScalar.OperationTypes.BW_XOR);
HopsOpOp2LopsBS.put(OpOp2.BW_SHIFTL, BinaryScalar.OperationTypes.BW_SHIFTL);
HopsOpOp2LopsBS.put(OpOp2.BW_SHIFTR, BinaryScalar.OperationTypes.BW_SHIFTR);
}

protected static final HashMap<Hop.OpOp2, org.apache.sysml.lops.Unary.OperationTypes> HopsOpOp2LopsU;
Expand All @@ -1251,14 +1263,20 @@ public enum Direction {
HopsOpOp2LopsU.put(OpOp2.GREATER, org.apache.sysml.lops.Unary.OperationTypes.GREATER_THAN);
HopsOpOp2LopsU.put(OpOp2.EQUAL, org.apache.sysml.lops.Unary.OperationTypes.EQUALS);
HopsOpOp2LopsU.put(OpOp2.NOTEQUAL, org.apache.sysml.lops.Unary.OperationTypes.NOT_EQUALS);
HopsOpOp2LopsU.put(OpOp2.AND, org.apache.sysml.lops.Unary.OperationTypes.NOTSUPPORTED);
HopsOpOp2LopsU.put(OpOp2.OR, org.apache.sysml.lops.Unary.OperationTypes.NOTSUPPORTED);
HopsOpOp2LopsU.put(OpOp2.AND, org.apache.sysml.lops.Unary.OperationTypes.AND);
HopsOpOp2LopsU.put(OpOp2.OR, org.apache.sysml.lops.Unary.OperationTypes.OR);
HopsOpOp2LopsU.put(OpOp2.XOR, org.apache.sysml.lops.Unary.OperationTypes.XOR);
HopsOpOp2LopsU.put(OpOp2.MAX, org.apache.sysml.lops.Unary.OperationTypes.MAX);
HopsOpOp2LopsU.put(OpOp2.MIN, org.apache.sysml.lops.Unary.OperationTypes.MIN);
HopsOpOp2LopsU.put(OpOp2.LOG, org.apache.sysml.lops.Unary.OperationTypes.LOG);
HopsOpOp2LopsU.put(OpOp2.POW, org.apache.sysml.lops.Unary.OperationTypes.POW);
HopsOpOp2LopsU.put(OpOp2.MINUS_NZ, org.apache.sysml.lops.Unary.OperationTypes.SUBTRACT_NZ);
HopsOpOp2LopsU.put(OpOp2.LOG_NZ, org.apache.sysml.lops.Unary.OperationTypes.LOG_NZ);
HopsOpOp2LopsU.put(OpOp2.BW_AND, Unary.OperationTypes.BW_AND);
HopsOpOp2LopsU.put(OpOp2.BW_OR, Unary.OperationTypes.BW_OR);
HopsOpOp2LopsU.put(OpOp2.BW_XOR, Unary.OperationTypes.BW_XOR);
HopsOpOp2LopsU.put(OpOp2.BW_SHIFTL, Unary.OperationTypes.BW_SHIFTL);
HopsOpOp2LopsU.put(OpOp2.BW_SHIFTR, Unary.OperationTypes.BW_SHIFTR);
}

protected static final HashMap<Hop.OpOp1, org.apache.sysml.lops.Unary.OperationTypes> HopsOpOp1LopsU;
Expand Down Expand Up @@ -1429,6 +1447,11 @@ public enum Direction {
HopsOpOp2String.put(OpOp2.RBIND, "rbind");
HopsOpOp2String.put(OpOp2.SOLVE, "solve");
HopsOpOp2String.put(OpOp2.XOR, "xor");
HopsOpOp2String.put(OpOp2.BW_AND, "bitwAnd");
HopsOpOp2String.put(OpOp2.BW_OR, "bitwOr");
HopsOpOp2String.put(OpOp2.BW_XOR, "bitwXor");
HopsOpOp2String.put(OpOp2.BW_SHIFTL, "bitwShiftL");
HopsOpOp2String.put(OpOp2.BW_SHIFTR, "bitwShiftR");
}

public static String getBinaryOpCode(OpOp2 op) {
Expand Down Expand Up @@ -1519,8 +1542,13 @@ public static OpOp2 getOpOp2ForOuterVectorOperation(String op)
else if( "&".equals(op) ) return OpOp2.AND;
else if( "log".equals(op) ) return OpOp2.LOG;
else if( "^".equals(op) ) return OpOp2.POW;
else if("bitwAnd".equals(op) ) return OpOp2.BW_AND;
else if("bitwOr".equals(op) ) return OpOp2.BW_OR;
else if("bitwXor".equals(op) ) return OpOp2.BW_XOR;
else if("bitwShiftL".equals(op) ) return OpOp2.BW_SHIFTL;
else if("bitwShiftR".equals(op) ) return OpOp2.BW_SHIFTR;

return null;
return null;
}

/////////////////////////////////////
Expand Down
14 changes: 12 additions & 2 deletions src/main/java/org/apache/sysml/lops/Binary.java
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ public enum OperationTypes {
ADD, SUBTRACT, MULTIPLY, DIVIDE, MINUS1_MULTIPLY, MODULUS, INTDIV, MATMULT,
LESS_THAN, LESS_THAN_OR_EQUALS, GREATER_THAN, GREATER_THAN_OR_EQUALS, EQUALS, NOT_EQUALS,
AND, OR, XOR,
MAX, MIN, POW, SOLVE, NOTSUPPORTED
MAX, MIN, POW, SOLVE, NOTSUPPORTED,
BW_AND, BW_OR, BW_XOR, BW_SHIFTL, BW_SHIFTR, //Bitwise operations
}

private OperationTypes operation;
Expand Down Expand Up @@ -160,7 +161,16 @@ public static String getOpcode( OperationTypes op ) {
/* Binary Builtin Function */
case XOR:
return "xor";

case BW_AND:
return "bitwAnd";
case BW_OR:
return "bitwOr";
case BW_XOR:
return "bitwXor";
case BW_SHIFTL:
return "bitwShiftL";
case BW_SHIFTR:
return "bitwShiftR";

/* Builtin Functions */
case MIN:
Expand Down
16 changes: 13 additions & 3 deletions src/main/java/org/apache/sysml/lops/BinaryScalar.java
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ public enum OperationTypes {
ADD, SUBTRACT, MULTIPLY, DIVIDE, MODULUS, INTDIV,
LESS_THAN, LESS_THAN_OR_EQUALS, GREATER_THAN, GREATER_THAN_OR_EQUALS, EQUALS, NOT_EQUALS,
AND, OR, XOR,
LOG,POW,MAX,MIN,PRINT,
IQSIZE,
LOG,POW,MAX,MIN,PRINT,IQSIZE,
BW_AND, BW_OR, BW_XOR, BW_SHIFTL, BW_SHIFTR, //Bitwise operations
}

private final OperationTypes operation;
Expand Down Expand Up @@ -149,7 +149,17 @@ public static String getOpcode( OperationTypes op )
/* Boolean built in binary function */
case XOR:
return "xor";

case BW_AND:
return "bitwAnd";
case BW_OR:
return "bitwOr";
case BW_XOR:
return "bitwXor";
case BW_SHIFTL:
return "bitwShiftL";
case BW_SHIFTR:
return "bitwShiftR";

/* Builtin Functions */
case LOG:
return "log";
Expand Down
20 changes: 15 additions & 5 deletions src/main/java/org/apache/sysml/lops/Unary.java
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ public class Unary extends Lop
public enum OperationTypes {
ADD, SUBTRACT, SUBTRACTRIGHT, MULTIPLY, MULTIPLY2, DIVIDE, MODULUS, INTDIV, MINUS1_MULTIPLY,
POW, POW2, LOG, MAX, MIN, NOT, ABS, SIN, COS, TAN, ASIN, ACOS, ATAN, SINH, COSH, TANH, SIGN, SQRT, EXP, Over,
LESS_THAN, LESS_THAN_OR_EQUALS, GREATER_THAN, GREATER_THAN_OR_EQUALS, EQUALS, NOT_EQUALS,
LESS_THAN, LESS_THAN_OR_EQUALS, GREATER_THAN, GREATER_THAN_OR_EQUALS, EQUALS, NOT_EQUALS,
AND, OR, XOR, BW_AND, BW_OR, BW_XOR, BW_SHIFTL, BW_SHIFTR,
ROUND, CEIL, FLOOR, MR_IQM, INVERSE, CHOLESKY,
CUMSUM, CUMPROD, CUMMIN, CUMMAX,
SPROP, SIGMOID, SELP, SUBTRACT_NZ, LOG_NZ,
Expand Down Expand Up @@ -224,7 +225,7 @@ public static String getOpcode(OperationTypes op)

case SUBTRACT_NZ:
return "-nz";

case SUBTRACTRIGHT:
return "s-r";

Expand All @@ -244,7 +245,7 @@ public static String getOpcode(OperationTypes op)
return "%%";

case INTDIV:
return "%/%";
return "%/%";

case Over:
return "so";
Expand All @@ -253,7 +254,7 @@ public static String getOpcode(OperationTypes op)
return "^";

case POW2:
return "^2";
return "^2";

case GREATER_THAN:
return ">";
Expand Down Expand Up @@ -320,7 +321,16 @@ public static String getOpcode(OperationTypes op)

case CAST_AS_FRAME:
return UnaryCP.CAST_AS_FRAME_OPCODE;


case AND: return "&&";
case OR: return "||";
case XOR: return "xor";
case BW_AND: return "bitwAnd";
case BW_OR: return "bitwOr";
case BW_XOR: return "bitwXor";
case BW_SHIFTL: return "bitwShiftL";
case BW_SHIFTR: return "bitwShiftR";

default:
throw new LopsException(
"Instruction not defined for Unary operation: " + op);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -447,6 +447,11 @@ public void validateExpression(HashMap<String, DataIdentifier> ids, HashMap<Stri
break;

case XOR:
case BITWISE_AND:
case BITWISE_OR:
case BITWISE_XOR:
case BITWISE_SHIFTL:
case BITWISE_SHIFTR:
case MIN:
case MAX:
//min(X), min(X,s), min(s,X), min(s,r), min(X,Y)
Expand Down Expand Up @@ -1344,6 +1349,11 @@ private boolean isMathFunction() {
case FLOOR:
case MEDIAN:
case XOR:
case BITWISE_AND:
case BITWISE_OR:
case BITWISE_XOR:
case BITWISE_SHIFTL:
case BITWISE_SHIFTR:
return true;
default:
return false;
Expand Down Expand Up @@ -1737,6 +1747,16 @@ else if ( functionName.equals("outer") )
bifop = Expression.BuiltinFunctionOp.OUTER;
else if ( functionName.equals("xor") )
bifop = Expression.BuiltinFunctionOp.XOR;
else if ( functionName.equals("bitwAnd") )
bifop = Expression.BuiltinFunctionOp.BITWISE_AND;
else if ( functionName.equals("bitwOr") )
bifop = Expression.BuiltinFunctionOp.BITWISE_OR;
else if ( functionName.equals("bitwXor") )
bifop = Expression.BuiltinFunctionOp.BITWISE_XOR;
else if ( functionName.equals("bitwShiftL") )
bifop = Expression.BuiltinFunctionOp.BITWISE_SHIFTL;
else if ( functionName.equals("bitwShiftR") )
bifop = Expression.BuiltinFunctionOp.BITWISE_SHIFTR;
else
return null;

Expand Down
21 changes: 21 additions & 0 deletions src/main/java/org/apache/sysml/parser/DMLTranslator.java
Original file line number Diff line number Diff line change
Expand Up @@ -2662,10 +2662,31 @@ else if ( sop.equalsIgnoreCase("!=") )
currBuiltinOp = new UnaryOp(target.getName(), target.getDataType(), ValueType.BOOLEAN, Hop.OpOp1.CAST_AS_BOOLEAN, expr);
break;

// Boolean binary
case XOR:
currBuiltinOp = new BinaryOp(target.getName(), target.getDataType(),
target.getValueType(), Hop.OpOp2.XOR, expr, expr2);
break;
case BITWISE_AND:
currBuiltinOp = new BinaryOp(target.getName(), target.getDataType(),
target.getValueType(), OpOp2.BW_AND, expr, expr2);
break;
case BITWISE_OR:
currBuiltinOp = new BinaryOp(target.getName(), target.getDataType(),
target.getValueType(), OpOp2.BW_OR, expr, expr2);
break;
case BITWISE_XOR:
currBuiltinOp = new BinaryOp(target.getName(), target.getDataType(),
target.getValueType(), OpOp2.BW_XOR, expr, expr2);
break;
case BITWISE_SHIFTL:
currBuiltinOp = new BinaryOp(target.getName(), target.getDataType(),
target.getValueType(), OpOp2.BW_SHIFTL, expr, expr2);
break;
case BITWISE_SHIFTR:
currBuiltinOp = new BinaryOp(target.getName(), target.getDataType(),
target.getValueType(), OpOp2.BW_SHIFTR, expr, expr2);
break;

case ABS:
case SIN:
Expand Down
7 changes: 6 additions & 1 deletion src/main/java/org/apache/sysml/parser/Expression.java
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,12 @@ public enum BuiltinFunctionOp {
TRACE,
TRANS,
VAR,
XOR
XOR,
BITWISE_AND,
BITWISE_OR,
BITWISE_XOR,
BITWISE_SHIFTL,
BITWISE_SHIFTR,
}

/**
Expand Down
Loading

0 comments on commit 7dbbaaa

Please sign in to comment.