Skip to content

Commit

Permalink
persist elephas model
Browse files Browse the repository at this point in the history
  • Loading branch information
maxpumperla committed Aug 15, 2018
1 parent e0105e8 commit ec04407
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 7 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -68,4 +68,5 @@ examples/*.csv
venv/

train.csv
test.csv
test.csv
test.h5
33 changes: 31 additions & 2 deletions elephas/spark_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@
from __future__ import print_function

import pyspark
import h5py
import json
from keras.optimizers import serialize as serialize_optimizer
from keras.models import load_model

from .utils import lp_to_simple_rdd
from .utils import model_to_dict
Expand Down Expand Up @@ -77,12 +80,25 @@ def get_train_config(epochs, batch_size, verbose, validation_split):
'validation_split': validation_split}

def get_config(self):
return {'model': self.master_network.get_config(),
'optimizer': self.optimizer.get_config(),
return {'parameter_server_mode': self.parameter_server_mode,
'elephas_optimizer': self.optimizer.get_config(),
'mode': self.mode,
'frequency': self.frequency,
'num_workers': self.num_workers}

def save(self, file_name):
model = self.master_network
model.save(file_name)
f = h5py.File(file_name, mode='a')

f.attrs['distributed_config'] = json.dumps({
'class_name': self.__class__.__name__,
'config': self.get_config()
}).encode('utf8')

f.flush()
f.close()

@property
def master_network(self):
return self._master_network
Expand Down Expand Up @@ -168,6 +184,19 @@ def _fit(self, rdd, epochs, batch_size, verbose, validation_split):
self.stop_server()


def load_spark_model(file_name):
model = load_model("test.h5")
f = h5py.File(file_name, mode='r')

elephas_conf = json.loads(f.attrs.get('distributed_config'))
class_name = elephas_conf.get('class_name')
config = elephas_conf.get('config')
if class_name == "SparkModel":
return SparkModel(model=model, **config)
elif class_name == "SparkMLlibModel":
return SparkMLlibModel(model=model, **config)


class SparkMLlibModel(SparkModel):

def __init__(self, model, mode='asynchronous', frequency='epoch', parameter_server_mode='http',
Expand Down
8 changes: 7 additions & 1 deletion tests/test_mllib_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from keras.optimizers import RMSprop
from keras.utils import np_utils

from elephas.spark_model import SparkMLlibModel
from elephas.spark_model import SparkMLlibModel, load_spark_model
from elephas.utils.rdd_utils import to_labeled_point

import pytest
Expand Down Expand Up @@ -46,6 +46,12 @@
model.compile(rms, 'categorical_crossentropy', ['acc'])


def test_serialization():
spark_model = SparkMLlibModel(model, frequency='epoch', mode='synchronous', num_workers=2)
spark_model.save("test.h5")
recov = load_spark_model("test.h5")


def test_mllib_model(spark_context):
# Build RDD from numpy features and labels
lp_rdd = to_labeled_point(spark_context, x_train, y_train, categorical=True)
Expand Down
10 changes: 7 additions & 3 deletions tests/test_spark_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,10 @@
from keras.datasets import mnist
from keras.models import Sequential
from keras.layers.core import Dense, Dropout, Activation
from keras.optimizers import SGD
from keras.utils import np_utils

from elephas.spark_model import SparkModel
from elephas.spark_model import SparkModel, load_spark_model
from elephas.utils.rdd_utils import to_simple_rdd
from elephas import optimizers as elephas_optimizers


# Define basic parameters
Expand Down Expand Up @@ -51,6 +49,12 @@
model.compile(optimizer="sgd", loss="categorical_crossentropy", metrics=["acc"])


def test_serialization():
spark_model = SparkModel(model, frequency='epoch', mode='synchronous', num_workers=2)
spark_model.save("test.h5")
recov = load_spark_model("test.h5")


def test_spark_model_end_to_end(spark_context):
rdd = to_simple_rdd(spark_context, x_train, y_train)

Expand Down

0 comments on commit ec04407

Please sign in to comment.