Skip to content

Commit

Permalink
[MINOR] Federated Modifications
Browse files Browse the repository at this point in the history
Major reduction in federated tests, by redusing startup time of
federated tests with multiple workers.

Furthermore a timeout is added to funtions tests allowing only 60 minutes
of execution time before being forcefully terminated.
This reduce the waiting time for feedback of tests that anyway would
timeout after 6 hours.

Isolate Function Test in workflows,
and stabilize negative federated test,
and reduce Federated Kmeans Tests

Privacy monitor added a null pointer check that happens if the object on
the federated site becomes null. This error would result in stack traces
that were hard to debug.

Fix 🐛 in federated right indexing if the indexing aligns to a split
between locations.
  • Loading branch information
Baunsgaard committed Nov 14, 2020
1 parent 809d53f commit 914b8f8
Show file tree
Hide file tree
Showing 34 changed files with 399 additions and 171 deletions.
4 changes: 3 additions & 1 deletion .github/workflows/functionsTests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,15 @@ on:
jobs:
applicationsTests:
runs-on: ${{ matrix.os }}
timeout-minutes: 60
strategy:
fail-fast: false
matrix:
tests: [
"**.functions.aggregate.**,**.functions.append.**,**.functions.binary.frame.**,**.functions.binary.matrix.**,**.functions.binary.scalar.**,**.functions.binary.tensor.**",
"**.functions.blocks.**,**.functions.compress.**,**.functions.countDistinct.**,**.functions.data.misc.**,**.functions.data.rand.**,**.functions.data.tensor.**,**.functions.codegenalg.parttwo.**,**.functions.codegen.**,**.functions.caching.**",
"**.functions.federated.**,**.functions.binary.matrix_full_cellwise.**,**.functions.binary.matrix_full_other.**",
"**.functions.binary.matrix_full_cellwise.**,**.functions.binary.matrix_full_other.**",
"**.functions.federated.**",
"**.functions.codegenalg.partone.**",
"**.functions.builtin.**",
"**.functions.frame.**,**.functions.indexing.**,**.functions.io.**,**.functions.jmlc.**,**.functions.lineage.**",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,9 @@ private void rightIndexing(ExecutionContext ec) {
curFedRange.setBeginDim(0, Math.max(rs - ixrange.rowStart, 0));
curFedRange.setBeginDim(1, Math.max(cs - ixrange.colStart, 0));
curFedRange.setEndDim(0,
(ixrange.rowEnd > re ? re - ixrange.rowStart : ixrange.rowEnd - ixrange.rowStart + 1));
(ixrange.rowEnd >= re ? re - ixrange.rowStart : ixrange.rowEnd - ixrange.rowStart + 1));
curFedRange.setEndDim(1,
(ixrange.colEnd > ce ? ce - ixrange.colStart : ixrange.colEnd - ixrange.colStart + 1));
(ixrange.colEnd >= ce ? ce - ixrange.colStart : ixrange.colEnd - ixrange.colStart + 1));
if(LOG.isDebugEnabled()) {
LOG.debug("Fed Mapping After : " + curFedRange);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ public static void setCheckPrivacy(boolean checkPrivacyParam){
* @return data object or data object with privacy constraint removed in case the privacy level was none.
*/
public static Data handlePrivacy(Data dataObject){
if(dataObject == null)
return null;
PrivacyConstraint privacyConstraint = dataObject.getPrivacyConstraint();
if (privacyConstraint != null){
PrivacyLevel privacyLevel = privacyConstraint.getPrivacyLevel();
Expand Down
15 changes: 14 additions & 1 deletion src/test/java/org/apache/sysds/test/AutomatedTestBase.java
Original file line number Diff line number Diff line change
Expand Up @@ -1421,6 +1421,19 @@ protected Process startLocalFedWorker(int port) {
* @return the thread associated with the worker.
*/
protected Thread startLocalFedWorkerThread(int port) {
return startLocalFedWorkerThread(port, FED_WORKER_WAIT);
}

/**
* Start a thread for a worker. This will share the same JVM, so all static variables will be shared.!
*
* Also when using the local Fed Worker thread the statistics printing, and clearing from the worker is disabled.
*
* @param port Port to use
* @param sleep The amount of time to wait for the worker startup. in Milliseconds
* @return the thread associated with the worker.
*/
protected Thread startLocalFedWorkerThread(int port, int sleep) {
Thread t = null;
String[] fedWorkArgs = {"-w", Integer.toString(port)};
ArrayList<String> args = new ArrayList<>();
Expand All @@ -1443,7 +1456,7 @@ protected Thread startLocalFedWorkerThread(int port) {
}
});
t.start();
java.util.concurrent.TimeUnit.MILLISECONDS.sleep(FED_WORKER_WAIT);
java.util.concurrent.TimeUnit.MILLISECONDS.sleep(sleep);
}
catch(InterruptedException e) {
e.printStackTrace();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,9 +114,9 @@ public void federatedL2SVM(Types.ExecMode execMode) {
int port2 = getRandomAvailablePort();
int port3 = getRandomAvailablePort();
int port4 = getRandomAvailablePort();
Thread t1 = startLocalFedWorkerThread(port1);
Thread t2 = startLocalFedWorkerThread(port2);
Thread t3 = startLocalFedWorkerThread(port3);
Thread t1 = startLocalFedWorkerThread(port1, 10);
Thread t2 = startLocalFedWorkerThread(port2, 10);
Thread t3 = startLocalFedWorkerThread(port3, 10);
Thread t4 = startLocalFedWorkerThread(port4);

TestConfiguration config = availableTestConfigurations.get(TEST_NAME);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,9 +102,9 @@ private void runAggregateOperationTest(ExecMode execMode) {
int port2 = getRandomAvailablePort();
int port3 = getRandomAvailablePort();
int port4 = getRandomAvailablePort();
Thread t1 = startLocalFedWorkerThread(port1);
Thread t2 = startLocalFedWorkerThread(port2);
Thread t3 = startLocalFedWorkerThread(port3);
Thread t1 = startLocalFedWorkerThread(port1, 10);
Thread t2 = startLocalFedWorkerThread(port2, 10);
Thread t3 = startLocalFedWorkerThread(port3, 10);
Thread t4 = startLocalFedWorkerThread(port4);

rtplatform = execMode;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ public void federatedGLM(Types.ExecMode execMode) {
fullDMLScriptName = "";
int port1 = getRandomAvailablePort();
int port2 = getRandomAvailablePort();
Thread t1 = startLocalFedWorkerThread(port1);
Thread t1 = startLocalFedWorkerThread(port1, 10);
Thread t2 = startLocalFedWorkerThread(port2);

TestConfiguration config = availableTestConfigurations.get(TEST_NAME);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,8 @@

package org.apache.sysds.test.functions.federated.algorithms;

import org.junit.Assert;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import java.util.Arrays;
import java.util.Collection;

import org.apache.sysds.common.Types;
import org.apache.sysds.common.Types.ExecMode;
Expand All @@ -33,9 +31,11 @@
import org.apache.sysds.test.AutomatedTestBase;
import org.apache.sysds.test.TestConfiguration;
import org.apache.sysds.test.TestUtils;

import java.util.Arrays;
import java.util.Collection;
import org.junit.Assert;
import org.junit.Ignore;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;

@RunWith(value = Parameterized.class)
@net.jcip.annotations.NotThreadSafe
Expand Down Expand Up @@ -64,9 +64,10 @@ public void setUp() {
@Parameterized.Parameters
public static Collection<Object[]> data() {
// rows have to be even and > 1
return Arrays.asList(new Object[][] {{10000, 10, 1, 1},
return Arrays.asList(new Object[][] {
// {10000, 10, 1, 1},
// {2000, 50, 1, 1}, {1000, 100, 1, 1},
{10000, 10, 2, 1},
// {10000, 10, 2, 1},
// {2000, 50, 2, 1}, {1000, 100, 2, 1}, //concurrent requests
{10000, 10, 2, 2}, // repeated exec
// TODO more runs e.g., 16 -> but requires rework RPC framework first
Expand All @@ -80,6 +81,7 @@ public void federatedKmeansSinglenode() {
}

@Test
@Ignore
public void federatedKmeansHybrid() {
federatedKmeans(Types.ExecMode.HYBRID);
}
Expand All @@ -102,7 +104,7 @@ public void federatedKmeans(Types.ExecMode execMode) {
fullDMLScriptName = "";
int port1 = getRandomAvailablePort();
int port2 = getRandomAvailablePort();
Thread t1 = startLocalFedWorkerThread(port1);
Thread t1 = startLocalFedWorkerThread(port1, 10);
Thread t2 = startLocalFedWorkerThread(port2);

TestConfiguration config = availableTestConfigurations.get(TEST_NAME);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ public void federatedL2SVM(Types.ExecMode execMode) {
fullDMLScriptName = "";
int port1 = getRandomAvailablePort();
int port2 = getRandomAvailablePort();
Thread t1 = startLocalFedWorkerThread(port1);
Thread t1 = startLocalFedWorkerThread(port1, 10);
Thread t2 = startLocalFedWorkerThread(port2);

TestConfiguration config = availableTestConfigurations.get(TEST_NAME);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ public void federatedLogReg(Types.ExecMode execMode) {
fullDMLScriptName = "";
int port1 = getRandomAvailablePort();
int port2 = getRandomAvailablePort();
Thread t1 = startLocalFedWorkerThread(port1);
Thread t1 = startLocalFedWorkerThread(port1, 10);
Thread t2 = startLocalFedWorkerThread(port2);

TestConfiguration config = availableTestConfigurations.get(TEST_NAME);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,9 +102,9 @@ public void federatedL2SVM(Types.ExecMode execMode) {
int port2 = getRandomAvailablePort();
int port3 = getRandomAvailablePort();
int port4 = getRandomAvailablePort();
Thread t1 = startLocalFedWorkerThread(port1);
Thread t2 = startLocalFedWorkerThread(port2);
Thread t3 = startLocalFedWorkerThread(port3);
Thread t1 = startLocalFedWorkerThread(port1, 10);
Thread t2 = startLocalFedWorkerThread(port2, 10);
Thread t3 = startLocalFedWorkerThread(port3, 10);
Thread t4 = startLocalFedWorkerThread(port4);

TestConfiguration config = availableTestConfigurations.get(TEST_NAME);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,9 +100,9 @@ public void federatedL2SVM(Types.ExecMode execMode) {
int port2 = getRandomAvailablePort();
int port3 = getRandomAvailablePort();
int port4 = getRandomAvailablePort();
Thread t1 = startLocalFedWorkerThread(port1);
Thread t2 = startLocalFedWorkerThread(port2);
Thread t3 = startLocalFedWorkerThread(port3);
Thread t1 = startLocalFedWorkerThread(port1, 10);
Thread t2 = startLocalFedWorkerThread(port2, 10);
Thread t3 = startLocalFedWorkerThread(port3, 10);
Thread t4 = startLocalFedWorkerThread(port4);

TestConfiguration config = availableTestConfigurations.get(TEST_NAME);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,9 +109,9 @@ private void runAggregateOperationTest(ExecMode execMode) {
int port2 = getRandomAvailablePort();
int port3 = getRandomAvailablePort();
int port4 = getRandomAvailablePort();
Thread t1 = startLocalFedWorkerThread(port1);
Thread t2 = startLocalFedWorkerThread(port2);
Thread t3 = startLocalFedWorkerThread(port3);
Thread t1 = startLocalFedWorkerThread(port1, 10);
Thread t2 = startLocalFedWorkerThread(port2, 10);
Thread t3 = startLocalFedWorkerThread(port3, 10);
Thread t4 = startLocalFedWorkerThread(port4);

rtplatform = execMode;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ public void federatedL2SVM(Types.ExecMode execMode) {
fullDMLScriptName = "";
int port1 = getRandomAvailablePort();
int port2 = getRandomAvailablePort();
Thread t1 = startLocalFedWorkerThread(port1);
Thread t1 = startLocalFedWorkerThread(port1, 10);
Thread t2 = startLocalFedWorkerThread(port2);

TestConfiguration config = availableTestConfigurations.get(TEST_NAME);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ public void federatedRead(Types.ExecMode execMode) {
fullDMLScriptName = "";
int port1 = getRandomAvailablePort();
int port2 = getRandomAvailablePort();
Thread t1 = startLocalFedWorkerThread(port1);
Thread t1 = startLocalFedWorkerThread(port1, 10);
Thread t2 = startLocalFedWorkerThread(port2);
String host = "localhost";

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ public void federatedRead(Types.ExecMode execMode) {
fullDMLScriptName = "";
int port1 = getRandomAvailablePort();
int port2 = getRandomAvailablePort();
Thread t1 = startLocalFedWorkerThread(port1);
Thread t1 = startLocalFedWorkerThread(port1, 10);
Thread t2 = startLocalFedWorkerThread(port2);
String host = "localhost";

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ public void federatedWrite(ExecMode execMode) {
fullDMLScriptName = "";
int port1 = getRandomAvailablePort();
int port2 = getRandomAvailablePort();
Thread t1 = startLocalFedWorkerThread(port1);
Thread t1 = startLocalFedWorkerThread(port1, 10);
Thread t2 = startLocalFedWorkerThread(port2);

try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;


@RunWith(value = Parameterized.class)
@net.jcip.annotations.NotThreadSafe
public class FederatedParamservTest extends AutomatedTestBase {
Expand All @@ -60,15 +59,12 @@ public class FederatedParamservTest extends AutomatedTestBase {
@Parameterized.Parameters
public static Collection<Object[]> parameters() {
return Arrays.asList(new Object[][] {
//Network type, number of federated workers, examples per worker, batch size, epochs, learning rate, update type, update frequency
{"TwoNN", 2, 2, 1, 5, 0.01, "BSP", "BATCH"},
{"TwoNN", 2, 2, 1, 5, 0.01, "ASP", "BATCH"},
{"TwoNN", 2, 2, 1, 5, 0.01, "BSP", "EPOCH"},
{"TwoNN", 2, 2, 1, 5, 0.01, "ASP", "EPOCH"},
{"CNN", 2, 2, 1, 5, 0.01, "BSP", "BATCH"},
{"CNN", 2, 2, 1, 5, 0.01, "ASP", "BATCH"},
{"CNN", 2, 2, 1, 5, 0.01, "BSP", "EPOCH"},
{"CNN", 2, 2, 1, 5, 0.01, "ASP", "EPOCH"},
// Network type, number of federated workers, examples per worker, batch size, epochs, learning rate, update
// type, update frequency
{"TwoNN", 2, 2, 1, 5, 0.01, "BSP", "BATCH"}, {"TwoNN", 2, 2, 1, 5, 0.01, "ASP", "BATCH"},
{"TwoNN", 2, 2, 1, 5, 0.01, "BSP", "EPOCH"}, {"TwoNN", 2, 2, 1, 5, 0.01, "ASP", "EPOCH"},
{"CNN", 2, 2, 1, 5, 0.01, "BSP", "BATCH"}, {"CNN", 2, 2, 1, 5, 0.01, "ASP", "BATCH"},
{"CNN", 2, 2, 1, 5, 0.01, "BSP", "EPOCH"}, {"CNN", 2, 2, 1, 5, 0.01, "ASP", "EPOCH"},
{"TwoNN", 5, 1000, 200, 2, 0.01, "BSP", "BATCH"},
// {"TwoNN", 5, 1000, 200, 2, 0.01, "ASP", "BATCH"},
// {"TwoNN", 5, 1000, 200, 2, 0.01, "BSP", "EPOCH"},
Expand All @@ -80,7 +76,8 @@ public static Collection<Object[]> parameters() {
});
}

public FederatedParamservTest(String networkType, int numFederatedWorkers, int examplesPerWorker, int batch_size, int epochs, double eta, String utype, String freq) {
public FederatedParamservTest(String networkType, int numFederatedWorkers, int examplesPerWorker, int batch_size,
int epochs, double eta, String utype, String freq) {
_networkType = networkType;
_numFederatedWorkers = numFederatedWorkers;
_examplesPerWorker = examplesPerWorker;
Expand All @@ -101,31 +98,30 @@ public void setUp() {
public void federatedParamservSingleNode() {
federatedParamserv(ExecMode.SINGLE_NODE);
}

@Test
public void federatedParamservHybrid() {
federatedParamserv(ExecMode.HYBRID);
}

private void federatedParamserv(ExecMode mode) {
// config
getAndLoadTestConfiguration(TEST_NAME);
String HOME = SCRIPT_DIR + TEST_DIR;
setOutputBuffering(true);

int C = 1, Hin = 28, Win = 28;
int numFeatures = C*Hin*Win;
int numFeatures = C * Hin * Win;
int numLabels = 10;

ExecMode platformOld = setExecMode(mode);

try {

// dml name
fullDMLScriptName = HOME + TEST_NAME + ".dml";
// generate program args
List<String> programArgsList = new ArrayList<>(Arrays.asList(
"-stats",
List<String> programArgsList = new ArrayList<>(Arrays.asList("-stats",
"-nvargs",
"examples_per_worker=" + _examplesPerWorker,
"num_features=" + numFeatures,
Expand All @@ -138,28 +134,39 @@ private void federatedParamserv(ExecMode mode) {
"network_type=" + _networkType,
"channels=" + C,
"hin=" + Hin,
"win=" + Win
));

"win=" + Win));

// for each worker
List<Integer> ports = new ArrayList<>();
List<Thread> threads = new ArrayList<>();
for(int i = 0; i < _numFederatedWorkers; i++) {
// write row partitioned features to disk
writeInputMatrixWithMTD("X" + i, generateDummyMNISTFeatures(_examplesPerWorker, C, Hin, Win), false,
new MatrixCharacteristics(_examplesPerWorker, numFeatures, _blocksize, _examplesPerWorker * numFeatures));
writeInputMatrixWithMTD("X" + i,
generateDummyMNISTFeatures(_examplesPerWorker, C, Hin, Win),
false,
new MatrixCharacteristics(_examplesPerWorker, numFeatures, _blocksize,
_examplesPerWorker * numFeatures));
// write row partitioned labels to disk
writeInputMatrixWithMTD("y" + i, generateDummyMNISTLabels(_examplesPerWorker, numLabels), false,
new MatrixCharacteristics(_examplesPerWorker, numLabels, _blocksize, _examplesPerWorker * numLabels));

writeInputMatrixWithMTD("y" + i,
generateDummyMNISTLabels(_examplesPerWorker, numLabels),
false,
new MatrixCharacteristics(_examplesPerWorker, numLabels, _blocksize,
_examplesPerWorker * numLabels));

// start worker
ports.add(getRandomAvailablePort());
threads.add(startLocalFedWorkerThread(ports.get(i)));
threads.add(startLocalFedWorkerThread(ports.get(i), 10));

// add worker to program args
programArgsList.add("X" + i + "=" + TestUtils.federatedAddress(ports.get(i), input("X" + i)));
programArgsList.add("y" + i + "=" + TestUtils.federatedAddress(ports.get(i), input("y" + i)));
}
try {
Thread.sleep(1000);
}
catch(InterruptedException e) {
e.printStackTrace();
}

programArgs = programArgsList.toArray(new String[0]);
LOG.debug(runTest(null));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ public void federatedMultiply(Types.ExecMode execMode) {

int port1 = getRandomAvailablePort();
int port2 = getRandomAvailablePort();
Thread t1 = startLocalFedWorkerThread(port1);
Thread t1 = startLocalFedWorkerThread(port1, 10);
Thread t2 = startLocalFedWorkerThread(port2);

TestConfiguration config = availableTestConfigurations.get(TEST_NAME);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ public void federatedMultiply(Types.ExecMode execMode) {

int port1 = getRandomAvailablePort();
int port2 = getRandomAvailablePort();
Thread t1 = startLocalFedWorkerThread(port1);
Thread t1 = startLocalFedWorkerThread(port1, 10);
Thread t2 = startLocalFedWorkerThread(port2);

TestConfiguration config = availableTestConfigurations.get(TEST_NAME);
Expand Down
Loading

0 comments on commit 914b8f8

Please sign in to comment.