Skip to content

Commit

Permalink
fix bug in logging train AP and reduce # batches for which train AP i…
Browse files Browse the repository at this point in the history
…s calculated to limit memory usage
  • Loading branch information
tlpss committed Aug 31, 2023
1 parent 918a60a commit 8085dbb
Showing 1 changed file with 10 additions and 5 deletions.
15 changes: 10 additions & 5 deletions keypoint_detection/models/detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,9 +254,7 @@ def shared_step(self, batch, batch_idx, include_visualization_data_in_result_dic

def training_step(self, train_batch, batch_idx):
log_images = batch_idx == 0 and self.current_epoch > 0
should_log_ap = (
self.is_ap_epoch()
) # and batch_idx < 20 # limit AP calculation to first 20 batches to save time
should_log_ap = self.is_ap_epoch() and batch_idx < 20 # limit AP calculation to first 20 batches to save time
include_vis_data = log_images or should_log_ap

result_dict = self.shared_step(
Expand Down Expand Up @@ -328,7 +326,7 @@ def validation_step(self, val_batch, batch_idx):
if self.is_ap_epoch():
self.update_ap_metrics(result_dict, self.ap_validation_metrics)

log_images = batch_idx == 0 and self.current_epoch > 0
log_images = batch_idx == 0 and self.current_epoch > 0 and self.is_ap_epoch()
if log_images:
image_grids = self.visualize_predictions_channels(result_dict)
self.log_image_grids(image_grids, mode="validation")
Expand All @@ -350,7 +348,14 @@ def test_step(self, test_batch, batch_idx):

def log_and_reset_mean_ap(self, mode: str):
mean_ap_per_threshold = torch.zeros(len(self.maximal_gt_keypoint_pixel_distances))
metrics = self.ap_test_metrics if mode == "test" else self.ap_validation_metrics
if mode == "train":
metrics = self.ap_training_metrics
elif mode == "validation":
metrics = self.ap_validation_metrics
elif mode == "test":
metrics = self.ap_test_metrics
else:
raise ValueError(f"mode {mode} not recognized")

# calculate APs for each channel and each threshold distance, and log them
print(f" # {mode} metrics:")
Expand Down

0 comments on commit 8085dbb

Please sign in to comment.