Skip to content

Commit d7037c3

Browse files
adamfarquharlgvaz
andauthored
Updated wandb integration and tutorial. (airctic#836)
* Updated wandb integration and tutorial. * Reformatted with black. * soft imports wandb_image Co-authored-by: lgvaz <[email protected]>
1 parent 19498f5 commit d7037c3

File tree

5 files changed

+195
-351
lines changed

5 files changed

+195
-351
lines changed

.gitignore

+3-1
Original file line numberDiff line numberDiff line change
@@ -143,11 +143,13 @@ checklink/cookies.txt
143143
logs/
144144
lightning_logs/
145145
examples/wandb
146+
notebooks/wandb
146147

147148
archives/
148149

149150
# mkdocs documentation
150151
/docs/site
151152
/docs/sources
152153

153-
*.pth
154+
*.pth
155+
notebooks/wandb/latest-run

icevision/visualize/__init__.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,8 @@
11
from icevision.visualize.utils import *
22
from icevision.visualize.draw_data import *
33
from icevision.visualize.show_data import *
4-
from icevision.visualize.wandb_img import *
4+
5+
from icevision.soft_dependencies import SoftDependencies
6+
7+
if SoftDependencies.wandb:
8+
from icevision.visualize.wandb_img import *

icevision/visualize/wandb_img.py

+77-71
Original file line numberDiff line numberDiff line change
@@ -1,107 +1,113 @@
1-
__all__ = [
2-
"wandb_img_preds",
3-
]
1+
__all__ = ["wandb_img_preds", "wandb_image"]
42

5-
from icevision.imports import *
6-
from icevision.data import *
7-
from icevision.core import *
83

4+
from typing import List
95

10-
def bbox_wandb(bbox: BBox, label: int, class_id_to_label, pred_score=None):
11-
"""Creates a wandb compatible dictionary with bbox, label and score"""
6+
import wandb
7+
from icevision import BaseRecord, BBox
8+
from icevision.data.prediction import Prediction
9+
10+
11+
def wandb_img_preds(
12+
preds: List[Prediction], add_ground_truth: bool = False
13+
) -> List[wandb.Image]:
14+
return [wandb_image(pred, add_ground_truth=add_ground_truth) for pred in preds]
15+
16+
17+
def bbox_wandb(bbox: BBox, label_id: int, label_name: str, score=None) -> dict:
18+
"""Return a wandb compatible dictionary with bbox, label and score"""
1219
xmin, ymin, xmax, ymax = map(int, bbox.xyxy)
1320

1421
box_data = {
1522
"position": {"minX": xmin, "maxX": xmax, "minY": ymin, "maxY": ymax},
16-
"class_id": label,
23+
"class_id": int(label_id),
1724
"domain": "pixel",
1825
}
1926

20-
if pred_score:
21-
score = int(pred_score * 100)
22-
box_caption = f"{class_id_to_label[label]} ({score}%)"
27+
if score:
28+
score = int(score * 100)
29+
box_caption = f"{label_name} ({score}%)"
2330
box_data["score"] = score
2431
else:
25-
box_caption = f"{class_id_to_label[label]}"
32+
box_caption = label_name
2633

2734
box_data["box_caption"] = box_caption
2835

2936
return box_data
3037

3138

32-
def wandb_image(sample, pred, class_id_to_label, add_ground_truth=False):
33-
raw_image = sample["img"]
34-
true_bboxes = sample["bboxes"]
35-
true_labels = sample["labels"]
39+
def wandb_image(pred: Prediction, add_ground_truth: bool = False) -> wandb.Image:
40+
"""Return a wandb image corresponding to the a prediction.
41+
42+
Args:
43+
pred (Prediction): A prediction to log with WandB.
44+
Must have been created with keep_image = True.
45+
add_ground_truth (bool, optional): Add ground_truth information to the
46+
the WandB image. Defaults to False.
47+
48+
Returns:
49+
wandb.Image: Specifying the image, but also the predictions and possibly ground_truth.
50+
"""
51+
# FIXME: if pred does not have an img, then we lose.
52+
# FIXME: Not handling masks
53+
3654
# Check if "masks" key is the sample dictionnary
37-
if "masks" in sample:
38-
true_masks = sample["masks"]
55+
# if "masks" in sample: true_masks = sample["masks"]
3956

40-
pred_bboxes = pred["bboxes"]
41-
pred_labels = pred["labels"].tolist()
42-
pred_scores = pred["scores"]
4357
# Check if "masks" key is the pred dictionnary
44-
if "masks" in pred:
45-
pred_masks = pred["masks"]
46-
47-
# Predicted Boxes
48-
pred_all_boxes = []
49-
# Collect predicted bounding boxes for this image
50-
for b_i, bbox in enumerate(pred_bboxes):
51-
box_data = bbox_wandb(
52-
bbox, pred_labels[b_i], class_id_to_label, pred_score=pred_scores[b_i]
53-
)
54-
pred_all_boxes.append(box_data)
58+
# if "masks" in pred: pred_masks = pred["masks"]
5559

56-
# log to wandb: raw image, predictions, and dictionary of class labels for each class id
57-
boxes = {
58-
"predictions": {"box_data": pred_all_boxes, "class_labels": class_id_to_label}
60+
class_id_to_label = {
61+
id: label for id, label in enumerate(pred.detection.class_map._id2class)
5962
}
6063

64+
# Prediction
65+
box_data = list(
66+
map(
67+
bbox_wandb,
68+
pred.detection.bboxes,
69+
pred.detection.label_ids,
70+
pred.detection.labels,
71+
pred.detection.scores,
72+
)
73+
)
74+
75+
boxes = {"predictions": {"box_data": box_data, "class_labels": class_id_to_label}}
76+
6177
# Predicted Masks
6278
# Check if "masks" key is the pred dictionnary
63-
if "masks" in pred:
64-
mask_data = (pred_masks.data * pred["labels"][:, None, None]).max(0)
65-
masks = {
66-
"predictions": {"mask_data": mask_data, "class_labels": class_id_to_label}
67-
}
68-
else:
69-
masks = None
79+
# if "masks" in pred:
80+
# mask_data = (pred_masks.data * pred["labels"][:, None, None]).max(0)
81+
# masks = {
82+
# "predictions": {"mask_data": mask_data, "class_labels": class_id_to_label}
83+
# }
84+
# else:
85+
# masks = None
86+
masks = None
7087

7188
# Ground Truth
7289
if add_ground_truth:
73-
# Ground Truth Boxes
74-
true_all_boxes = []
75-
# Collect ground truth bounding boxes for this image
76-
for b_i, bbox in enumerate(true_bboxes):
77-
box_data = bbox_wandb(bbox, true_labels[b_i], class_id_to_label)
78-
true_all_boxes.append(box_data)
90+
box_data = list(
91+
map(
92+
bbox_wandb,
93+
pred.ground_truth.detection.bboxes,
94+
pred.ground_truth.detection.label_ids,
95+
pred.ground_truth.detection.labels,
96+
)
97+
)
7998

8099
boxes["ground_truth"] = {
81-
"box_data": true_all_boxes,
100+
"box_data": box_data,
82101
"class_labels": class_id_to_label,
83102
}
84103

85104
# # Ground Truth Masks
86105
# Check if "masks" key is the sample dictionnary
87-
if "masks" in sample:
88-
labels_arr = np.array(sample["labels"])
89-
mask_data = (true_masks.data * labels_arr[:, None, None]).max(0)
90-
masks["ground_truth"] = {
91-
"mask_data": mask_data,
92-
"class_labels": class_id_to_label,
93-
}
94-
95-
return wandb.Image(raw_image, boxes=boxes, masks=masks)
96-
97-
98-
def wandb_img_preds(samples, preds, class_map, add_ground_truth=False):
99-
class_id_to_label = {int(v): k for k, v in class_map.class2id.items()}
100-
101-
wandb_imgs = []
102-
for (sample, pred) in zip(samples, preds):
103-
img_wandb = wandb_image(
104-
sample, pred, class_id_to_label, add_ground_truth=add_ground_truth
105-
)
106-
wandb_imgs.append(img_wandb)
107-
return wandb_imgs
106+
# if "masks" in sample:
107+
# labels_arr = np.array(sample["labels"])
108+
# mask_data = (true_masks.data * labels_arr[:, None, None]).max(0)
109+
# masks["ground_truth"] = {
110+
# "mask_data": mask_data,
111+
# "class_labels": class_id_to_label,
112+
# }
113+
return wandb.Image(pred.img, boxes=boxes, masks=masks)

notebooks/getting_started_object_detection.ipynb

+2-2
Original file line numberDiff line numberDiff line change
@@ -952,7 +952,7 @@
952952
"name": "python",
953953
"nbconvert_exporter": "python",
954954
"pygments_lexer": "ipython3",
955-
"version": "3.7.10"
955+
"version": "3.8.8"
956956
},
957957
"metadata": {
958958
"interpreter": {
@@ -1314,5 +1314,5 @@
13141314
}
13151315
},
13161316
"nbformat": 4,
1317-
"nbformat_minor": 1
1317+
"nbformat_minor": 4
13181318
}

notebooks/wandb_efficientdet.ipynb

+108-276
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)