Skip to content

Commit

Permalink
Fix spark.load_model not to delete the DFS path (mlflow#335)
Browse files Browse the repository at this point in the history
  • Loading branch information
aarondav authored Aug 21, 2018
1 parent 3fa4fc8 commit a97e06b
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 22 deletions.
28 changes: 15 additions & 13 deletions mlflow/spark.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,15 @@
import mlflow
from mlflow import pyfunc
from mlflow.models import Model
from mlflow.utils.logging_utils import eprint

FLAVOR_NAME = "spark"

# Default temporary directory on DFS. Used to write / read from Spark ML models.
DFS_TMP = "/tmp/mlflow"


def log_model(spark_model, artifact_path, conda_env=None, jars=None, dfs_tmpdir=DFS_TMP):
def log_model(spark_model, artifact_path, conda_env=None, jars=None, dfs_tmpdir=None):
"""
Log a Spark MLlib model as an MLflow artifact for the current run.
Expand All @@ -42,7 +43,7 @@ def log_model(spark_model, artifact_path, conda_env=None, jars=None, dfs_tmpdir=
destination and then copied into the model's artifact directory. This is
necessary as Spark ML models read / write from / to DFS if running on a
cluster. All temporary files created on the DFS will be removed if this
operation completes successfully.
operation completes successfully. Defaults to /tmp/mlflow.
>>> from pyspark.ml import Pipeline
>>> from pyspark.ml.classification import LogisticRegression
Expand Down Expand Up @@ -117,7 +118,7 @@ def delete(cls, path):


def save_model(spark_model, path, mlflow_model=Model(), conda_env=None, jars=None,
dfs_tmpdir=DFS_TMP):
dfs_tmpdir=None):
"""
Save Spark MLlib PipelineModel at given local path.
Expand All @@ -133,7 +134,7 @@ def save_model(spark_model, path, mlflow_model=Model(), conda_env=None, jars=Non
destination and then copied to the requested local path. This is necessary
as Spark ML models read / write from / to DFS if running on a cluster. All
temporary files created on the DFS will be removed if this operation
completes successfully.
completes successfully. Defaults to /tmp/mlflow.
>>> from mlflow import spark
Expand All @@ -143,6 +144,7 @@ def save_model(spark_model, path, mlflow_model=Model(), conda_env=None, jars=Non
>>> model = ...
>>> mlflow.spark.save_model(model, "spark-model")
"""
dfs_tmpdir = dfs_tmpdir if dfs_tmpdir is not None else DFS_TMP
if jars:
raise Exception("jar dependencies are not implemented")
if not isinstance(spark_model, Transformer):
Expand All @@ -169,7 +171,7 @@ def save_model(spark_model, path, mlflow_model=Model(), conda_env=None, jars=Non
mlflow_model.save(os.path.join(path, "MLmodel"))


def load_model(path, run_id=None, dfs_tmpdir=DFS_TMP):
def load_model(path, run_id=None, dfs_tmpdir=None):
"""
Load the Spark MLlib model from the path.
Expand All @@ -190,6 +192,7 @@ def load_model(path, run_id=None, dfs_tmpdir=DFS_TMP):
>>> prediction = model.transform(test)
"""
dfs_tmpdir = dfs_tmpdir if dfs_tmpdir is not None else DFS_TMP
if run_id is not None:
path = mlflow.tracking.utils._get_model_log_dir(model_name=path, run_id=run_id)
m = Model.load(os.path.join(path, 'MLmodel'))
Expand All @@ -198,14 +201,13 @@ def load_model(path, run_id=None, dfs_tmpdir=DFS_TMP):
conf = m.flavors[FLAVOR_NAME]
model_path = os.path.join(path, conf['model_data'])
tmp_path = _tmp_path(dfs_tmpdir)
try:
# Spark ML expects the model to be stored on DFS
# Copy the model to a temp DFS location first.
_HadoopFileSystem.copy_from_local_file(model_path, tmp_path, removeSrc=False)
pipeline_model = PipelineModel.load(tmp_path)
return pipeline_model
finally:
_HadoopFileSystem.delete(tmp_path)
# Spark ML expects the model to be stored on DFS
# Copy the model to a temp DFS location first. We cannot delete this file, as
# Spark may read from it at any point.
_HadoopFileSystem.copy_from_local_file(model_path, tmp_path, removeSrc=False)
pipeline_model = PipelineModel.load(tmp_path)
eprint("Copied SparkML model to %s" % tmp_path)
return pipeline_model


def load_pyfunc(path):
Expand Down
17 changes: 8 additions & 9 deletions tests/spark/test_spark_model_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,8 @@ def test_model_export(tmpdir):
assert preds1 == preds3
assert os.path.exists(sparkm.DFS_TMP)
print(os.listdir(sparkm.DFS_TMP))
assert not os.listdir(sparkm.DFS_TMP)
# We expect not to delete the DFS tempdir.
assert os.listdir(sparkm.DFS_TMP)


@pytest.mark.large
Expand Down Expand Up @@ -129,28 +130,26 @@ def test_model_log(tmpdir):
mlflow.start_run()
artifact_path = "model%d" % cnt
cnt += 1
if dfs_tmp_dir:
sparkm.log_model(artifact_path=artifact_path, spark_model=model,
dfs_tmpdir=dfs_tmp_dir)
else:
sparkm.log_model(artifact_path=artifact_path, spark_model=model)
sparkm.log_model(artifact_path=artifact_path, spark_model=model,
dfs_tmpdir=dfs_tmp_dir)
run_id = active_run().info.run_uuid
# test pyfunc
x = pyfunc.load_pyfunc(artifact_path, run_id=run_id)
preds2 = x.predict(pandas_df)
assert preds1 == preds2
# test load model
reloaded_model = sparkm.load_model(artifact_path, run_id=run_id)
reloaded_model = sparkm.load_model(artifact_path, run_id=run_id,
dfs_tmpdir=dfs_tmp_dir)
preds_df_1 = reloaded_model.transform(spark_df)
preds3 = [x.prediction for x in preds_df_1.select("prediction").collect()]
assert preds1 == preds3
# test spar_udf
preds4 = score_model_as_udf(artifact_path, run_id, pandas_df)
assert preds1 == preds4
# make sure we did not leave any temp files behind
# We expect not to delete the DFS tempdir.
x = dfs_tmp_dir or sparkm.DFS_TMP
assert os.path.exists(x)
assert not os.listdir(x)
assert os.listdir(x)
shutil.rmtree(x)
finally:
mlflow.end_run()
Expand Down

0 comments on commit a97e06b

Please sign in to comment.