Skip to content

Commit

Permalink
Updated slice key extractor to support passing features natively in d…
Browse files Browse the repository at this point in the history
…ict.

PiperOrigin-RevId: 272561043
  • Loading branch information
mdreves authored and tf-model-analysis-team committed Oct 3, 2019
1 parent 96ca3ff commit 51b6607
Showing 1 changed file with 13 additions and 8 deletions.
21 changes: 13 additions & 8 deletions tensorflow_model_analysis/extractors/slice_key_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,14 +73,19 @@ def __init__(self, slice_spec: List[slicer.SingleSliceSpec],
self._materialize = materialize

def process(self, element: types.Extracts) -> List[types.Extracts]:
fpl = element.get(constants.FEATURES_PREDICTIONS_LABELS_KEY)
if not fpl:
raise RuntimeError('FPL missing, Please ensure Predict() was called.')
if not isinstance(fpl, types.FeaturesPredictionsLabels):
raise TypeError(
'Expected FPL to be instance of FeaturesPredictionsLabel. FPL was: '
'%s of type %s' % (str(fpl), type(fpl)))
features = fpl.features
features = None
if constants.FEATURES_PREDICTIONS_LABELS_KEY in element:
fpl = element[constants.FEATURES_PREDICTIONS_LABELS_KEY]
if not isinstance(fpl, types.FeaturesPredictionsLabels):
raise TypeError(
'Expected FPL to be instance of FeaturesPredictionsLabel. FPL was: '
'%s of type %s' % (str(fpl), type(fpl)))
features = fpl.features
elif constants.FEATURES_KEY in element:
features = element[constants.FEATURES_KEY]
if not features:
raise RuntimeError(
'Features missing, Please ensure Predict() was called.')
slices = list(
slicer.get_slices_for_features_dict(features, self._slice_spec))

Expand Down

0 comments on commit 51b6607

Please sign in to comment.