From ec0440761a5825e1f0ba1c411928c25e3753f4c5 Mon Sep 17 00:00:00 2001 From: Max Pumperla Date: Wed, 15 Aug 2018 17:00:23 +0200 Subject: [PATCH] persist elephas model --- .gitignore | 3 ++- elephas/spark_model.py | 33 +++++++++++++++++++++++++++++++-- tests/test_mllib_model.py | 8 +++++++- tests/test_spark_model.py | 10 +++++++--- 4 files changed, 47 insertions(+), 7 deletions(-) diff --git a/.gitignore b/.gitignore index 1592491..9218d94 100644 --- a/.gitignore +++ b/.gitignore @@ -68,4 +68,5 @@ examples/*.csv venv/ train.csv -test.csv \ No newline at end of file +test.csv +test.h5 \ No newline at end of file diff --git a/elephas/spark_model.py b/elephas/spark_model.py index 9db0005..58316c6 100644 --- a/elephas/spark_model.py +++ b/elephas/spark_model.py @@ -2,7 +2,10 @@ from __future__ import print_function import pyspark +import h5py +import json from keras.optimizers import serialize as serialize_optimizer +from keras.models import load_model from .utils import lp_to_simple_rdd from .utils import model_to_dict @@ -77,12 +80,25 @@ def get_train_config(epochs, batch_size, verbose, validation_split): 'validation_split': validation_split} def get_config(self): - return {'model': self.master_network.get_config(), - 'optimizer': self.optimizer.get_config(), + return {'parameter_server_mode': self.parameter_server_mode, + 'elephas_optimizer': self.optimizer.get_config(), 'mode': self.mode, 'frequency': self.frequency, 'num_workers': self.num_workers} + def save(self, file_name): + model = self.master_network + model.save(file_name) + f = h5py.File(file_name, mode='a') + + f.attrs['distributed_config'] = json.dumps({ + 'class_name': self.__class__.__name__, + 'config': self.get_config() + }).encode('utf8') + + f.flush() + f.close() + @property def master_network(self): return self._master_network @@ -168,6 +184,19 @@ def _fit(self, rdd, epochs, batch_size, verbose, validation_split): self.stop_server() +def load_spark_model(file_name): + model = load_model("test.h5") + f = h5py.File(file_name, mode='r') + + elephas_conf = json.loads(f.attrs.get('distributed_config')) + class_name = elephas_conf.get('class_name') + config = elephas_conf.get('config') + if class_name == "SparkModel": + return SparkModel(model=model, **config) + elif class_name == "SparkMLlibModel": + return SparkMLlibModel(model=model, **config) + + class SparkMLlibModel(SparkModel): def __init__(self, model, mode='asynchronous', frequency='epoch', parameter_server_mode='http', diff --git a/tests/test_mllib_model.py b/tests/test_mllib_model.py index 8a142bb..0f0e604 100644 --- a/tests/test_mllib_model.py +++ b/tests/test_mllib_model.py @@ -4,7 +4,7 @@ from keras.optimizers import RMSprop from keras.utils import np_utils -from elephas.spark_model import SparkMLlibModel +from elephas.spark_model import SparkMLlibModel, load_spark_model from elephas.utils.rdd_utils import to_labeled_point import pytest @@ -46,6 +46,12 @@ model.compile(rms, 'categorical_crossentropy', ['acc']) +def test_serialization(): + spark_model = SparkMLlibModel(model, frequency='epoch', mode='synchronous', num_workers=2) + spark_model.save("test.h5") + recov = load_spark_model("test.h5") + + def test_mllib_model(spark_context): # Build RDD from numpy features and labels lp_rdd = to_labeled_point(spark_context, x_train, y_train, categorical=True) diff --git a/tests/test_spark_model.py b/tests/test_spark_model.py index 60d8d3a..37e5639 100644 --- a/tests/test_spark_model.py +++ b/tests/test_spark_model.py @@ -5,12 +5,10 @@ from keras.datasets import mnist from keras.models import Sequential from keras.layers.core import Dense, Dropout, Activation -from keras.optimizers import SGD from keras.utils import np_utils -from elephas.spark_model import SparkModel +from elephas.spark_model import SparkModel, load_spark_model from elephas.utils.rdd_utils import to_simple_rdd -from elephas import optimizers as elephas_optimizers # Define basic parameters @@ -51,6 +49,12 @@ model.compile(optimizer="sgd", loss="categorical_crossentropy", metrics=["acc"]) +def test_serialization(): + spark_model = SparkModel(model, frequency='epoch', mode='synchronous', num_workers=2) + spark_model.save("test.h5") + recov = load_spark_model("test.h5") + + def test_spark_model_end_to_end(spark_context): rdd = to_simple_rdd(spark_context, x_train, y_train)