Skip to content

Commit

Permalink
Merge pull request tensorflow#4433 from casassg:casassg/elwc-example-…
Browse files Browse the repository at this point in the history
…gen-fix

PiperOrigin-RevId: 440804380
  • Loading branch information
tfx-copybara committed Apr 11, 2022
2 parents bb4125b + 96c0a60 commit 01253b4
Show file tree
Hide file tree
Showing 5 changed files with 33 additions and 20 deletions.
2 changes: 2 additions & 0 deletions RELEASE.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
blessed model at all (e.g. first run).
* Fix that the resolver with custom `ResolverStrategy` (assume correctly
packaged) fails.
* Fixed `ElwcBigQueryExampleGen` data serializiation error that was causing an
assertion failure on Beam.

## Dependency Updates

Expand Down
11 changes: 1 addition & 10 deletions tfx/extensions/google_cloud_big_query/example_gen/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

import apache_beam as beam

from apache_beam.options import value_provider
from google.cloud import bigquery
import tensorflow as tf

Expand Down Expand Up @@ -64,15 +63,7 @@ def _BigQueryToExample(pipeline: beam.Pipeline, exec_properties: Dict[str, Any],
Returns:
PCollection of TF examples.
"""

beam_pipeline_args = exec_properties['_beam_pipeline_args']
pipeline_options = beam.options.pipeline_options.PipelineOptions(
beam_pipeline_args)
# Try to parse the GCP project ID from the beam pipeline options.
project = pipeline_options.view_as(
beam.options.pipeline_options.GoogleCloudOptions).project
if isinstance(project, value_provider.ValueProvider):
project = project.get()
project = utils.parse_gcp_project(exec_properties['_beam_pipeline_args'])
converter = _BigQueryConverter(split_pattern, project)

return (pipeline
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,19 +57,20 @@ def process(


def _ConvertContextAndExamplesToElwc(
context_feature_and_examples: Tuple[bytes, List[tf.train.Example]]
) -> input_pb2.ExampleListWithContext:
context_feature_and_examples: Tuple[bytes,
List[tf.train.Example]]) -> bytes:
"""Convert context feature and examples to ELWC."""
context_feature, examples = context_feature_and_examples
context_feature_proto = tf.train.Example()
context_feature_proto.ParseFromString(context_feature)
return input_pb2.ExampleListWithContext(
elwc_pb2 = input_pb2.ExampleListWithContext(
context=context_feature_proto, examples=examples)
return elwc_pb2.SerializeToString(deterministic=True)


@beam.ptransform_fn
@beam.typehints.with_input_types(beam.Pipeline)
@beam.typehints.with_output_types(input_pb2.ExampleListWithContext)
@beam.typehints.with_output_types(bytes)
def _BigQueryToElwc(pipeline: beam.Pipeline, exec_properties: Dict[str, Any],
split_pattern: str) -> beam.pvalue.PCollection:
"""Read from BigQuery and transform to ExampleListWithContext.
Expand All @@ -89,13 +90,13 @@ def _BigQueryToElwc(pipeline: beam.Pipeline, exec_properties: Dict[str, Any],
Raises:
RuntimeError: Context features must be included in the queried result.
"""

project = utils.parse_gcp_project(exec_properties['_beam_pipeline_args'])
custom_config = example_gen_pb2.CustomConfig()
json_format.Parse(exec_properties['custom_config'], custom_config)
elwc_config = elwc_config_pb2.ElwcConfig()
custom_config.custom_config.Unpack(elwc_config)

client = bigquery.Client()
client = bigquery.Client(project=project)
# Dummy query to get the type information for each field.
query_job = client.query('SELECT * FROM ({}) LIMIT 0'.format(split_pattern))
results = query_job.result()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -353,6 +353,12 @@ def _MockReadFromBigQuery(pipeline, query):
return pipeline | beam.Create(mock_query_results)


def _DeserializeElwc(some_bytes):
elwc_pb2 = input_pb2.ExampleListWithContext()
elwc_pb2.ParseFromString(some_bytes)
return elwc_pb2


class ExecutorTest(tf.test.TestCase):

def setUp(self):
Expand Down Expand Up @@ -380,7 +386,7 @@ def testBigQueryToElwc(self, mock_client):
packed_custom_config.custom_config.Pack(elwc_config)
with beam.Pipeline() as pipeline:
elwc_examples = (
pipeline | 'ToElwc' >> executor._BigQueryToElwc(
pipeline | 'ToElwcBytes' >> executor._BigQueryToElwc(
exec_properties={
'_beam_pipeline_args': [],
'custom_config':
Expand All @@ -389,7 +395,8 @@ def testBigQueryToElwc(self, mock_client):
preserving_proto_field_name=True)
},
split_pattern='SELECT context_feature_1, context_feature_2, '
'feature_id_1, feature_id_2, feature_id_3 FROM `fake`'))
'feature_id_1, feature_id_2, feature_id_3 FROM `fake`')
| 'LoadElwc' >> beam.Map(_DeserializeElwc))

expected_elwc_examples = [_ELWC_1, _ELWC_2, _ELWC_3, _ELWC_4, _ELWC_5]
util.assert_that(elwc_examples, util.equal_to(expected_elwc_examples))
Expand Down
16 changes: 14 additions & 2 deletions tfx/extensions/google_cloud_big_query/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,19 @@
Internal utilities, no backwards compatibility guarantees.
"""

from typing import Any, Dict
from typing import Any, Dict, List

import apache_beam as beam
from apache_beam.io.gcp import bigquery
from apache_beam.options import value_provider
import tensorflow as tf
from tfx.utils import telemetry_utils


@beam.ptransform_fn
@beam.typehints.with_input_types(beam.Pipeline)
@beam.typehints.with_output_types(beam.typehints.Dict[str, Any])
def ReadFromBigQuery(pipeline: beam.Pipeline,
def ReadFromBigQuery(pipeline: beam.Pipeline, # pylint: disable=invalid-name
query: str) -> beam.pvalue.PCollection:
"""Read data from BigQuery.
Expand Down Expand Up @@ -89,3 +90,14 @@ def row_to_example( # pylint: disable=invalid-name
'BigQuery column type {} is not supported.'.format(data_type))

return tf.train.Example(features=tf.train.Features(feature=feature))


def parse_gcp_project(beam_pipeline_args: List[str]) -> str:
# Try to parse the GCP project ID from the beam pipeline options.
pipeline_options = beam.options.pipeline_options.PipelineOptions(
beam_pipeline_args)
project = pipeline_options.view_as(
beam.options.pipeline_options.GoogleCloudOptions).project
if isinstance(project, value_provider.ValueProvider):
project = project.get()
return project

0 comments on commit 01253b4

Please sign in to comment.