Skip to content

Commit

Permalink
[SYSTEMML-1879] Performance parfor remote spark (reuse shared inputs)
Browse files Browse the repository at this point in the history
In parfor remote spark jobs, each worker is initialized with its own
deserialized symbol table, which causes redundant reads of shared inputs
in each parfor worker and is unnecessarily memory-inefficient. This
patch introduces a principled approach to reusing shared inputs, where
we reuse all variables except for result variables and partitioned
matrices. By simply using common instances of matrix objects, the
sharing happens automatically through the bufferpool similar to local
parfor execution and without additional pinned memory requirements. On
the perftest scenario MSVM 1M x 1K, sparse with 150 classes and 25
iterations, the end-to-end runtime (including read and spark context
creation) improved from 94s to 72s.
  • Loading branch information
mboehm7 committed Sep 2, 2017
1 parent 912c655 commit 2c57cf7
Show file tree
Hide file tree
Showing 9 changed files with 356 additions and 337 deletions.
14 changes: 5 additions & 9 deletions src/main/java/org/apache/sysml/parser/DMLTranslator.java
Original file line number Diff line number Diff line change
Expand Up @@ -617,12 +617,11 @@ else if (sb instanceof ForStatementBlock)
ForProgramBlock rtpb = null;
IterablePredicate iterPred = fsb.getIterPredicate();

if( sb instanceof ParForStatementBlock )
{
if( sb instanceof ParForStatementBlock ) {
sbName = "ParForStatementBlock";
rtpb = new ParForProgramBlock(prog, iterPred.getIterVar().getName(), iterPred.getParForParams());
rtpb = new ParForProgramBlock(prog, iterPred.getIterVar().getName(),
iterPred.getParForParams(), ((ParForStatementBlock)sb).getResultVariables());
ParForProgramBlock pfrtpb = (ParForProgramBlock)rtpb;
pfrtpb.setResultVariables( ((ParForStatementBlock)sb).getResultVariables() );
pfrtpb.setStatementBlock((ParForStatementBlock)sb); //used for optimization and creating unscoped variables
}
else {//ForStatementBlock
Expand All @@ -636,8 +635,8 @@ else if (sb instanceof ForStatementBlock)

// process the body of the for statement block
if (fsb.getNumStatements() > 1){
LOG.error(fsb.printBlockErrorLocation() + " " + sbName + " should have 1 statement" );
throw new LopsException(fsb.printBlockErrorLocation() + " " + sbName + " should have 1 statement" );
LOG.error(fsb.printBlockErrorLocation() + " " + sbName + " should have 1 statement" );
throw new LopsException(fsb.printBlockErrorLocation() + " " + sbName + " should have 1 statement" );
}
ForStatement fs = (ForStatement)fsb.getStatement(0);
for (StatementBlock sblock : fs.getBody()){
Expand All @@ -653,9 +652,6 @@ else if (sb instanceof ForStatementBlock)

retPB = rtpb;

//post processing for generating missing instructions
//retPB = verifyAndCorrectProgramBlock(sb.liveIn(), sb.liveOut(), sb._kill, retPB);

// add statement block
retPB.setStatementBlock(sb);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,11 @@ public void removeAll() {
localMap.clear();
}

public void removeAllIn(Set<String> blacklist) {
localMap.entrySet().removeIf(
e -> blacklist.contains(e.getKey()));
}

public void removeAllNotIn(Set<String> blacklist) {
localMap.entrySet().removeIf(
e -> !blacklist.contains(e.getKey()));
Expand Down

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -544,23 +544,18 @@ public HashMap<String,Boolean> pinVariables(ArrayList<String> varList)
for( String var : varList )
{
Data dat = _variables.get(var);
if( dat instanceof MatrixObject )
{
if( dat instanceof MatrixObject ) {
MatrixObject mo = (MatrixObject)dat;
varsState.put( var, mo.isCleanupEnabled() );
//System.out.println("pre-pin "+var+" ("+mo.isCleanupEnabled()+")");
}
}

//step 2) pin variables
for( String var : varList )
{
for( String var : varList ) {
Data dat = _variables.get(var);
if( dat instanceof MatrixObject )
{
if( dat instanceof MatrixObject ) {
MatrixObject mo = (MatrixObject)dat;
mo.enableCleanup(false);
//System.out.println("pin "+var);
}
}

Expand All @@ -583,11 +578,8 @@ public HashMap<String,Boolean> pinVariables(ArrayList<String> varList)
* @param varList variable list
* @param varsState variable state
*/
public void unpinVariables(ArrayList<String> varList, HashMap<String,Boolean> varsState)
{
for( String var : varList)
{
//System.out.println("unpin "+var+" ("+varsState.get(var)+")");
public void unpinVariables(ArrayList<String> varList, HashMap<String,Boolean> varsState) {
for( String var : varList) {
Data dat = _variables.get(var);
if( dat instanceof MatrixObject )
((MatrixObject)dat).enableCleanup(varsState.get(var));
Expand All @@ -597,15 +589,28 @@ public void unpinVariables(ArrayList<String> varList, HashMap<String,Boolean> va
/**
* NOTE: No order guaranteed, so keep same list for pin and unpin.
*
* @return variable list as strings
* @return list of all variable names.
*/
public ArrayList<String> getVarList()
{
ArrayList<String> varlist = new ArrayList<String>();
varlist.addAll(_variables.keySet());
return varlist;
public ArrayList<String> getVarList() {
return new ArrayList<>(_variables.keySet());
}


/**
* NOTE: No order guaranteed, so keep same list for pin and unpin.
*
* @return list of all variable names of partitioned matrices.
*/
public ArrayList<String> getVarListPartitioned() {
ArrayList<String> ret = new ArrayList<>();
for( String var : _variables.keySet() ) {
Data dat = _variables.get(var);
if( dat instanceof MatrixObject
&& ((MatrixObject)dat).isPartitioned() )
ret.add(var);
}
return ret;
}

public void cleanupMatrixObject(MatrixObject mo)
throws DMLRuntimeException
{
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.apache.sysml.runtime.controlprogram.parfor;

import java.lang.ref.SoftReference;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;

import org.apache.sysml.runtime.controlprogram.LocalVariableMap;

public class CachedReuseVariables
{
private final HashMap<Long, SoftReference<LocalVariableMap>> _data;

public CachedReuseVariables() {
_data = new HashMap<>();
}

public synchronized void reuseVariables(long pfid, LocalVariableMap vars, Collection<String> blacklist) {
//check for existing reuse map
LocalVariableMap tmp = null;
if( _data.containsKey(pfid) )
tmp = _data.get(pfid).get();

//build reuse map if not created yet or evicted
if( tmp == null ) {
tmp = new LocalVariableMap(vars);
tmp.removeAllIn(new HashSet<>(blacklist));
_data.put(pfid, new SoftReference<>(tmp));
}
//reuse existing reuse map
else {
for( String varName : tmp.keySet() )
vars.put(varName, tmp.get(varName));
}
}

public synchronized void clearVariables(long pfid) {
_data.remove(pfid);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -328,16 +328,15 @@ public static ParForProgramBlock createDeepCopyParForProgramBlock(ParForProgramB
ParForProgramBlock tmpPB = null;

if( IDPrefix == -1 ) //still on master node
tmpPB = new ParForProgramBlock(prog,pfpb.getIterVar(), pfpb.getParForParams());
tmpPB = new ParForProgramBlock(prog,pfpb.getIterVar(), pfpb.getParForParams(), pfpb.getResultVariables());
else //child of remote ParWorker at any level
tmpPB = new ParForProgramBlock(IDPrefix, prog, pfpb.getIterVar(), pfpb.getParForParams());
tmpPB = new ParForProgramBlock(IDPrefix, prog, pfpb.getIterVar(), pfpb.getParForParams(), pfpb.getResultVariables());

tmpPB.setStatementBlock( createForStatementBlockCopy( (ForStatementBlock) pfpb.getStatementBlock(), pid, plain, forceDeepCopy) );
tmpPB.setThreadID(pid);

tmpPB.disableOptimization(); //already done in top-level parfor
tmpPB.disableMonitorReport(); //already done in top-level parfor
tmpPB.setResultVariables( pfpb.getResultVariables() );

tmpPB.setFromInstructions( createDeepCopyInstructionSet(pfpb.getFromInstructions(), pid, IDPrefix, prog, fnStack, fnCreated, plain, true) );
tmpPB.setToInstructions( createDeepCopyInstructionSet(pfpb.getToInstructions(), pid, IDPrefix, prog, fnStack, fnCreated, plain, true) );
Expand Down Expand Up @@ -1514,9 +1513,8 @@ private static ParForProgramBlock rParseParForProgramBlock( String in, Program p
//program blocks //reset id to preinit state, replaced during exec
ArrayList<ProgramBlock> pbs = rParseProgramBlocks(st.nextToken(), prog, 0);

ParForProgramBlock pfpb = new ParForProgramBlock(id, prog, iterVar, params);
ParForProgramBlock pfpb = new ParForProgramBlock(id, prog, iterVar, params, resultVars);
pfpb.disableOptimization(); //already done in top-level parfor
pfpb.setResultVariables(resultVars);
pfpb.setFromInstructions(from);
pfpb.setToInstructions(to);
pfpb.setIncrementInstructions(incr);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@
import org.apache.sysml.runtime.controlprogram.LocalVariableMap;
import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext;
import org.apache.sysml.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer;
import org.apache.sysml.runtime.controlprogram.parfor.util.IDSequence;
import org.apache.sysml.utils.Statistics;

/**
Expand All @@ -51,12 +53,14 @@
*/
public class RemoteParForSpark
{

protected static final Log LOG = LogFactory.getLog(RemoteParForSpark.class.getName());

public static RemoteParForJobReturn runJob(long pfid, String program, HashMap<String, byte[]> clsMap,

//globally unique id for parfor spark job instances (unique across spark contexts)
private static final IDSequence _jobID = new IDSequence();

public static RemoteParForJobReturn runJob(long pfid, String prog, HashMap<String, byte[]> clsMap,
List<Task> tasks, ExecutionContext ec, boolean cpCaching, int numMappers)
throws DMLRuntimeException
throws DMLRuntimeException
{
String jobname = "ParFor-ESP";
long t0 = DMLScript.STATISTICS ? System.nanoTime() : 0;
Expand All @@ -68,13 +72,16 @@ public static RemoteParForJobReturn runJob(long pfid, String program, HashMap<St
LongAccumulator aTasks = sc.sc().longAccumulator("tasks");
LongAccumulator aIters = sc.sc().longAccumulator("iterations");

//reset cached shared inputs for correctness in local mode
long jobid = _jobID.getNextID();
if( InfrastructureAnalyzer.isLocalMode() )
RemoteParForSparkWorker.cleanupCachedVariables(jobid);

//run remote_spark parfor job
//(w/o lazy evaluation to fit existing parfor framework, e.g., result merge)
RemoteParForSparkWorker func = new RemoteParForSparkWorker(program, clsMap, cpCaching, aTasks, aIters);
List<Tuple2<Long,String>> out = sc
.parallelize(tasks, tasks.size()) //create rdd of parfor tasks
.flatMapToPair(func) //execute parfor tasks
.collect(); //get output handles
List<Tuple2<Long,String>> out = sc.parallelize(tasks, tasks.size()) //create rdd of parfor tasks
.flatMapToPair(new RemoteParForSparkWorker(jobid, prog, clsMap, cpCaching, aTasks, aIters))
.collect(); //execute and get output handles

//de-serialize results
LocalVariableMap[] results = RemoteParForUtils.getResults(out, LOG);
Expand All @@ -85,11 +92,10 @@ public static RemoteParForJobReturn runJob(long pfid, String program, HashMap<St
RemoteParForJobReturn ret = new RemoteParForJobReturn(true, numTasks, numIters, results);

//maintain statistics
Statistics.incrementNoOfCompiledSPInst();
Statistics.incrementNoOfExecutedSPInst();
if( DMLScript.STATISTICS ){
Statistics.incrementNoOfCompiledSPInst();
Statistics.incrementNoOfExecutedSPInst();
if( DMLScript.STATISTICS )
Statistics.maintainCPHeavyHitters(jobname, System.nanoTime()-t0);
}

return ret;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,12 @@

import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map.Entry;

import org.apache.commons.collections.CollectionUtils;
import org.apache.spark.TaskContext;
import org.apache.spark.api.java.function.PairFlatMapFunction;
import org.apache.spark.util.LongAccumulator;
Expand All @@ -40,18 +42,22 @@
public class RemoteParForSparkWorker extends ParWorker implements PairFlatMapFunction<Task, Long, String>
{
private static final long serialVersionUID = -3254950138084272296L;

private final String _prog;

private static final CachedReuseVariables reuseVars = new CachedReuseVariables();

private final long _jobid;
private final String _prog;
private final HashMap<String, byte[]> _clsMap;
private boolean _initialized = false;
private boolean _caching = true;

private final LongAccumulator _aTasks;
private final LongAccumulator _aIters;

public RemoteParForSparkWorker(String program, HashMap<String, byte[]> clsMap, boolean cpCaching, LongAccumulator atasks, LongAccumulator aiters)
public RemoteParForSparkWorker(long jobid, String program, HashMap<String, byte[]> clsMap, boolean cpCaching, LongAccumulator atasks, LongAccumulator aiters)
throws DMLRuntimeException
{
_jobid = jobid;
_prog = program;
_clsMap = clsMap;
_initialized = false;
Expand All @@ -68,7 +74,7 @@ public Iterator<Tuple2<Long, String>> call(Task arg0)
{
//lazy parworker initialization
if( !_initialized )
configureWorker( TaskContext.get().taskAttemptId() );
configureWorker(TaskContext.get().taskAttemptId());

//execute a single task
long numIter = getExecutedIterations();
Expand All @@ -88,10 +94,11 @@ public Iterator<Tuple2<Long, String>> call(Task arg0)
return ret.iterator();
}

private void configureWorker( long ID )
@SuppressWarnings("unchecked")
private void configureWorker(long taskID)
throws DMLRuntimeException, IOException
{
_workerID = ID;
_workerID = taskID;

//initialize codegen class cache (before program parsing)
synchronized( CodegenUtils.class ) {
Expand All @@ -106,7 +113,13 @@ private void configureWorker( long ID )
_resultVars = body.getResultVarNames();
_numTasks = 0;
_numIters = 0;


//reuse shared inputs (to read shared inputs once per process instead of once per core;
//we reuse everything except result variables and partitioned input matrices)
_ec.pinVariables(_ec.getVarList()); //avoid cleanup of shared inputs
Collection<String> blacklist = CollectionUtils.union(_resultVars, _ec.getVarListPartitioned());
reuseVars.reuseVariables(_jobid, _ec.getVariables(), blacklist);

//init and register-cleanup of buffer pool (in parfor spark, multiple tasks might
//share the process-local, i.e., per executor, buffer pool; hence we synchronize
//the initialization and immediately register the created directory for cleanup
Expand All @@ -121,7 +134,7 @@ private void configureWorker( long ID )
CacheableData.cacheEvictionLocalFilePrefix +"_" + _workerID;
//register entire working dir for delete on shutdown
RemoteParForUtils.cleanupWorkingDirectoriesOnShutdown();
}
}
}

//ensure that resultvar files are not removed
Expand All @@ -134,4 +147,8 @@ private void configureWorker( long ID )
//mark as initialized
_initialized = true;
}

public static void cleanupCachedVariables(long pfid) {
reuseVars.clearVariables(pfid);
}
}
Loading

0 comments on commit 2c57cf7

Please sign in to comment.