Skip to content

Commit

Permalink
Add Batched Slice Key extractor.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 356611172
  • Loading branch information
huanmingf authored and tf-model-analysis-team committed Feb 9, 2021
1 parent e08c0c3 commit 47e6297
Show file tree
Hide file tree
Showing 3 changed files with 96 additions and 0 deletions.
8 changes: 8 additions & 0 deletions tensorflow_model_analysis/extractors/slice_key_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,14 @@ def process(self, element: types.Extracts) -> List[types.Extracts]:
slicer.get_slices_for_features_dicts(
features_dicts, util.get_features_from_extracts(element),
self._slice_spec))

# If SLICE_KEY_TYPES_KEY already exists, that means the
# SqlSliceKeyExtractor has generated some slice keys. We need to add
# them to current slice_keys list.
if (constants.SLICE_KEY_TYPES_KEY in element and
element[constants.SLICE_KEY_TYPES_KEY]):
slice_keys.extend(element[constants.SLICE_KEY_TYPES_KEY])

unique_slice_keys = list(set(slice_keys))
if len(slice_keys) != len(unique_slice_keys):
self._duplicate_slice_keys_counter.inc()
Expand Down
35 changes: 35 additions & 0 deletions tensorflow_model_analysis/extractors/slice_key_extractor_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,41 @@ class SliceTest(testutil.TensorflowModelAnalysisTest, parameterized.TestCase):
}], [slicer.SingleSliceSpec(columns=['interest'])], [[
(('interest', 'boats'),), (('interest', 'planes'),)
], [(('interest', 'planes'),), (('interest', 'trains'),)]]),
('features_with_batched_slices_keys', [''], [{
constants.FEATURES_KEY:
make_features_dict({
'gender': ['m'],
'age': [10],
'interest': ['cars']
}),
constants.SLICE_KEY_TYPES_KEY: [(
('age', '10'),
('interest', 'cars'),
)]
}, {
constants.FEATURES_KEY:
make_features_dict({
'gender': ['f'],
'age': [12],
'interest': ['cars']
}),
constants.SLICE_KEY_TYPES_KEY: [(
('age', '12'),
('interest', 'cars'),
)]
}], [slicer.SingleSliceSpec(columns=['gender'])], [[
(
('age', '10'),
('interest', 'cars'),
),
(('gender', 'm'),),
], [
(
('age', '12'),
('interest', 'cars'),
),
(('gender', 'f'),),
]]),
)
def testSliceKeys(self, model_names, extracts, slice_specs, expected_slices):
eval_config = config.EvalConfig(
Expand Down
53 changes: 53 additions & 0 deletions tensorflow_model_analysis/proto/config.proto
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,59 @@ message SlicingSpec {
// converted to ints and floats respectively and will be compared against both
// the string versions and int or float versions of the associated features.
map<string, string> feature_values = 2;

// This config is an alternative to the config above.
// It must have the pattern:
// "SELECT STRUCT({feature_name} [AS {slice_key}])
// [FROM example.feature_name [, example.feature_name, ... ]
// [WHERE ... ]]"
//
// The “example.feature_name” inside the FROM statement is used to flatten the
// repeated fields. For non-repeated fields, you can directly write the config
// as follows:
// “SELECT STRUCT(non_repeated_feature_a, non_repeated_feature_b)”.
//
// When executing, this SQL expression will be further wrapped as:
// “SELECT ARRAY({slice_keys_sql}) as slices FROM Examples as example”.
//
// The resulting output of the query will have the same number of rows as the
// input dataset. Each row will only have one column named "slices". Each row
// is a list. Each element in the list will be a list of tuple with
// ('key', 'value') pairs representing a slice. For example, a single row
// could be:
// [[(‘gender’, ‘male’), (‘country’: ‘USA’)], [(‘zip_code’, ‘123456’)]]
//
// In the user’s SQL statement, the “example” is a key word that binds to each
// input "row". The semantics of this variable will depend on the decoding of
// the input data to the Arrow representation (e.g., for tf.Example, each key
// is decoded to a separate column). Thus, structured data can be readily
// accessed by iterating/unnesting the fields of the "example" variable.
//
// Example 1:
// slice_keys_sql="SELECT STRUCT(gender) FROM example.gender"
// - This equals to config: feature_keys=[gender]
// - the slice key and value will be: (gender, {gender_value})
//
// Example 2:
// slice_keys_sql =
// "SELECT STRUCT(gender, country)
// FROM example.gender, example.country
// WHERE country = 'USA'"
// - This equals to config:
// feature_keys=[gender], feature_values={country:'USA'}
// - the slice key and value will be:
// (gender_x_country, {gender_value}_x_USA)
//
// Example 3 (background positive subgroup negative):
// slice_keys_sql=
// "SELECT STRUCT('male' as bpsn)
// FROM example
// WHERE ('male' not in UNNEST(example.gender) and 1 in
// UNNEST(example.label)) or
// ('male' in UNNEST(example.gender) and 0 in
// UNNEST(example.label))"
// - the slice key and value will be: (bpsn, male)
string slice_keys_sql = 3;
}

// Cross slicing specification.
Expand Down

0 comments on commit 47e6297

Please sign in to comment.