Skip to content

Commit

Permalink
Prepare for multi-class and multi-label confusion matrix.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 276198840
  • Loading branch information
tf-model-analysis-team committed Oct 23, 2019
1 parent 653e4a0 commit aab632b
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 0 deletions.
16 changes: 16 additions & 0 deletions tensorflow_model_analysis/frontend/lib/constants.js
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,8 @@ const PlotTypes = {
GAIN_CHART: 'gainChart',
MACRO_PRECISION_RECALL_CURVE: 'macroPrecisionRecallCurve',
MICRO_PRECISION_RECALL_CURVE: 'microPrecisionRecallCurve',
MULTI_CLASS_CONFUSION_MATRIX: 'multiClassConfusionMatrix',
MULTI_LABEL_CONFUSION_MATRIX: 'multiLabelConfusionMatrix',
PREDICTION_DISTRIBUTION: 'predictionDistribution',
PRECISION_RECALL_CURVE: 'precisionRecallCurve',
RESIDUAL_PLOT: 'residualPlot',
Expand All @@ -92,6 +94,8 @@ const PlotDataFieldNames = {
CONFUSION_MATRICES: 'matrices',
MACRO_PRECISION_RECALL_CURVE_DATA: 'macroValuesByThreshold',
MICRO_PRECISION_RECALL_CURVE_DATA: 'microValuesByThreshold',
MULTI_CLASS_CONFUSION_MATRIX_DATA: 'multiClassConfusionMatrixAtThresholds',
MULTI_LABEL_CONFUSION_MATRIX_DATA: 'multiLabelConfusionMatrixAtThresholds',
PRECISION_RECALL_CURVE_DATA: 'binaryClassificationByThreshold',
WEIGHTED_PRECISION_RECALL_CURVE_DATA: 'weightedValuesByThreshold',
};
Expand Down Expand Up @@ -199,6 +203,12 @@ goog.exportSymbol(
goog.exportSymbol(
'tfma.PlotDataFieldNames.MICRO_PRECISION_RECALL_CURVE_DATA',
PlotDataFieldNames.MICRO_PRECISION_RECALL_CURVE_DATA);
goog.exportSymbol(
'tfma.PlotDataFieldNames.MULTI_CLASS_CONFUSION_MATRIX_DATA',
PlotDataFieldNames.MULTI_CLASS_CONFUSION_MATRIX_DATA);
goog.exportSymbol(
'tfma.PlotDataFieldNames.MULTI_LABEL_CONFUSION_MATRIX_DATA',
PlotDataFieldNames.MULTI_LABEL_CONFUSION_MATRIX_DATA);
goog.exportSymbol(
'tfma.PlotDataFieldNames.PRECISION_RECALL_CURVE_DATA',
PlotDataFieldNames.PRECISION_RECALL_CURVE_DATA);
Expand All @@ -225,6 +235,12 @@ goog.exportSymbol(
goog.exportSymbol(
'tfma.PlotTypes.MICRO_PRECISION_RECALL_CURVE',
PlotTypes.MICRO_PRECISION_RECALL_CURVE);
goog.exportSymbol(
'tfma.PlotTypes.MULTI_CLASS_CONFUSION_MATRIX',
PlotTypes.MULTI_CLASS_CONFUSION_MATRIX);
goog.exportSymbol(
'tfma.PlotTypes.MULTI_LABEL_CONFUSION_MATRIX',
PlotTypes.MULTI_LABEL_CONFUSION_MATRIX);
goog.exportSymbol(
'tfma.PlotTypes.PREDICTION_DISTRIBUTION',
PlotTypes.PREDICTION_DISTRIBUTION);
Expand Down
4 changes: 4 additions & 0 deletions tensorflow_model_analysis/frontend/lib/data.js
Original file line number Diff line number Diff line change
Expand Up @@ -407,6 +407,10 @@ function getAvailablePlotTypes(plotMap) {
},
{
type: Constants.PlotTypes.WEIGHTED_PRECISION_RECALL_CURVE,
}, {
type: Constants.PlotTypes.MULTI_CLASS_CONFUSION_MATRIX,
}, {
type: Constants.PlotTypes.MULTI_LABEL_CONFUSION_MATRIX,
}
];

Expand Down
4 changes: 4 additions & 0 deletions tensorflow_model_analysis/frontend/lib/externs.js
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,8 @@ tfma.PlotDataFieldNames = {
CONFUSION_MATRICES: '',
MACRO_PRECISION_RECALL_CURVE_DATA: '',
MICRO_PRECISION_RECALL_CURVE_DATA: '',
MULTI_CLASS_CONFUSION_MATRIX_DATA: '',
MULTI_LABEL_CONFUSION_MATRIX_DATA: '',
PRECISION_RECALL_CURVE_DATA: '',
WEIGHTED_PRECISION_RECALL_CURVE_DATA: '',
};
Expand Down Expand Up @@ -342,6 +344,8 @@ tfma.PlotTypes = {
PREDICTION_DISTRIBUTION: '',
MACRO_PRECISION_RECALL_CURVE: '',
MICRO_PRECISION_RECALL_CURVE: '',
MULTI_CLASS_CONFUSION_MATRIX: '',
MULTI_LABEL_CONFUSION_MATRIX: '',
PRECISION_RECALL_CURVE: '',
RESIDUAL_PLOT: '',
ROC_CURVE: '',
Expand Down

0 comments on commit aab632b

Please sign in to comment.