Skip to content

Commit

Permalink
supports precision recall curve
Browse files Browse the repository at this point in the history
  • Loading branch information
lanpa committed Nov 10, 2017
1 parent 1620342 commit 8f5387c
Show file tree
Hide file tree
Showing 5 changed files with 162 additions and 2 deletions.
1 change: 1 addition & 0 deletions demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
writer.add_text('markdown Text', '''a|b\n-|-\nc|d''', n_iter)
for name, param in resnet18.named_parameters():
writer.add_histogram(name, param, n_iter)
writer.add_pr_curve('xoxo', np.random.randint(2, size=100), np.random.rand(100), n_iter) #needs tensorboard 0.4RC or later

# export scalar data to JSON for external processing
writer.export_scalars_to_json("./all_scalars.json")
Expand Down
25 changes: 25 additions & 0 deletions tensorboardX/src/plugin_pr_curve.proto
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

syntax = "proto3";

package tensorboard;

message PrCurvePluginData {
// Version `0` is the only supported version.
int32 version = 1;

uint32 num_thresholds = 2;
}
76 changes: 76 additions & 0 deletions tensorboardX/src/plugin_pr_curve_pb2.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

45 changes: 44 additions & 1 deletion tensorboardX/summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
from .src.summary_pb2 import SummaryMetadata
from .src.tensor_pb2 import TensorProto
from .src.tensor_shape_pb2 import TensorShapeProto
from .src.plugin_pr_curve_pb2 import PrCurvePluginData
from .x2num import makenp

_INVALID_TAG_CHARACTERS = _re.compile(r'[^-/\w\.]')
Expand Down Expand Up @@ -208,4 +209,46 @@ def text(tag, text):
smd = SummaryMetadata(plugin_data=PluginData)
tensor = TensorProto(dtype='DT_STRING', string_val=[text.encode(encoding='utf_8')], tensor_shape=TensorShapeProto(dim=[TensorShapeProto.Dim(size=1)]))
return Summary(value=[Summary.Value(node_name=tag, metadata=smd, tensor=tensor)])


def pr_curve(tag, labels, predictions, num_thresholds=127, weights=None):
if num_thresholds>127: # wierd, value > 127 breaks protobuf
num_thresholds = 127
data = compute_curve(labels, predictions, num_thresholds=num_thresholds, weights=weights)
pr_curve_plugin_data = PrCurvePluginData(version=0, num_thresholds=num_thresholds).SerializeToString()
PluginData = [SummaryMetadata.PluginData(plugin_name='pr_curves', content=pr_curve_plugin_data)]
smd = SummaryMetadata(plugin_data=PluginData)
tensor = TensorProto(dtype='DT_FLOAT', float_val=data.reshape(-1).tolist(),\
tensor_shape=TensorShapeProto(dim=[TensorShapeProto.Dim(size=data.shape[0]), TensorShapeProto.Dim(size=data.shape[1])]))
return Summary(value=[Summary.Value(tag=tag, metadata=smd, tensor=tensor)])

# https://github.com/tensorflow/tensorboard/blob/master/tensorboard/plugins/pr_curve/summary.py
def compute_curve(labels, predictions, num_thresholds=None, weights=None):

_MINIMUM_COUNT = 1e-7

if weights is None:
weights = 1.0

# Compute bins of true positives and false positives.
bucket_indices = np.int32(np.floor(predictions * (num_thresholds - 1)))
float_labels = labels.astype(np.float)
histogram_range = (0, num_thresholds - 1)
tp_buckets, _ = np.histogram(
bucket_indices,
bins=num_thresholds,
range=histogram_range,
weights=float_labels * weights)
fp_buckets, _ = np.histogram(
bucket_indices,
bins=num_thresholds,
range=histogram_range,
weights=(1.0 - float_labels) * weights)

# Obtain the reverse cumulative sum.
tp = np.cumsum(tp_buckets[::-1])[::-1]
fp = np.cumsum(fp_buckets[::-1])[::-1]
tn = fp[0] - fp
fn = tp[0] - tp
precision = tp / np.maximum(_MINIMUM_COUNT, tp + fp)
recall = tp / np.maximum(_MINIMUM_COUNT, tp + fn)
return np.stack((tp, fp, tn, fn, precision, recall))
17 changes: 16 additions & 1 deletion tensorboardX/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from .src import summary_pb2
from .src import graph_pb2
from .event_file_writer import EventFileWriter
from .summary import scalar, histogram, image, audio, text
from .summary import scalar, histogram, image, audio, text, pr_curve
from .graph import graph
from .graph_onnx import gg
from .embedding import make_mat, make_sprite, make_tsv, append_pbtxt
Expand Down Expand Up @@ -442,6 +442,21 @@ def add_embedding(self, mat, metadata=None, label_img=None, global_step=None, ta
#new funcion to append to the config file a new embedding
append_pbtxt(metadata, label_img, self.file_writer.get_logdir(), str(global_step).zfill(5), tag)


def add_pr_curve(self, tag, labels, predictions, global_step=None, num_thresholds=127, weights=None):
"""Adds precision recall curve.
Args:
tag (string): Data identifier
labels (torch.Tensor): Ground thuth data. Binary label for each element.
predictions (torch.Tensor): The probability that an element be classified as true. Value should in [0, 1]
global_step (int): Global step value to record
num_thresholds (int): Number of thresholds used to draw the curve.
"""
from .x2num import makenp
self.file_writer.add_summary(pr_curve(tag, makenp(labels), makenp(predictions), num_thresholds, weights), global_step)

def close(self):
if self.file_writer is None:
return # ignore double close
Expand Down

0 comments on commit 8f5387c

Please sign in to comment.