Skip to content

Commit

Permalink
[SYSTEMML-1876] Fix size propagation QuaternaryOp wdivmm (stratstats)
Browse files Browse the repository at this point in the history
This patch fixes the worst-case size propagation (i.e., the propagation
of upper bounds instead of exact size information) for quaternary
operators of type wdivmm basic (cellwise). There was a branching issue
that led to output sizes for wdivmm basic being overwritten by output
sizes for wdivmm right. The issue did not show up before because it only
happens in the special case where the sizes (at least the dimensions) of
inputs cannot be inferred exactly.

Furthermore, this patch also includes some minor cleanups related to the
common code of MR lop construction.
  • Loading branch information
mboehm7 committed Sep 1, 2017
1 parent 55b7342 commit ec35215
Showing 1 changed file with 33 additions and 37 deletions.
70 changes: 33 additions & 37 deletions src/main/java/org/apache/sysml/hops/QuaternaryOp.java
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@
*/
public class QuaternaryOp extends Hop implements MultiThreadedHop
{

//config influencing mr operator selection (for testing purposes only)
public static boolean FORCE_REPLICATION = false;

Expand Down Expand Up @@ -321,7 +320,7 @@ private void constructCPLopsWeightedSquaredLoss(WeightsType wtype)
setLops( wsloss );
}

private Lop obtainlU(Hop U, Hop V, boolean cacheU, double m1Size) throws HopsException, LopsException {
private Lop constructLeftFactorMRLop(Hop U, Hop V, boolean cacheU, double m1Size) throws HopsException, LopsException {
Lop lU = null;
if (cacheU) {
// partitioning of U for read through distributed cache
Expand All @@ -331,28 +330,26 @@ private Lop obtainlU(Hop U, Hop V, boolean cacheU, double m1Size) throws HopsExc
lU = new DataPartition(lU, DataType.MATRIX, ValueType.DOUBLE,
(m1Size > OptimizerUtils.getLocalMemBudget()) ? ExecType.MR : ExecType.CP,
PDataPartitionFormat.ROW_BLOCK_WISE_N);
lU.getOutputParameters().setDimensions(U.getDim1(), U.getDim2(), getRowsInBlock(), getColsInBlock(),
U.getNnz());
lU.getOutputParameters().setDimensions(U.getDim1(), U.getDim2(), getRowsInBlock(), getColsInBlock(), U.getNnz());
setLineNumbers(lU);
}
} else {
}
else {
// replication of U for shuffle to target block
Lop offset = createOffsetLop(V, false); // ncol of t(V) -> nrow of V determines num replicates
lU = new RepMat(U.constructLops(), offset, true, V.getDataType(), V.getValueType());
lU.getOutputParameters().setDimensions(U.getDim1(), U.getDim2(), U.getRowsInBlock(), U.getColsInBlock(),
U.getNnz());
lU.getOutputParameters().setDimensions(U.getDim1(), U.getDim2(), U.getRowsInBlock(), U.getColsInBlock(), U.getNnz());
setLineNumbers(lU);

Group grpU = new Group(lU, Group.OperationTypes.Sort, DataType.MATRIX, ValueType.DOUBLE);
grpU.getOutputParameters().setDimensions(U.getDim1(), U.getDim2(), U.getRowsInBlock(), U.getColsInBlock(),
-1);
grpU.getOutputParameters().setDimensions(U.getDim1(), U.getDim2(), U.getRowsInBlock(), U.getColsInBlock(), -1);
setLineNumbers(grpU);
lU = grpU;
}
return lU;
}

private Lop obtainlV(Hop U, Hop V, boolean cacheV, double m2Size) throws HopsException, LopsException {
private Lop constructRightFactorMRLop(Hop U, Hop V, boolean cacheV, double m2Size) throws HopsException, LopsException {
Lop lV = null;
if (cacheV) {
// partitioning of V for read through distributed cache
Expand All @@ -362,27 +359,24 @@ private Lop obtainlV(Hop U, Hop V, boolean cacheV, double m2Size) throws HopsExc
lV = new DataPartition(lV, DataType.MATRIX, ValueType.DOUBLE,
(m2Size > OptimizerUtils.getLocalMemBudget()) ? ExecType.MR : ExecType.CP,
PDataPartitionFormat.ROW_BLOCK_WISE_N);
lV.getOutputParameters().setDimensions(V.getDim1(), V.getDim2(), getRowsInBlock(), getColsInBlock(),
V.getNnz());
lV.getOutputParameters().setDimensions(V.getDim1(), V.getDim2(), getRowsInBlock(), getColsInBlock(), V.getNnz());
setLineNumbers(lV);
}
} else {
}
else {
// replication of t(V) for shuffle to target block
Transform ltV = new Transform(V.constructLops(), HopsTransf2Lops.get(ReOrgOp.TRANSPOSE), getDataType(),
getValueType(), ExecType.MR);
ltV.getOutputParameters().setDimensions(V.getDim2(), V.getDim1(), V.getColsInBlock(), V.getRowsInBlock(),
V.getNnz());
ltV.getOutputParameters().setDimensions(V.getDim2(), V.getDim1(), V.getColsInBlock(), V.getRowsInBlock(), V.getNnz());
setLineNumbers(ltV);

Lop offset = createOffsetLop(U, false); // nrow of U determines num replicates
lV = new RepMat(ltV, offset, false, V.getDataType(), V.getValueType());
lV.getOutputParameters().setDimensions(V.getDim2(), V.getDim1(), V.getColsInBlock(), V.getRowsInBlock(),
V.getNnz());
lV.getOutputParameters().setDimensions(V.getDim2(), V.getDim1(), V.getColsInBlock(), V.getRowsInBlock(), V.getNnz());
setLineNumbers(lV);

Group grpV = new Group(lV, Group.OperationTypes.Sort, DataType.MATRIX, ValueType.DOUBLE);
grpV.getOutputParameters().setDimensions(V.getDim2(), V.getDim1(), V.getColsInBlock(), V.getRowsInBlock(),
-1);
grpV.getOutputParameters().setDimensions(V.getDim2(), V.getDim1(), V.getColsInBlock(), V.getRowsInBlock(), -1);
setLineNumbers(grpV);
lV = grpV;
}
Expand Down Expand Up @@ -464,8 +458,8 @@ private void constructMRLopsWeightedSquaredLoss(WeightsType wtype)
setLineNumbers(grpW);
}

Lop lU = obtainlU(U, V, cacheU, m1Size);
Lop lV = obtainlV(U, V, cacheV, m2Size);
Lop lU = constructLeftFactorMRLop(U, V, cacheU, m1Size);
Lop lV = constructRightFactorMRLop(U, V, cacheV, m2Size);

//reduce-side wsloss w/ or without broadcast
Lop wsloss = new WeightedSquaredLossR(
Expand Down Expand Up @@ -613,8 +607,8 @@ private void constructMRLopsWeightedSigmoid( WSigmoidType wtype )
grpX.getOutputParameters().setDimensions(X.getDim1(), X.getDim2(), X.getRowsInBlock(), X.getColsInBlock(), X.getNnz());
setLineNumbers(grpX);

Lop lU = obtainlU(U, V, cacheU, m1Size);
Lop lV = obtainlV(U, V, cacheV, m2Size);
Lop lU = constructLeftFactorMRLop(U, V, cacheU, m1Size);
Lop lV = constructRightFactorMRLop(U, V, cacheV, m2Size);

//reduce-side wsig w/ or without broadcast
Lop wsigmoid = new WeightedSigmoidR(
Expand Down Expand Up @@ -757,8 +751,8 @@ private void constructMRLopsWeightedDivMM( WDivMMType wtype )
grpX.getOutputParameters().setDimensions(X.getDim1(), X.getDim2(), X.getRowsInBlock(), X.getColsInBlock(), X.getNnz());
setLineNumbers(grpX);

Lop lU = obtainlU(U, V, cacheU, m1Size);
Lop lV = obtainlV(U, V, cacheV, m2Size);
Lop lU = constructLeftFactorMRLop(U, V, cacheU, m1Size);
Lop lV = constructRightFactorMRLop(U, V, cacheV, m2Size);

//reduce-side wdivmm w/ or without broadcast
Lop wdivmm = new WeightedDivMMR( grpW, lU, lV, grpX,
Expand Down Expand Up @@ -919,8 +913,8 @@ private void constructMRLopsWeightedCeMM(WCeMMType wtype)
grpX.getOutputParameters().setDimensions(X.getDim1(), X.getDim2(), X.getRowsInBlock(), X.getColsInBlock(), -1);
setLineNumbers(grpX);

Lop lU = obtainlU(U, V, cacheU, m1Size);
Lop lV = obtainlV(U, V, cacheV, m2Size);
Lop lU = constructLeftFactorMRLop(U, V, cacheU, m1Size);
Lop lV = constructRightFactorMRLop(U, V, cacheV, m2Size);

//reduce-side wcemm w/ or without broadcast
Lop wcemm = new WeightedCrossEntropyR( grpX, lU, lV, eps.constructLops(),
Expand Down Expand Up @@ -1076,8 +1070,8 @@ private void constructMRLopsWeightedUMM( WUMMType wtype )
grpX.getOutputParameters().setDimensions(X.getDim1(), X.getDim2(), X.getRowsInBlock(), X.getColsInBlock(), X.getNnz());
setLineNumbers(grpX);

Lop lU = obtainlU(U, V, cacheU, m1Size);
Lop lV = obtainlV(U, V, cacheV, m2Size);
Lop lU = constructLeftFactorMRLop(U, V, cacheU, m1Size);
Lop lV = constructRightFactorMRLop(U, V, cacheV, m2Size);

//reduce-side wumm w/ or without broadcast
Lop wumm = new WeightedUnaryMMR(
Expand Down Expand Up @@ -1254,7 +1248,7 @@ protected long[] inferOutputCharacteristics( MemoTable memo )
MatrixCharacteristics mcW = memo.getAllInputStats(getInput().get(0));
ret = new long[]{mcW.getRows(), mcW.getCols(), mcW.getNonZeros()};
}
if( _baseType == 1 || _baseType == 3 ) { //left (w/ transpose or w/ epsilon)
else if( _baseType == 1 || _baseType == 3 ) { //left (w/ transpose or w/ epsilon)
MatrixCharacteristics mcV = memo.getAllInputStats(getInput().get(2));
ret = new long[]{mcV.getRows(), mcV.getCols(), -1};
}
Expand Down Expand Up @@ -1329,24 +1323,26 @@ public void refreshSizeInformation()
Hop inW = getInput().get(0);
setDim1( inW.getDim1() );
setDim2( inW.getDim2() );
setNnz( inW.getNnz() );
setNnz( inW.getNnz() );
}
else if( _baseType == 1 || _baseType == 3 ){ //left (w/ transpose or w/ epsilon)
Hop inV = getInput().get(2);
setDim1( inV.getDim1() );
setDim2( inV.getDim2() );
setDim2( inV.getDim2() );
setNnz( -1 ); //reset
}
else { //right
Hop inU = getInput().get(1);
setDim1( inU.getDim1() );
setDim2( inU.getDim2() );
setDim2( inU.getDim2() );
setNnz( -1 ); //reset
}
break;
}

default:
break;
}
}
}

@Override
Expand Down

0 comments on commit ec35215

Please sign in to comment.