Skip to content

Commit

Permalink
Model Comparison Support in tensorboard_plugin_fairness_indicators
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 332568796
  • Loading branch information
jindalshivam09 authored and tf-model-analysis-team committed Sep 19, 2020
1 parent d496c91 commit 3f0dc47
Show file tree
Hide file tree
Showing 9 changed files with 1,010 additions and 771 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,9 @@ export class FairnessMetricsTable extends PolymerElement {
*/
headerRow_: {
type: Array,
computed: 'populateHeaderRow_(metrics, evalName, evalNameCompare)'
computed:
'populateHeaderRow_(data, dataCompare, metrics, evalName, evalNameCompare)',
notify: true,
},

/**
Expand All @@ -139,13 +141,15 @@ export class FairnessMetricsTable extends PolymerElement {

/**
* Populate header row
* @param {!Array} data
* @param {!Array} dataCompare
* @param {!Array<string>} metrics
* @param {string} evalName
* @param {string} evalCompareName
* @return {!Array<string>}
* @private
*/
populateHeaderRow_(metrics, evalName, evalCompareName) {
populateHeaderRow_(data, dataCompare, metrics, evalName, evalCompareName) {
if (!metrics) {
return [];
}
Expand Down Expand Up @@ -258,6 +262,8 @@ export class FairnessMetricsTable extends PolymerElement {
return [[]];
}

this.headerRow_ = this.populateHeaderRow_(
data, dataCompare, metrics, evalName, evalNameCompare);
let tableRows = this.populateTableRows_(metrics, data, dataCompare);
return [this.headerRow_].concat(tableRows);
}
Expand Down Expand Up @@ -357,6 +363,10 @@ export class FairnessMetricsTable extends PolymerElement {
* @private
*/
getExampleCount_(rowNum, exampleCounts) {
if (!exampleCounts) {
return '';
}

// We skip the first row, since it is a header row which does not correspond
// to a slice.
let value = exampleCounts[parseFloat(rowNum) - 1];
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,6 @@ template.innerHTML = `
#metrics-and-slice-selector {
height: 100%
}
#run-selector > paper-dropdown-menu {
margin-left: 16px;
}
.flex-row {
display: flex;
flex-direction: row;
Expand All @@ -50,6 +47,21 @@ template.innerHTML = `
width: 90%;
margin-right: 16px;
}
#div-to-compare {
margin-left: 16px;
width: 90%;
margin-right: 16px;
}
#drop-down-to-compare {
width: 100%;
}
#model-comparison {
max-width: 290px;
overflow: hidden;
text-overflow: ellipsis;
display: block;
margin-top: 5px;
}
</style>
<div class="flex-row">
Expand All @@ -66,6 +78,27 @@ template.innerHTML = `
</template>
</paper-listbox>
</paper-dropdown-menu>
<paper-item>
<paper-checkbox id="model-comparison" checked="{{modelComparisonEnabled_}}">
<span class="model-comparison" title$="Enable Model Comparison">
Enable Model Comparison
</span>
</paper-checkbox>
</paper-item>
<div id="div-to-compare" hidden$="[[!modelComparisonEnabled_]]">
<paper-dropdown-menu id="drop-down-to-compare"
label="Select evaluation run to compare:"
title$="[[selectedEvaluationRunToCompare]]">
<paper-listbox selected="{{selectedEvaluationRunToCompare}}" attr-for-selected="run"
class="dropdown-content" slot="dropdown-content" title$="">
<template is="dom-repeat" items="[[availableEvaluationRuns]]">
<paper-item run="[[item]]">
<span class="evaluation-run" title$="[[item]]">[[item]]</span>
</paper-item>
</template>
</paper-listbox>
</paper-dropdown-menu>
</div>
</paper-card>
<paper-card id="metrics-and-slice-selector">
<fairness-metric-and-slice-selector available-metrics="[[selectableMetrics_]]"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ import {template} from './fairness-nb-container-template.html.js';

import {SelectEventMixin} from '../../../../frontend/tfma-nb-event-mixin/tfma-nb-event-mixin.js';

import '@polymer/paper-checkbox/paper-checkbox.js';
import '@polymer/paper-item/paper-item.js';
import '@polymer/paper-card/paper-card.js';
import '@polymer/iron-flex-layout/iron-flex-layout.js';
import '../fairness-metrics-board/fairness-metrics-board.js';
Expand Down Expand Up @@ -149,6 +151,13 @@ export class FairnessNbContainer extends SelectEventMixin

/** @type {string} */
selectedEvaluationRun: {type: String, notify: true},

/** @type {string} */
selectedEvaluationRunToCompare: {type: String, notify: true},

/** @type {boolean} */
modelComparisonEnabled_:
{type: Boolean, notify: true, observer: 'onModelComparisonEnabled_'}
};
}

Expand All @@ -161,6 +170,8 @@ export class FairnessNbContainer extends SelectEventMixin
if (slicingMetrics) {
tfma.Data.flattenMetrics(slicingMetrics, 'metrics');
this.flattenSlicingMetrics_ = slicingMetrics;
} else {
this.flattenSlicingMetrics_ = [];
}
this.availableMetricsNames_ =
this.computeAvailableMetricsNames_(slicingMetrics);
Expand All @@ -175,6 +186,8 @@ export class FairnessNbContainer extends SelectEventMixin
if (slicingMetricsCompare) {
tfma.Data.flattenMetrics(slicingMetricsCompare, 'metrics');
this.flattenSlicingMetricsCompare_ = slicingMetricsCompare;
} else {
this.flattenSlicingMetricsCompare_ = [];
}
}

Expand Down Expand Up @@ -216,6 +229,10 @@ export class FairnessNbContainer extends SelectEventMixin
* @private
*/
updateSelectableMetrics_(availableMetricsNames_) {
if (!availableMetricsNames_) {
this.selectableMetrics_ = [];
this.selectedMetrics_ = [];
}
const thresholdedMetrics = new Set();
const otherMetrics = new Set();
availableMetricsNames_.forEach(metricName => {
Expand Down Expand Up @@ -244,6 +261,20 @@ export class FairnessNbContainer extends SelectEventMixin
hideRunSelector_(hideSelectEvalRunDropDown, availableEvaluationRuns) {
return hideSelectEvalRunDropDown || !availableEvaluationRuns.length;
}

/**
* Handler listening to any changes in model comparison check box.
* @param {boolean} modelComparisonEnabled
* @private
*/
onModelComparisonEnabled_(modelComparisonEnabled) {
// If model comparison is turned off, set slicing metric to empty array.
if (!modelComparisonEnabled) {
this.slicingMetricsCompare = [];
this.flattenSlicingMetricsCompare_ = [];
this.selectedEvaluationRunToCompare = '';
}
}
};

customElements.define('fairness-nb-container', FairnessNbContainer);
Original file line number Diff line number Diff line change
Expand Up @@ -167,4 +167,22 @@ suite('fairness-nb-container tests', () => {
};
setTimeout(fillData, 0);
});

test('testModelComparisonFlow', done => {
const checkCheckbox = () => {
let checkbox =
fairnessContainer.shadowRoot.querySelector('#model-comparison');
assert.equal(checkbox.checked, false);

let runSelector =
fairnessContainer.shadowRoot.querySelector('#div-to-compare');
assert.equal(runSelector.hidden, true);

checkbox.checked = true;
assert.equal(runSelector.hidden, false);
done();
};

setTimeout(checkCheckbox, 0);
});
});
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
<!DOCTYPE html>
<!--
Copyright 2019 Google LLC
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
https://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
-->
<meta charset="utf-8">
<html>
<head>
</head>
<body>
<fairness-tensorboard-container></fairness-tensorboard-container>
</body>
</html>
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
/**
* Copyright 2019 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

(() => {
const createSliceMetrics = () => {
let result = {
'post_export_metrics/example_count':
{'doubleValue': Math.floor(Math.random() * 200 + 1000)},
'totalWeightedExamples':
{'doubleValue': Math.floor(Math.random() * 200 + 1000)},
};
const metrics = [
'accuracy',
'post_export_metrics/[email protected]',
'post_export_metrics/[email protected]',
'post_export_metrics/[email protected]',
'post_export_metrics/[email protected]',
'post_export_metrics/[email protected]',
'post_export_metrics/[email protected]',
];
metrics.forEach(metric => {
let value = Math.random() * 0.7 + 0.15;
result[metric] = {
'boundedValue': {
'lowerBound': value - 0.1,
'upperBound': value + 0.1,
'value': value,
'methodology': 'POISSON_BOOTSTRAP'
}
};
});
return result;
};

const SLICES = [
'Overall',
'Sex:Male',
'Sex:Female',
'Sex:Transgender',
'Race:asian',
'Race:latino',
'Race:black',
'Race:white',
'Religion:atheist',
'Religion:buddhist',
'Religion:christian',
'Religion:hindu',
'Religion:jewish',
];

const input1 = SLICES.map((slice) => {
return {
'slice': slice,
'sliceValue': slice.split(':')[1] || 'Overall',
'metrics': createSliceMetrics(),
};
});

const input2 = SLICES.map((slice) => {
return {
'slice': slice,
'sliceValue': slice.split(':')[1] || 'Overall',
'metrics': createSliceMetrics(),
};
});

const element =
document.getElementsByTagName('fairness-tensorboard-container')[0];
element.selectedEvaluationRun_ = '1';
element.selectedEvaluationRunCompare_ = '2';
element.evaluationRuns_ = ['1', '2', '3'];
element.hideSelectEvalRunDropDown_ = false;

element.slicingMetrics_ = input1;
element.slicingMetricsCompare_ = input2;

let container = element.shadowRoot.querySelector('fairness-nb-container');
container.shadowRoot.querySelector('#model-comparison').checked = true;
container.evalName = 'base';
container.evalNameCompare = 'compare';
})();
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,9 @@
<div class="container">
<fairness-nb-container available-evaluation-runs="[[evaluationRuns_]]"
selected-evaluation-run="{{selectedEvaluationRun_}}"
selected-evaluation-run-to-compare="{{selectedEvaluationRunCompare_}}"
slicing-metrics="[[slicingMetrics_]]"
slicing-metrics-compare="[[slicingMetricsCompare_]]"
hide-select-eval-run-drop-down="[[hideSelectEvalRunDropDown_]]">
</fairness-nb-container>

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,13 @@ export class FairnessTensorboardContainer extends SelectEventMixin
*/
selectedEvaluationRun_: {type: String, observer: 'runChanged_'},

/**
* Evaluation run selected by the user.
* @private {string}
*/
selectedEvaluationRunCompare_:
{type: String, observer: 'compareRunChanged_'},

/**
* The slicing metrics evaluation result. It's a list of dict with key
* "slice" and "metrics". For example:
Expand All @@ -130,17 +137,41 @@ export class FairnessTensorboardContainer extends SelectEventMixin
* @private {!Array<!Object>}
*/
slicingMetrics_: {type: Array, notify: true, value: []},

/**
* @private {!Array<!Object>}
*/
slicingMetricsCompare_: {type: Array, notify: true, value: []},
};
}

runChanged_(run) {
this.slicingMetrics_ = [];
fetch(`${GET_EVAL_RESULTS_ENDPOINT}?run=${run}`)
.then(res => res.json())
.then(slicingMetrics => {
this.slicingMetrics_ = slicingMetrics;
});
}

compareRunChanged_(run) {
this.slicingMetricsCompare_ = [];
fetch(`${GET_EVAL_RESULTS_ENDPOINT}?run=${run}`)
.then(res => res.json())
.then(slicingMetricsCompare => {
this.slicingMetricsCompare_ = slicingMetricsCompare;
let nbContainer = this.shadowRoot.querySelector('fairness-nb-container');
nbContainer.evalName = 'base';
if (this.slicingMetricsCompare_) {
nbContainer.evalNameCompare = 'compare';
} else {
nbContainer.evalNameCompare = '';
}
});


}

evaluationOutputPathChanged_(path) {
fetch(
`${GET_EVAL_RESULTS_ENDPOINT_FROM_URL}?evaluation_output_path=${path}`)
Expand Down
Loading

0 comments on commit 3f0dc47

Please sign in to comment.