Skip to content

Commit

Permalink
Internal change
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 465127497
  • Loading branch information
rui1996 authored and tensorflower-gardener committed Aug 3, 2022
1 parent 02ac43f commit c1f02ec
Showing 1 changed file with 15 additions and 6 deletions.
21 changes: 15 additions & 6 deletions official/projects/yt8m/eval_utils/eval_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,8 @@ def top_k_by_class(predictions, labels, k=20):
Args:
predictions: A numpy matrix containing the outputs of the model. Dimensions
are 'batch' x 'num_classes'.
labels: A numpy matrix containing the ground truth labels.
Dimensions are 'batch' x 'num_classes'.
labels: A numpy matrix containing the ground truth labels. Dimensions are
'batch' x 'num_classes'.
k: the top k non-zero entries to preserve in each prediction.
Returns:
Expand Down Expand Up @@ -143,9 +143,10 @@ def top_k_triplets(predictions, labels, k=20):
Args:
predictions: A numpy matrix containing the outputs of the model. Dimensions
are 'batch' x 'num_classes'.
labels: A numpy matrix containing the ground truth labels.
Dimensions are 'batch' x 'num_classes'.
labels: A numpy matrix containing the ground truth labels. Dimensions are
'batch' x 'num_classes'.
k: The number top predictions to pick.
Returns:
a sparse list of tuples in (prediction, class) format.
"""
Expand Down Expand Up @@ -175,7 +176,7 @@ def __init__(self, num_class, top_k, top_n):
self.sum_hit_at_one = 0.0
self.sum_perr = 0.0
self.map_calculator = map_calculator.MeanAveragePrecisionCalculator(
num_class, top_n=top_n)
num_class, filter_empty_classes=False, top_n=top_n)
self.global_ap_calculator = ap_calculator.AveragePrecisionCalculator()
self.top_k = top_k
self.num_examples = 0
Expand Down Expand Up @@ -217,9 +218,13 @@ def accumulate(self, predictions, labels):

return {"hit_at_one": mean_hit_at_one, "perr": mean_perr}

def get(self):
def get(self, return_per_class_ap=False):
"""Calculate the evaluation metrics for the whole epoch.
Args:
return_per_class_ap: a bool variable to determine whether return the
detailed class-wise ap for more detailed analysis. Default is `False`.
Raises:
ValueError: If no examples were accumulated.
Expand All @@ -243,6 +248,10 @@ def get(self):
"map": mean_ap,
"gap": gap
}

if return_per_class_ap:
epoch_info_dict["per_class_ap"] = aps

return epoch_info_dict

def clear(self):
Expand Down

0 comments on commit c1f02ec

Please sign in to comment.