Skip to content

Commit

Permalink
Handle corner cases in auto slicer
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 327384122
  • Loading branch information
paulgc authored and tf-model-analysis-team committed Aug 19, 2020
1 parent e480471 commit 9d68707
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,8 @@ def process(self, element: types.Extracts) -> List[types.Extracts]:
element_copy = copy.deepcopy(element)
features = util.get_features_from_extracts(element_copy)
for feature_name, boundaries in self._bucket_boundaries.items():
if feature_name in features:
if (feature_name in features and features[feature_name] is not None and
features[feature_name].size > 0):
transformed_values = []
for value in features[feature_name]:
transformed_values.append(_bin_value(value, boundaries))
Expand Down
45 changes: 37 additions & 8 deletions tensorflow_model_analysis/slicer/auto_slicing_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import math
import operator
from typing import Dict, List, NamedTuple, Optional, Text, Tuple
from absl import logging
import numpy as np
import pandas as pd
from scipy import stats
Expand Down Expand Up @@ -100,14 +101,32 @@ def _is_significant_slice(slice_metric: float, slice_std_dev: float,
comparison_type: Text,
alpha: float) -> Tuple[bool, float]:
"""Perform statistical significance testing."""
_, p_value_two_sided = stats.ttest_ind_from_stats(
slice_metric,
slice_std_dev,
slice_weight,
base_metric,
base_std_dev,
base_weight,
equal_var=False)
assert base_std_dev > 0, ('base_std_dev must be positive, but got '
'{}.'.format(base_std_dev))
assert slice_std_dev > 0, ('slice_std_dev must be positive, but got '
'{}.'.format(slice_std_dev))
assert base_weight > 1, ('base_weight must be greater than 1, but got '
'{}.'.format(base_weight))
assert slice_weight > 1, ('slice_weight must be greater than 1, but got '
'{}.'.format(slice_weight))

try:
_, p_value_two_sided = stats.ttest_ind_from_stats(
slice_metric,
slice_std_dev,
slice_weight,
base_metric,
base_std_dev,
base_weight,
equal_var=False)
except ZeroDivisionError:
raise ZeroDivisionError(
'invalid ttest for params: slice_metric={}, '
'slice_std_dev={}, slice_weight={}, '
'base_metric={}, base_std_dev={}, base_weight={}, '.format(
slice_metric, slice_std_dev, slice_weight, base_metric,
base_std_dev, base_weight))

metric_diff = slice_metric - base_metric
one_sided_p_value = _two_sided_to_one_sided_pvalue(
p_value_two_sided, metric_diff, comparison_type=comparison_type)
Expand Down Expand Up @@ -292,6 +311,16 @@ def partition_slices(
# Prune non-interesting slices.
if np.isnan(slice_metrics_dict[metric_key].unsampled_value):
continue
if slice_metrics_dict[metric_key].sample_standard_deviation == 0:
logging.warning('Ignoring slice: %s with standard deviation: %s ',
slice_key,
slice_metrics_dict[metric_key].sample_standard_deviation)
continue
# TODO(pachristopher): Should we use weighted example count?
if slice_metrics_dict['example_count'].unsampled_value <= 1:
logging.warning('Ignoring slice: %s with example count: %s ', slice_key,
slice_metrics_dict['example_count'].unsampled_value)
continue
# Only consider statistically significant slices.
is_significant, p_value = _is_significant_slice(
slice_metrics_dict[metric_key].unsampled_value,
Expand Down
2 changes: 1 addition & 1 deletion tensorflow_model_analysis/slicer/slice_accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def __init__(self, features_dict: Union[types.DictOfTensorValue,
self._features_dict = features_dict

def has_key(self, key: Text):
return key in self._features_dict
return key in self._features_dict and self._features_dict[key] is not None

def get(self, key: Text) -> List[Union[int, bytes, float]]:
"""Get the values of the feature with the given key.
Expand Down

0 comments on commit 9d68707

Please sign in to comment.