Skip to content

Commit

Permalink
[SYSTEMML-2419] Paramserv spark function shipping and worker setup
Browse files Browse the repository at this point in the history
Closes apache#799.
  • Loading branch information
EdgarLGB authored and mboehm7 committed Jul 16, 2018
1 parent 614adec commit cffefca
Show file tree
Hide file tree
Showing 40 changed files with 808 additions and 565 deletions.
4 changes: 2 additions & 2 deletions src/main/java/org/apache/sysml/api/DMLScript.java
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import java.net.URI;
import java.text.DateFormat;
import java.text.SimpleDateFormat;
import java.util.Arrays;
import java.util.Collections;
import java.util.Date;
import java.util.HashMap;
Expand Down Expand Up @@ -76,7 +77,6 @@
import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysml.runtime.controlprogram.context.ExecutionContextFactory;
import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext;
import org.apache.sysml.runtime.controlprogram.parfor.ProgramConverter;
import org.apache.sysml.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer;
import org.apache.sysml.runtime.controlprogram.parfor.util.IDHandler;
import org.apache.sysml.runtime.instructions.gpu.context.GPUContextPool;
Expand Down Expand Up @@ -853,7 +853,7 @@ private static void checkSecuritySetup(DMLConfig config)

LOG.debug("SystemML security check: "
+ "local.user.name = " + userName + ", "
+ "local.user.groups = " + ProgramConverter.serializeStringCollection(groupNames) + ", "
+ "local.user.groups = " + Arrays.toString(groupNames.toArray()) + ", "
+ MRConfigurationNames.MR_JOBTRACKER_ADDRESS + " = " + job.get(MRConfigurationNames.MR_JOBTRACKER_ADDRESS) + ", "
+ MRConfigurationNames.MR_TASKTRACKER_TASKCONTROLLER + " = " + taskController + ","
+ MRConfigurationNames.MR_TASKTRACKER_GROUP + " = " + ttGroupName + ", "
Expand Down
5 changes: 2 additions & 3 deletions src/main/java/org/apache/sysml/hops/DataGenOp.java
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@
import org.apache.sysml.parser.Expression.DataType;
import org.apache.sysml.parser.Expression.ValueType;
import org.apache.sysml.parser.Statement;
import org.apache.sysml.runtime.controlprogram.parfor.ProgramConverter;
import org.apache.sysml.runtime.util.UtilFunctions;

/**
Expand Down Expand Up @@ -108,8 +107,8 @@ public DataGenOp(DataGenMethod mthd, DataIdentifier id, HashMap<String, Hop> inp

//generate base dir
String scratch = ConfigurationManager.getScratchSpace();
_baseDir = scratch + Lop.FILE_SEPARATOR + Lop.PROCESS_PREFIX + DMLScript.getUUID() + Lop.FILE_SEPARATOR +
Lop.FILE_SEPARATOR + ProgramConverter.CP_ROOT_THREAD_ID + Lop.FILE_SEPARATOR;
_baseDir = scratch + Lop.FILE_SEPARATOR + Lop.PROCESS_PREFIX + DMLScript.getUUID() + Lop.FILE_SEPARATOR
+ Lop.FILE_SEPARATOR + Lop.CP_ROOT_THREAD_ID + Lop.FILE_SEPARATOR;

//compute unknown dims and nnz
refreshSizeInformation();
Expand Down
3 changes: 1 addition & 2 deletions src/main/java/org/apache/sysml/hops/OptimizerUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@
import org.apache.sysml.runtime.controlprogram.LocalVariableMap;
import org.apache.sysml.runtime.controlprogram.caching.LazyWriteBuffer;
import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext;
import org.apache.sysml.runtime.controlprogram.parfor.ProgramConverter;
import org.apache.sysml.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer;
import org.apache.sysml.runtime.functionobjects.IntegerDivide;
import org.apache.sysml.runtime.functionobjects.Modulus;
Expand Down Expand Up @@ -927,7 +926,7 @@ public static boolean exceedsCachingThreshold(long dim2, double outMem) {
public static String getUniqueTempFileName() {
return ConfigurationManager.getScratchSpace()
+ Lop.FILE_SEPARATOR + Lop.PROCESS_PREFIX + DMLScript.getUUID()
+ Lop.FILE_SEPARATOR + ProgramConverter.CP_ROOT_THREAD_ID + Lop.FILE_SEPARATOR
+ Lop.FILE_SEPARATOR + Lop.CP_ROOT_THREAD_ID + Lop.FILE_SEPARATOR
+ Dag.getNextUniqueFilenameSuffix();
}

Expand Down
5 changes: 2 additions & 3 deletions src/main/java/org/apache/sysml/hops/recompile/Recompiler.java
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@
import org.apache.sysml.runtime.controlprogram.caching.FrameObject;
import org.apache.sysml.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysml.runtime.controlprogram.parfor.ProgramConverter;
import org.apache.sysml.runtime.util.ProgramConverter;
import org.apache.sysml.runtime.controlprogram.parfor.opt.OptTreeConverter;
import org.apache.sysml.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer;
import org.apache.sysml.runtime.instructions.Instruction;
Expand Down Expand Up @@ -589,8 +589,7 @@ public static void rUpdateFunctionNames( Hop hop, long pid )
//update function names
if( hop instanceof FunctionOp && ((FunctionOp)hop).getFunctionType() != FunctionType.MULTIRETURN_BUILTIN) {
FunctionOp fop = (FunctionOp) hop;
fop.setFunctionName( fop.getFunctionName() +
ProgramConverter.CP_CHILD_THREAD + pid);
fop.setFunctionName( fop.getFunctionName() + Lop.CP_CHILD_THREAD + pid);
}

if( hop.getInput() != null )
Expand Down
2 changes: 2 additions & 0 deletions src/main/java/org/apache/sysml/lops/Lop.java
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,8 @@ public enum VisitStatus {

public static final String FILE_SEPARATOR = "/";
public static final String PROCESS_PREFIX = "_p";
public static final String CP_ROOT_THREAD_ID = "_t0";
public static final String CP_CHILD_THREAD = "_t";

//special delimiters w/ extended ASCII characters to avoid collisions
public static final String INSTRUCTION_DELIMITOR = "\u2021";
Expand Down
3 changes: 1 addition & 2 deletions src/main/java/org/apache/sysml/lops/compile/Dag.java
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,6 @@
import org.apache.sysml.parser.Expression.DataType;
import org.apache.sysml.parser.StatementBlock;
import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.controlprogram.parfor.ProgramConverter;
import org.apache.sysml.runtime.controlprogram.parfor.util.IDSequence;
import org.apache.sysml.runtime.instructions.CPInstructionParser;
import org.apache.sysml.runtime.instructions.Instruction;
Expand Down Expand Up @@ -198,7 +197,7 @@ private String getFilePath() {
scratchFilePath = scratch + Lop.FILE_SEPARATOR
+ Lop.PROCESS_PREFIX + DMLScript.getUUID()
+ Lop.FILE_SEPARATOR + Lop.FILE_SEPARATOR
+ ProgramConverter.CP_ROOT_THREAD_ID + Lop.FILE_SEPARATOR;
+ Lop.CP_ROOT_THREAD_ID + Lop.FILE_SEPARATOR;
}
return scratchFilePath;
}
Expand Down
3 changes: 1 addition & 2 deletions src/main/java/org/apache/sysml/parser/DMLTranslator.java
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,6 @@
import org.apache.sysml.runtime.controlprogram.Program;
import org.apache.sysml.runtime.controlprogram.ProgramBlock;
import org.apache.sysml.runtime.controlprogram.WhileProgramBlock;
import org.apache.sysml.runtime.controlprogram.parfor.ProgramConverter;
import org.apache.sysml.runtime.instructions.Instruction;


Expand Down Expand Up @@ -613,7 +612,7 @@ else if (sb instanceof FunctionStatementBlock){
buff.append(Lop.PROCESS_PREFIX);
buff.append(DMLScript.getUUID());
buff.append(Lop.FILE_SEPARATOR);
buff.append(ProgramConverter.CP_ROOT_THREAD_ID);
buff.append(Lop.CP_ROOT_THREAD_ID);
buff.append(Lop.FILE_SEPARATOR);
buff.append("PackageSupport");
buff.append(Lop.FILE_SEPARATOR);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@

import org.apache.sysml.api.DMLScript;
import org.apache.sysml.runtime.controlprogram.caching.CacheableData;
import org.apache.sysml.runtime.controlprogram.parfor.ProgramConverter;
import org.apache.sysml.runtime.util.ProgramConverter;
import org.apache.sysml.runtime.controlprogram.parfor.util.IDSequence;
import org.apache.sysml.runtime.instructions.cp.Data;
import org.apache.sysml.utils.Statistics;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@
import org.apache.sysml.runtime.controlprogram.parfor.LocalParWorker;
import org.apache.sysml.runtime.controlprogram.parfor.LocalTaskQueue;
import org.apache.sysml.runtime.controlprogram.parfor.ParForBody;
import org.apache.sysml.runtime.controlprogram.parfor.ProgramConverter;
import org.apache.sysml.runtime.util.ProgramConverter;
import org.apache.sysml.runtime.controlprogram.parfor.RemoteDPParForMR;
import org.apache.sysml.runtime.controlprogram.parfor.RemoteDPParForSpark;
import org.apache.sysml.runtime.controlprogram.parfor.RemoteParForJobReturn;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,10 @@ protected ExecutionContext( boolean allocateVariableMap, Program prog )
public Program getProgram(){
return _prog;
}

public void setProgram(Program prog) {
_prog = prog;
}

public LocalVariableMap getVariables() {
return _variables;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,23 +34,25 @@
import org.apache.sysml.runtime.instructions.cp.CPOperand;
import org.apache.sysml.runtime.instructions.cp.FunctionCallCPInstruction;

@SuppressWarnings("unused")
public abstract class PSWorker {

protected final int _workerID;
protected final int _epochs;
protected final long _batchSize;
protected final ExecutionContext _ec;
protected final ParamServer _ps;
protected final DataIdentifier _output;
protected final FunctionCallCPInstruction _inst;
protected int _workerID;
protected int _epochs;
protected long _batchSize;
protected ExecutionContext _ec;
protected ParamServer _ps;
protected DataIdentifier _output;
protected FunctionCallCPInstruction _inst;
protected MatrixObject _features;
protected MatrixObject _labels;

private MatrixObject _valFeatures;
private MatrixObject _valLabels;
private final String _updFunc;
protected final Statement.PSFrequency _freq;

protected MatrixObject _valFeatures;
protected MatrixObject _valLabels;
protected String _updFunc;
protected Statement.PSFrequency _freq;

protected PSWorker() {

}

protected PSWorker(int workerID, String updFunc, Statement.PSFrequency freq, int epochs, long batchSize,
MatrixObject valFeatures, MatrixObject valLabels, ExecutionContext ec, ParamServer ps) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@
import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext;
import org.apache.sysml.runtime.controlprogram.paramserv.spark.DataPartitionerSparkAggregator;
import org.apache.sysml.runtime.controlprogram.paramserv.spark.DataPartitionerSparkMapper;
import org.apache.sysml.runtime.controlprogram.parfor.ProgramConverter;
import org.apache.sysml.runtime.functionobjects.Plus;
import org.apache.sysml.runtime.instructions.cp.Data;
import org.apache.sysml.runtime.instructions.cp.ListObject;
Expand All @@ -68,6 +67,7 @@
import org.apache.sysml.runtime.matrix.data.MatrixIndexes;
import org.apache.sysml.runtime.matrix.data.OutputInfo;
import org.apache.sysml.runtime.matrix.operators.BinaryOperator;
import org.apache.sysml.runtime.util.ProgramConverter;

import scala.Tuple2;

Expand Down Expand Up @@ -175,9 +175,8 @@ public static String[] getCompleteFuncName(String funcName, String prefix) {
new String[]{ns, name} : new String[]{ns, name};
}

public static List<ExecutionContext> createExecutionContexts(ExecutionContext ec, LocalVariableMap varsMap,
String updFunc, String aggFunc, int workerNum, int k) {

public static ExecutionContext createExecutionContext(ExecutionContext ec, LocalVariableMap varsMap, String updFunc,
String aggFunc, int k) {
FunctionProgramBlock updPB = getFunctionBlock(ec, updFunc);
FunctionProgramBlock aggPB = getFunctionBlock(ec, aggFunc);

Expand All @@ -188,27 +187,21 @@ public static List<ExecutionContext> createExecutionContexts(ExecutionContext ec
// 2. Recompile the imported function blocks
prog.getFunctionProgramBlocks().forEach((fname, fvalue) -> recompileProgramBlocks(k, fvalue.getChildBlocks()));

// 3. Copy function for workers
List<ExecutionContext> workerECs = IntStream.range(0, workerNum)
.mapToObj(i -> {
FunctionProgramBlock newUpdFunc = copyFunction(updFunc, updPB);
FunctionProgramBlock newAggFunc = copyFunction(aggFunc, aggPB);
Program newProg = new Program();
putFunction(newProg, newUpdFunc);
putFunction(newProg, newAggFunc);
return ExecutionContextFactory.createContext(new LocalVariableMap(varsMap), newProg);
})
.collect(Collectors.toList());

// 4. Copy function for agg service
// 3. Copy function
FunctionProgramBlock newUpdFunc = copyFunction(updFunc, updPB);
FunctionProgramBlock newAggFunc = copyFunction(aggFunc, aggPB);
Program newProg = new Program();
putFunction(newProg, newUpdFunc);
putFunction(newProg, newAggFunc);
ExecutionContext aggEC = ExecutionContextFactory.createContext(new LocalVariableMap(varsMap), newProg);
return ExecutionContextFactory.createContext(new LocalVariableMap(varsMap), newProg);
}

List<ExecutionContext> result = new ArrayList<>(workerECs);
result.add(aggEC);
return result;
public static List<ExecutionContext> copyExecutionContext(ExecutionContext ec, int num) {
return IntStream.range(0, num).mapToObj(i -> {
Program newProg = new Program();
ec.getProgram().getFunctionProgramBlocks().forEach((func, pb) -> putFunction(newProg, copyFunction(func, pb)));
return ExecutionContextFactory.createContext(new LocalVariableMap(ec.getVariables()), newProg);
}).collect(Collectors.toList());
}

private static FunctionProgramBlock copyFunction(String funcName, FunctionProgramBlock fpb) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.apache.sysml.runtime.controlprogram.paramserv.spark;

import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;

/**
* Wrapper class containing all needed for launching spark remote worker
*/
public class SparkPSBody {

private ExecutionContext _ec;

public SparkPSBody() {

}

public SparkPSBody(ExecutionContext ec) {
this._ec = ec;
}

public ExecutionContext getEc() {
return _ec;
}

public void setEc(ExecutionContext ec) {
this._ec = ec;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,30 +19,59 @@

package org.apache.sysml.runtime.controlprogram.paramserv.spark;

import java.io.IOException;
import java.io.Serializable;
import java.util.HashMap;
import java.util.Map;

import org.apache.spark.api.java.function.VoidFunction;
import org.apache.sysml.parser.Statement;
import org.apache.sysml.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysml.runtime.controlprogram.paramserv.ParamServer;
import org.apache.sysml.runtime.codegen.CodegenUtils;
import org.apache.sysml.runtime.controlprogram.paramserv.PSWorker;
import org.apache.sysml.runtime.controlprogram.parfor.RemoteParForUtils;
import org.apache.sysml.runtime.matrix.data.MatrixBlock;
import org.apache.sysml.runtime.util.ProgramConverter;

import scala.Tuple2;

public class SparkPSWorker implements VoidFunction<Tuple2<Integer, Tuple2<MatrixBlock, MatrixBlock>>>, Serializable {
public class SparkPSWorker extends PSWorker implements VoidFunction<Tuple2<Integer, Tuple2<MatrixBlock, MatrixBlock>>>, Serializable {

private static final long serialVersionUID = -8674739573419648732L;

public SparkPSWorker() {
private String _program;
private HashMap<String, byte[]> _clsMap;

protected SparkPSWorker() {
// No-args constructor used for deserialization
}

public SparkPSWorker(String updFunc, Statement.PSFrequency freq, int epochs, long batchSize,
MatrixObject valFeatures, MatrixObject valLabels, ExecutionContext ec, ParamServer ps) {
public SparkPSWorker(String updFunc, Statement.PSFrequency freq, int epochs, long batchSize, String program, HashMap<String, byte[]> clsMap) {
_updFunc = updFunc;
_freq = freq;
_epochs = epochs;
_batchSize = batchSize;
_program = program;
_clsMap = clsMap;
}

@Override
public void call(Tuple2<Integer, Tuple2<MatrixBlock, MatrixBlock>> input) throws Exception {
configureWorker(input);
}

private void configureWorker(Tuple2<Integer, Tuple2<MatrixBlock, MatrixBlock>> input) throws IOException {
_workerID = input._1;

// Initialize codegen class cache (before program parsing)
for (Map.Entry<String, byte[]> e : _clsMap.entrySet()) {
CodegenUtils.getClassSync(e.getKey(), e.getValue());
}

// Deserialize the body to initialize the execution context
SparkPSBody body = ProgramConverter.parseSparkPSBody(_program, _workerID);
_ec = body.getEc();

// Initialize the buffer pool and register it in the jvm shutdown hook in order to be cleanuped at the end
RemoteParForUtils.setupBufferPool(_workerID);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
import org.apache.sysml.runtime.matrix.mapred.MRConfigurationNames;
import org.apache.sysml.runtime.matrix.mapred.MRJobConfiguration;
import org.apache.sysml.runtime.util.MapReduceTool;
import org.apache.sysml.runtime.util.ProgramConverter;
import org.apache.sysml.utils.Statistics;
import org.apache.sysml.yarn.DMLAppMasterUtils;

Expand Down
Loading

0 comments on commit cffefca

Please sign in to comment.