forked from google-research/google-research
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Codebase for reproducing results in "Mitigating Bias in Calibration E…
…rror" PiperOrigin-RevId: 352021384
- Loading branch information
1 parent
793615e
commit f944d3c
Showing
28 changed files
with
3,830 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
# This is the list of authors for copyright purposes. | ||
# | ||
# This does not necessarily list everyone who has contributed code, since in | ||
# some cases, their employer may be the copyright holder. To see the full list | ||
# of contributors, see the revision history in source control. | ||
Google LLC | ||
Rebecca Roelofs | ||
Nicholas Cain | ||
Michael C. Mozer |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
# Mitigating bias in calibration error estimation | ||
|
||
Code for the paper: "[Mitigating Bias in Calibration Error Estimation](https://arxiv.org/abs/2012.08668)" | ||
|
||
|
||
## Setup | ||
|
||
### Install dependencies | ||
|
||
```bash | ||
virtualenv -p python3 env3 | ||
source env3/bin/activate | ||
pip install -r caltrain/requirements.txt | ||
``` | ||
|
||
### Download data | ||
|
||
```bash | ||
source env3/bin/activate | ||
DATA_DIR='./caltrain/data' # This is the default value if omitted below | ||
python -m caltrain.download_data --data_dir=${DATA_DIR} | ||
``` | ||
|
||
## Running | ||
|
||
### Setup | ||
|
||
Plots that are generated by each script are saved in a command-line configurable variable `plot_dir=./caltrain/plots` by default. To speed up computation, some values have been precomputed and cached. Each plotting script is configured to read these data from a command-line configurable variable `data_dir='./caltrain/data` by default. Generating Figure 3 requires downloading logit data from `https://github.com/markus93/NN_calibration/tree/master/logits` into the `data_dir` as well. | ||
|
||
The following environment variables should be defined: | ||
```bash | ||
export MPLBACKEND=Agg | ||
``` | ||
|
||
### Generating figures | ||
```bash | ||
source env3/bin/activate | ||
DATA_DIR='./caltrain/data' # This is the default value if omitted below | ||
PLOT_DIR='./caltrain/plots' # This is the default value if omitted below | ||
|
||
# Figure 1a left panel | ||
python -m caltrain.plot_intro_reliability_diagram --plot_dir=${PLOT_DIR} | ||
# Figure 1a right panel | ||
python -m caltrain.plot_intro_ece_distribution --plot_dir=${PLOT_DIR} | ||
# Figure 1b (both panels) | ||
python -m caltrain.plot_tce_assumptions --plot_dir=${PLOT_DIR} | ||
# Figure 2, Figure 7, Figure 8 | ||
python -m caltrain.plot_bias_heat_map --data_dir=${DATA_DIR} --plot_dir=${PLOT_DIR} | ||
# Figure 3 | ||
python -m caltrain.plot_glm_beta_eece_sece --data_dir=${DATA_DIR} --plot_dir=${PLOT_DIR} | ||
# Figure 4 | ||
python -m caltrain.plot_calibration_errors --data_dir=${DATA_DIR} --plot_dir=${PLOT_DIR} | ||
# Figure 5 | ||
python -m caltrain.plot_ece_vs_tce --data_dir=${DATA_DIR} --plot_dir=${PLOT_DIR} | ||
``` | ||
|
||
## Citing this work | ||
|
||
``` | ||
@article{roelofs2020mitigating, | ||
title={Mitigating bias in calibration error estimation}, | ||
author={Roelofs, Rebecca and Cain, Nicholas and Shlens, Jonathon and Mozer, Michael C}, | ||
journal={arXiv preprint arXiv:2012.08668}, | ||
year={2020} | ||
} | ||
``` | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
# coding=utf-8 | ||
# Copyright 2021 The Google Research Authors. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
"""Caltrain library.""" | ||
import seaborn as sns | ||
|
||
SAVEFIG_FORMAT = 'pdf' | ||
|
||
TRUE_DATASETS = [ | ||
'logistic', 'logistic_beta', 'polynomial', 'flip_polynomial', | ||
'two_param_polynomial', 'two_param_flip_polynomial', 'logistic_log_odds', | ||
'logistic_two_param_flip_polynomial' | ||
] | ||
|
||
dataset_mlmodel_imageset_map = { | ||
'resnet110_c10': ('resnet', 'c10'), | ||
'densenet40_c10': ('densenet', 'c10'), | ||
'resnet110_SD_c10': ('resnet_SD', 'c10'), | ||
'resnet_wide32_c10': ('wide_resnet', 'c10'), | ||
'resnet110_c100': ('resnet', 'c100'), | ||
'densenet40_c100': ('densenet', 'c100'), | ||
'resnet152_imgnet': ('resnet', 'imgnet'), | ||
'densenet161_imgnet': ('densenet', 'imgnet'), | ||
'resnet110_SD_c100': ('resnet_SD', 'c100'), | ||
'resnet_wide32_c100': ('wide_resnet', 'c100'), | ||
} | ||
|
||
mlmodel_linestyle_map = { | ||
'resnet': '-', | ||
'densenet': '--', | ||
'resnet_SD': '-.', | ||
'wide_resnet': ':' | ||
} | ||
|
||
mlmodel_marker_map = { | ||
'resnet': '*', | ||
'densenet': '^', | ||
'resnet_SD': 'o', | ||
'wide_resnet': 'd' | ||
} | ||
|
||
clrs = sns.color_palette('husl', n_colors=3) | ||
|
||
imageset_color_map = { | ||
'c10': clrs[0], | ||
'c100': clrs[1], | ||
'imgnet': clrs[2], | ||
} | ||
|
||
cetype_color_map = { | ||
'em_ece_bin': 'blue', | ||
'ew_ece_bin': 'navy', | ||
'em_ece_sweep': 'red', | ||
'ew_ece_sweep': 'darkred' | ||
} | ||
|
||
ce_type_paper_name_map = { | ||
'em_ece_bin': 'EM', | ||
'ew_ece_bin': 'EW', | ||
'em_ece_sweep': 'EMsweep', | ||
'ew_ece_sweep': 'EWsweep' | ||
} | ||
|
||
ml_model_name_map = { | ||
'resnet110_c10': 'ResNet', | ||
'densenet40_c10': 'DenseNet', | ||
'resnet110_SD_c10': 'ResNet_SD', | ||
'resnet_wide32_c10': 'Wide_ResNet', | ||
'resnet110_c100': 'ResNet', | ||
'densenet40_c100': 'DenseNet', | ||
'resnet152_imgnet': 'ResNet', | ||
'densenet161_imgnet': 'DenseNet', | ||
'resnet110_SD_c100': 'ResNet_SD', | ||
'resnet_wide32_c100': 'Wide_ResNet' | ||
} | ||
|
||
ml_data_name_map = { | ||
'imgnet': 'ImageNet', | ||
'c10': 'CIFAR-10', | ||
'c100': 'CIFAR-100' | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
# coding=utf-8 | ||
# Copyright 2021 The Google Research Authors. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
"""Binning methods.""" | ||
import abc | ||
import numpy as np | ||
|
||
|
||
class BinMethod(abc.ABC): | ||
"""General interface for specifying binning method.""" | ||
|
||
def __init__(self, num_bins): | ||
self.num_bins = num_bins | ||
|
||
@abc.abstractmethod | ||
def compute_bin_indices(self, scores): | ||
"""Assign a bin index for each score. | ||
Args: | ||
scores: np.array of shape (num_examples, num_classes) containing the | ||
model's confidence scores | ||
Returns: | ||
bin_indices: np.array of shape (num_examples, num_classes) containing the | ||
bin assignment for each score | ||
""" | ||
pass | ||
|
||
|
||
class BinEqualWidth(BinMethod): | ||
"""Divide the scores into equal-width bins.""" | ||
|
||
def compute_bin_indices(self, scores): | ||
"""Assign a bin index for each score assuming equal width bins. | ||
Args: | ||
scores: np.array of shape (num_examples, num_classes) containing the | ||
model's confidence scores | ||
Returns: | ||
bin_indices: np.array of shape (num_examples, num_classes) containing the | ||
bin assignment for each score | ||
""" | ||
edges = np.linspace(0.0, 1.0, self.num_bins + 1) | ||
bin_indices = np.digitize(scores, edges, right=False) | ||
# np.digitze uses one-indexed bins, switch to using 0-indexed | ||
bin_indices = bin_indices - 1 | ||
# Put examples with score equal to 1.0 in the last bin. | ||
bin_indices = np.where(scores == 1.0, self.num_bins - 1, bin_indices) | ||
return bin_indices | ||
|
||
|
||
class BinEqualExamples(BinMethod): | ||
"""Divide the scores into bins with equal number of examples.""" | ||
|
||
def compute_bin_indices(self, scores): | ||
"""Assign a bin index for each score assumes equal num examples per bin. | ||
Args: | ||
scores: np.ndarray of shape [N, K] containing the model's confidence | ||
Returns: | ||
bin_indices: np.ndarray of shape [N, K] containing the bin assignment for | ||
each score | ||
""" | ||
num_examples = scores.shape[0] | ||
num_classes = scores.shape[1] | ||
|
||
bin_indices = np.zeros((num_examples, num_classes), dtype=int) | ||
for k in range(num_classes): | ||
sort_ix = np.argsort(scores[:, k]) | ||
bin_indices[:, k][sort_ix] = np.minimum( | ||
self.num_bins - 1, | ||
np.floor((np.arange(num_examples) / num_examples) * | ||
self.num_bins)).astype(int) | ||
return bin_indices |
Oops, something went wrong.