Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
criccomini committed Jul 31, 2017
2 parents c616eaa + 1932ccc commit c08842f
Show file tree
Hide file tree
Showing 4 changed files with 91 additions and 51 deletions.
16 changes: 14 additions & 2 deletions airflow/contrib/operators/cloudml_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,12 @@ class CloudMLVersionOperator(BaseOperator):
If it is None, the only `operation` possible would be `list`.
:type version: dict
:param version_name: A name to use for the version being operated upon. If
not None and the `version` argument is None or does not have a value for
the `name` key, then this will be populated in the payload for the
`name` key.
:type version_name: string
:param gcp_conn_id: The connection ID to use when fetching connection info.
:type gcp_conn_id: string
Expand Down Expand Up @@ -372,13 +378,15 @@ class CloudMLVersionOperator(BaseOperator):
template_fields = [
'_model_name',
'_version',
'_version_name',
]

@apply_defaults
def __init__(self,
model_name,
project_id,
version,
version=None,
version_name=None,
gcp_conn_id='google_cloud_default',
operation='create',
delegate_to=None,
Expand All @@ -387,13 +395,17 @@ def __init__(self,

super(CloudMLVersionOperator, self).__init__(*args, **kwargs)
self._model_name = model_name
self._version = version
self._version = version or {}
self._version_name = version_name
self._gcp_conn_id = gcp_conn_id
self._delegate_to = delegate_to
self._project_id = project_id
self._operation = operation

def execute(self, context):
if 'name' not in self._version:
self._version['name'] = self._version_name

hook = CloudMLHook(
gcp_conn_id=self._gcp_conn_id, delegate_to=self._delegate_to)

Expand Down
84 changes: 53 additions & 31 deletions airflow/contrib/operators/cloudml_operator_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,31 +18,26 @@
import json
import os
import re
try: # python 2
from urlparse import urlsplit
except ImportError: # python 3
from urllib.parse import urlsplit

import dill

from airflow.contrib.hooks.gcs_hook import GoogleCloudStorageHook
from airflow.contrib.operators.cloudml_operator import CloudMLBatchPredictionOperator
from airflow.contrib.operators.cloudml_operator import _normalize_cloudml_job_id
from airflow.contrib.operators.dataflow_operator import DataFlowPythonOperator
from airflow.exceptions import AirflowException
from airflow.operators.python_operator import PythonOperator

from six.moves.urllib.parse import urlsplit

def create_evaluate_ops(task_prefix,
project_id,
job_id,
region,
data_format,
input_paths,
prediction_path,
metric_fn_and_keys,
validate_fn,
dataflow_options,
batch_prediction_job_id=None,
project_id=None,
region=None,
dataflow_options=None,
model_uri=None,
model_name=None,
version_name=None,
Expand Down Expand Up @@ -114,22 +109,6 @@ def validate_err_and_count(summary):
job name, which doesn't allow other characters.
:type task_prefix: string
:param model_uri: GCS path of the model exported by Tensorflow using
tensorflow.estimator.export_savedmodel(). It cannot be used with
model_name or version_name below. See CloudMLBatchPredictionOperator for
more detail.
:type model_uri: string
:param model_name: Used to indicate a model to use for prediction. Can be
used in combination with version_name, but cannot be used together with
model_uri. See CloudMLBatchPredictionOperator for more detail.
:type model_name: string
:param version_name: Used to indicate a model version to use for prediciton,
in combination with model_name. Cannot be used together with model_uri.
See CloudMLBatchPredictionOperator for more detail.
:type version_name: string
:param data_format: either of 'TEXT', 'TF_RECORD', 'TF_RECORD_GZIP'
:type data_format: string
Expand All @@ -149,9 +128,46 @@ def validate_err_and_count(summary):
good enough to push the model.
:type validate_fn: function
:param dataflow_options: options to run Dataflow jobs.
:param batch_prediction_job_id: the id to use for the Cloud ML Batch
prediction job. Passed directly to the CloudMLBatchPredictionOperator as
the job_id argument.
:type batch_prediction_job_id: string
:param project_id: the Google Cloud Platform project id in which to execute
Cloud ML Batch Prediction and Dataflow jobs. If None, then the `dag`'s
`default_args['project_id']` will be used.
:type project_id: string
:param region: the Google Cloud Platform region in which to execute Cloud ML
Batch Prediction and Dataflow jobs. If None, then the `dag`'s
`default_args['region']` will be used.
:type region: string
:param dataflow_options: options to run Dataflow jobs. If None, then the
`dag`'s `default_args['dataflow_default_options']` will be used.
:type dataflow_options: dictionary
:param model_uri: GCS path of the model exported by Tensorflow using
tensorflow.estimator.export_savedmodel(). It cannot be used with
model_name or version_name below. See CloudMLBatchPredictionOperator for
more detail.
:type model_uri: string
:param model_name: Used to indicate a model to use for prediction. Can be
used in combination with version_name, but cannot be used together with
model_uri. See CloudMLBatchPredictionOperator for more detail. If None,
then the `dag`'s `default_args['model_name']` will be used.
:type model_name: string
:param version_name: Used to indicate a model version to use for prediciton,
in combination with model_name. Cannot be used together with model_uri.
See CloudMLBatchPredictionOperator for more detail. If None, then the
`dag`'s `default_args['version_name']` will be used.
:type version_name: string
:param dag: The `DAG` to use for all Operators.
:type dag: airflow.DAG
:returns: a tuple of three operators, (prediction, summary, validation)
:rtype: tuple(DataFlowPythonOperator, DataFlowPythonOperator,
PythonOperator)
Expand All @@ -170,10 +186,19 @@ def validate_err_and_count(summary):
if not callable(validate_fn):
raise AirflowException("`validate_fn` param must be callable.")

if dag is not None and dag.default_args is not None:
default_args = dag.default_args
project_id = project_id or default_args.get('project_id')
region = region or default_args.get('region')
model_name = model_name or default_args.get('model_name')
version_name = version_name or default_args.get('version_name')
dataflow_options = dataflow_options or \
default_args.get('dataflow_default_options')

evaluate_prediction = CloudMLBatchPredictionOperator(
task_id=(task_prefix + "-prediction"),
project_id=project_id,
job_id=_normalize_cloudml_job_id(job_id),
job_id=batch_prediction_job_id,
region=region,
data_format=data_format,
input_paths=input_paths,
Expand All @@ -195,9 +220,6 @@ def validate_err_and_count(summary):
"metric_keys": ','.join(metric_keys)
},
dag=dag)
# TODO: "options" is not template_field of DataFlowPythonOperator (not sure
# if intended or by mistake); consider fixing in the DataFlowPythonOperator.
evaluate_summary.template_fields.append("options")
evaluate_summary.set_upstream(evaluate_prediction)

def apply_validate_fn(*args, **kwargs):
Expand Down
2 changes: 2 additions & 0 deletions airflow/contrib/operators/dataflow_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,8 @@ def execute(self, context):

class DataFlowPythonOperator(BaseOperator):

template_fields = ['options', 'dataflow_default_options']

@apply_defaults
def __init__(
self,
Expand Down
40 changes: 22 additions & 18 deletions tests/contrib/operators/test_cloudml_operator_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ class CreateEvaluateOpsTest(unittest.TestCase):
'inputPaths': ['gs://legal-bucket/fake-input-path/*'],
'outputPath': 'gs://legal-bucket/fake-output-path',
'region': 'us-east1',
'versionName': 'projects/test-project/models/test_model/versions/test_version',
}
SUCCESS_MESSAGE_MISSING_INPUT = {
'jobId': 'eval_test_prediction',
Expand All @@ -61,30 +62,27 @@ def setUp(self):
'owner': 'airflow',
'start_date': DEFAULT_DATE,
'end_date': DEFAULT_DATE,
'project_id': 'test-project',
'region': 'us-east1',
'model_name': 'test_model',
'version_name': 'test_version',
},
schedule_interval='@daily')
self.metric_fn = lambda x: (0.1,)
self.metric_fn_encoded = cloudml_operator_utils.base64.b64encode(
cloudml_operator_utils.dill.dumps(self.metric_fn, recurse=True))


def testSuccessfulRun(self):
input_with_model = self.INPUT_MISSING_ORIGIN.copy()
input_with_model['modelName'] = (
'projects/test-project/models/test_model')

pred, summary, validate = create_evaluate_ops(
task_prefix='eval-test',
project_id='test-project',
job_id='eval-test-prediction',
region=input_with_model['region'],
batch_prediction_job_id='eval-test-prediction',
data_format=input_with_model['dataFormat'],
input_paths=input_with_model['inputPaths'],
prediction_path=input_with_model['outputPath'],
model_name=input_with_model['modelName'].split('/')[-1],
metric_fn_and_keys=(self.metric_fn, ['err']),
validate_fn=(lambda x: 'err=%.1f' % x['err']),
dataflow_options=None,
dag=self.dag)

with patch('airflow.contrib.operators.cloudml_operator.'
Expand All @@ -100,8 +98,9 @@ def testSuccessfulRun(self):
'test-project',
{
'jobId': 'eval_test_prediction',
'predictionInput': input_with_model
}, ANY)
'predictionInput': input_with_model,
},
ANY)
self.assertEqual(success_message['predictionOutput'], result)

with patch('airflow.contrib.operators.dataflow_operator.'
Expand Down Expand Up @@ -133,22 +132,27 @@ def testSuccessfulRun(self):
self.assertEqual('err=0.9', result)

def testFailures(self):
input_with_model = self.INPUT_MISSING_ORIGIN.copy()
input_with_model['modelName'] = (
'projects/test-project/models/test_model')
dag = DAG(
'test_dag',
default_args={
'owner': 'airflow',
'start_date': DEFAULT_DATE,
'end_date': DEFAULT_DATE,
'project_id': 'test-project',
'region': 'us-east1',
},
schedule_interval='@daily')

input_with_model = self.INPUT_MISSING_ORIGIN.copy()
other_params_but_models = {
'task_prefix': 'eval-test',
'project_id': 'test-project',
'job_id': 'eval-test-prediction',
'region': input_with_model['region'],
'batch_prediction_job_id': 'eval-test-prediction',
'data_format': input_with_model['dataFormat'],
'input_paths': input_with_model['inputPaths'],
'prediction_path': input_with_model['outputPath'],
'metric_fn_and_keys': (self.metric_fn, ['err']),
'validate_fn': (lambda x: 'err=%.1f' % x['err']),
'dataflow_options': None,
'dag': self.dag,
'dag': dag,
}

with self.assertRaisesRegexp(ValueError, 'Missing model origin'):
Expand Down

0 comments on commit c08842f

Please sign in to comment.