diff --git a/src/main/java/org/apache/sysml/hops/QuaternaryOp.java b/src/main/java/org/apache/sysml/hops/QuaternaryOp.java index bfbaae7ec00..6395d6a2fc2 100644 --- a/src/main/java/org/apache/sysml/hops/QuaternaryOp.java +++ b/src/main/java/org/apache/sysml/hops/QuaternaryOp.java @@ -59,7 +59,6 @@ */ public class QuaternaryOp extends Hop implements MultiThreadedHop { - //config influencing mr operator selection (for testing purposes only) public static boolean FORCE_REPLICATION = false; @@ -321,7 +320,7 @@ private void constructCPLopsWeightedSquaredLoss(WeightsType wtype) setLops( wsloss ); } - private Lop obtainlU(Hop U, Hop V, boolean cacheU, double m1Size) throws HopsException, LopsException { + private Lop constructLeftFactorMRLop(Hop U, Hop V, boolean cacheU, double m1Size) throws HopsException, LopsException { Lop lU = null; if (cacheU) { // partitioning of U for read through distributed cache @@ -331,28 +330,26 @@ private Lop obtainlU(Hop U, Hop V, boolean cacheU, double m1Size) throws HopsExc lU = new DataPartition(lU, DataType.MATRIX, ValueType.DOUBLE, (m1Size > OptimizerUtils.getLocalMemBudget()) ? ExecType.MR : ExecType.CP, PDataPartitionFormat.ROW_BLOCK_WISE_N); - lU.getOutputParameters().setDimensions(U.getDim1(), U.getDim2(), getRowsInBlock(), getColsInBlock(), - U.getNnz()); + lU.getOutputParameters().setDimensions(U.getDim1(), U.getDim2(), getRowsInBlock(), getColsInBlock(), U.getNnz()); setLineNumbers(lU); } - } else { + } + else { // replication of U for shuffle to target block Lop offset = createOffsetLop(V, false); // ncol of t(V) -> nrow of V determines num replicates lU = new RepMat(U.constructLops(), offset, true, V.getDataType(), V.getValueType()); - lU.getOutputParameters().setDimensions(U.getDim1(), U.getDim2(), U.getRowsInBlock(), U.getColsInBlock(), - U.getNnz()); + lU.getOutputParameters().setDimensions(U.getDim1(), U.getDim2(), U.getRowsInBlock(), U.getColsInBlock(), U.getNnz()); setLineNumbers(lU); - + Group grpU = new Group(lU, Group.OperationTypes.Sort, DataType.MATRIX, ValueType.DOUBLE); - grpU.getOutputParameters().setDimensions(U.getDim1(), U.getDim2(), U.getRowsInBlock(), U.getColsInBlock(), - -1); + grpU.getOutputParameters().setDimensions(U.getDim1(), U.getDim2(), U.getRowsInBlock(), U.getColsInBlock(), -1); setLineNumbers(grpU); lU = grpU; } return lU; } - private Lop obtainlV(Hop U, Hop V, boolean cacheV, double m2Size) throws HopsException, LopsException { + private Lop constructRightFactorMRLop(Hop U, Hop V, boolean cacheV, double m2Size) throws HopsException, LopsException { Lop lV = null; if (cacheV) { // partitioning of V for read through distributed cache @@ -362,27 +359,24 @@ private Lop obtainlV(Hop U, Hop V, boolean cacheV, double m2Size) throws HopsExc lV = new DataPartition(lV, DataType.MATRIX, ValueType.DOUBLE, (m2Size > OptimizerUtils.getLocalMemBudget()) ? ExecType.MR : ExecType.CP, PDataPartitionFormat.ROW_BLOCK_WISE_N); - lV.getOutputParameters().setDimensions(V.getDim1(), V.getDim2(), getRowsInBlock(), getColsInBlock(), - V.getNnz()); + lV.getOutputParameters().setDimensions(V.getDim1(), V.getDim2(), getRowsInBlock(), getColsInBlock(), V.getNnz()); setLineNumbers(lV); } - } else { + } + else { // replication of t(V) for shuffle to target block Transform ltV = new Transform(V.constructLops(), HopsTransf2Lops.get(ReOrgOp.TRANSPOSE), getDataType(), getValueType(), ExecType.MR); - ltV.getOutputParameters().setDimensions(V.getDim2(), V.getDim1(), V.getColsInBlock(), V.getRowsInBlock(), - V.getNnz()); + ltV.getOutputParameters().setDimensions(V.getDim2(), V.getDim1(), V.getColsInBlock(), V.getRowsInBlock(), V.getNnz()); setLineNumbers(ltV); - + Lop offset = createOffsetLop(U, false); // nrow of U determines num replicates lV = new RepMat(ltV, offset, false, V.getDataType(), V.getValueType()); - lV.getOutputParameters().setDimensions(V.getDim2(), V.getDim1(), V.getColsInBlock(), V.getRowsInBlock(), - V.getNnz()); + lV.getOutputParameters().setDimensions(V.getDim2(), V.getDim1(), V.getColsInBlock(), V.getRowsInBlock(), V.getNnz()); setLineNumbers(lV); - + Group grpV = new Group(lV, Group.OperationTypes.Sort, DataType.MATRIX, ValueType.DOUBLE); - grpV.getOutputParameters().setDimensions(V.getDim2(), V.getDim1(), V.getColsInBlock(), V.getRowsInBlock(), - -1); + grpV.getOutputParameters().setDimensions(V.getDim2(), V.getDim1(), V.getColsInBlock(), V.getRowsInBlock(), -1); setLineNumbers(grpV); lV = grpV; } @@ -464,8 +458,8 @@ private void constructMRLopsWeightedSquaredLoss(WeightsType wtype) setLineNumbers(grpW); } - Lop lU = obtainlU(U, V, cacheU, m1Size); - Lop lV = obtainlV(U, V, cacheV, m2Size); + Lop lU = constructLeftFactorMRLop(U, V, cacheU, m1Size); + Lop lV = constructRightFactorMRLop(U, V, cacheV, m2Size); //reduce-side wsloss w/ or without broadcast Lop wsloss = new WeightedSquaredLossR( @@ -613,8 +607,8 @@ private void constructMRLopsWeightedSigmoid( WSigmoidType wtype ) grpX.getOutputParameters().setDimensions(X.getDim1(), X.getDim2(), X.getRowsInBlock(), X.getColsInBlock(), X.getNnz()); setLineNumbers(grpX); - Lop lU = obtainlU(U, V, cacheU, m1Size); - Lop lV = obtainlV(U, V, cacheV, m2Size); + Lop lU = constructLeftFactorMRLop(U, V, cacheU, m1Size); + Lop lV = constructRightFactorMRLop(U, V, cacheV, m2Size); //reduce-side wsig w/ or without broadcast Lop wsigmoid = new WeightedSigmoidR( @@ -757,8 +751,8 @@ private void constructMRLopsWeightedDivMM( WDivMMType wtype ) grpX.getOutputParameters().setDimensions(X.getDim1(), X.getDim2(), X.getRowsInBlock(), X.getColsInBlock(), X.getNnz()); setLineNumbers(grpX); - Lop lU = obtainlU(U, V, cacheU, m1Size); - Lop lV = obtainlV(U, V, cacheV, m2Size); + Lop lU = constructLeftFactorMRLop(U, V, cacheU, m1Size); + Lop lV = constructRightFactorMRLop(U, V, cacheV, m2Size); //reduce-side wdivmm w/ or without broadcast Lop wdivmm = new WeightedDivMMR( grpW, lU, lV, grpX, @@ -919,8 +913,8 @@ private void constructMRLopsWeightedCeMM(WCeMMType wtype) grpX.getOutputParameters().setDimensions(X.getDim1(), X.getDim2(), X.getRowsInBlock(), X.getColsInBlock(), -1); setLineNumbers(grpX); - Lop lU = obtainlU(U, V, cacheU, m1Size); - Lop lV = obtainlV(U, V, cacheV, m2Size); + Lop lU = constructLeftFactorMRLop(U, V, cacheU, m1Size); + Lop lV = constructRightFactorMRLop(U, V, cacheV, m2Size); //reduce-side wcemm w/ or without broadcast Lop wcemm = new WeightedCrossEntropyR( grpX, lU, lV, eps.constructLops(), @@ -1076,8 +1070,8 @@ private void constructMRLopsWeightedUMM( WUMMType wtype ) grpX.getOutputParameters().setDimensions(X.getDim1(), X.getDim2(), X.getRowsInBlock(), X.getColsInBlock(), X.getNnz()); setLineNumbers(grpX); - Lop lU = obtainlU(U, V, cacheU, m1Size); - Lop lV = obtainlV(U, V, cacheV, m2Size); + Lop lU = constructLeftFactorMRLop(U, V, cacheU, m1Size); + Lop lV = constructRightFactorMRLop(U, V, cacheV, m2Size); //reduce-side wumm w/ or without broadcast Lop wumm = new WeightedUnaryMMR( @@ -1254,7 +1248,7 @@ protected long[] inferOutputCharacteristics( MemoTable memo ) MatrixCharacteristics mcW = memo.getAllInputStats(getInput().get(0)); ret = new long[]{mcW.getRows(), mcW.getCols(), mcW.getNonZeros()}; } - if( _baseType == 1 || _baseType == 3 ) { //left (w/ transpose or w/ epsilon) + else if( _baseType == 1 || _baseType == 3 ) { //left (w/ transpose or w/ epsilon) MatrixCharacteristics mcV = memo.getAllInputStats(getInput().get(2)); ret = new long[]{mcV.getRows(), mcV.getCols(), -1}; } @@ -1329,24 +1323,26 @@ public void refreshSizeInformation() Hop inW = getInput().get(0); setDim1( inW.getDim1() ); setDim2( inW.getDim2() ); - setNnz( inW.getNnz() ); + setNnz( inW.getNnz() ); } else if( _baseType == 1 || _baseType == 3 ){ //left (w/ transpose or w/ epsilon) Hop inV = getInput().get(2); setDim1( inV.getDim1() ); - setDim2( inV.getDim2() ); + setDim2( inV.getDim2() ); + setNnz( -1 ); //reset } else { //right Hop inU = getInput().get(1); setDim1( inU.getDim1() ); - setDim2( inU.getDim2() ); + setDim2( inU.getDim2() ); + setNnz( -1 ); //reset } break; } default: break; - } + } } @Override