Skip to content

Commit ca73528

Browse files
Feature/1124 Added Pytorch Lightning Test support for all models (airctic#1125)
* feature/1124 Added Pytorch Lightning Test support for mmdet * feature/1124 Added Pytorch Lightning Test support for efficientdet * feature/1124 Added Pytorch Lightning Test support for torchvision * feature/1124 Added Pytorch Lightning Test support for yolov5 * feature/1124 Renamed shared evaluation method for mmdet * feature/1124 Renamed shared evaluation method for efficientdet * feature/1124 Renamed shared evaluation method for torchvision * feature/1124 Renamed shared evaluation method for yolov5 * feature/1124 Added test for PL test step of efficientdet * feature/1124 Added test for PL test step of mmdet * feature/1124 Added test for PL test step of torchvision models * feature/1124 Added test for PL test step of yolov5 * feature/1124 Updated object detection getting started guide to add PL test * feature/1124 Updated environment.yml to fix icevision version not found * feature/1124 Added unit tests for efficientdet lightning model adapter * feature/1124 Added unit tests for mmdet lightning model adapter * feature/1124 Added unit tests for torchvision lightning model adapter * feature/1124 Added unit tests for yolov5 lightning model adapter * feature/1124 Added unit tests for fastai unet lightning model adapter * feature/1124 Fixed memory consumption in PL tests
1 parent 356d29b commit ca73528

File tree

27 files changed

+1546
-66
lines changed

27 files changed

+1546
-66
lines changed

environment.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ dependencies:
1313
- fastai >=2.5.2,<2.6
1414
- pip
1515
- pip:
16-
- icevision[all]==0.12.0rc1
16+
- icevision[all]==0.12.0
1717
- --find-links https://download.openmmlab.com/mmcv/dist/cu102/torch1.10.0/index.html
1818
- mmcv-full==1.3.17
1919
- mmdet==2.17.0

icevision/models/fastai/unet/lightning.py

+28-10
Original file line numberDiff line numberDiff line change
@@ -33,28 +33,46 @@ def training_step(self, batch, batch_idx):
3333
(xb, yb), _ = batch
3434
preds = self(xb)
3535

36-
loss = self.loss_func(preds, yb)
36+
loss = self.compute_loss(preds, yb)
3737

3838
self.log("train_loss", loss)
3939

4040
return loss
4141

42+
def compute_loss(self, preds, yb):
43+
return self.loss_func(preds, yb)
44+
4245
def validation_step(self, batch, batch_idx):
46+
self._shared_eval(batch=batch, loss_log_key="val")
47+
48+
def _shared_eval(self, batch, loss_log_key):
4349
(xb, yb), records = batch
4450

45-
with torch.no_grad():
46-
preds = self(xb)
47-
loss = self.loss_func(preds, yb)
51+
preds = self(xb)
52+
loss = self.compute_loss(preds, yb)
4853

49-
preds = unet.convert_raw_predictions(
50-
batch=xb,
51-
raw_preds=preds,
52-
records=records,
53-
)
54+
preds = self.convert_raw_predictions(
55+
batch=xb,
56+
raw_preds=preds,
57+
records=records,
58+
)
5459

5560
self.accumulate_metrics(preds)
5661

57-
self.log("val_loss", loss)
62+
self.log(f"{loss_log_key}_loss", loss)
63+
64+
def convert_raw_predictions(self, batch, raw_preds, records):
65+
return unet.convert_raw_predictions(
66+
batch=batch,
67+
raw_preds=raw_preds,
68+
records=records,
69+
)
5870

5971
def validation_epoch_end(self, outs):
6072
self.finalize_metrics()
73+
74+
def test_step(self, batch, batch_idx):
75+
self._shared_eval(batch=batch, loss_log_key="test")
76+
77+
def test_epoch_end(self, outs):
78+
self.finalize_metrics()

icevision/models/mmdet/lightning/model_adapter.py

+16-12
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616

1717
class MMDetModelAdapter(LightningModelAdapter, ABC):
18-
"""Lightning module specialized for EfficientDet, with metrics support.
18+
"""Lightning module specialized for MMDet, with metrics support.
1919
2020
The methods `forward`, `training_step`, `validation_step`, `validation_epoch_end`
2121
are already overriden.
@@ -45,30 +45,34 @@ def training_step(self, batch, batch_idx):
4545
outputs = self.model.train_step(data=data, optimizer=None)
4646

4747
for k, v in outputs["log_vars"].items():
48-
self.log(f"train/{k}", v)
48+
self.log(f"train_{k}", v)
4949

5050
return outputs["loss"]
5151

5252
def validation_step(self, batch, batch_idx):
53+
self._shared_eval(batch, loss_log_key="val")
54+
55+
def _shared_eval(self, batch, loss_log_key):
5356
data, records = batch
5457

55-
self.model.eval()
56-
with torch.no_grad():
57-
outputs = self.model.train_step(data=data, optimizer=None)
58-
raw_preds = self.model.forward_test(
59-
imgs=[data["img"]], img_metas=[data["img_metas"]]
60-
)
58+
outputs = self.model.train_step(data=data, optimizer=None)
59+
raw_preds = self.model.forward_test(
60+
imgs=[data["img"]], img_metas=[data["img_metas"]]
61+
)
6162

6263
preds = self.convert_raw_predictions(
6364
batch=data, raw_preds=raw_preds, records=records
6465
)
6566
self.accumulate_metrics(preds)
6667

6768
for k, v in outputs["log_vars"].items():
68-
self.log(f"valid/{k}", v)
69-
70-
# TODO: is train and eval model automatically set by lighnting?
71-
self.model.train()
69+
self.log(f"{loss_log_key}_{k}", v)
7270

7371
def validation_epoch_end(self, outs):
7472
self.finalize_metrics()
73+
74+
def test_step(self, batch, batch_idx):
75+
self._shared_eval(batch=batch, loss_log_key="test")
76+
77+
def test_epoch_end(self, outs):
78+
self.finalize_metrics()

icevision/models/ross/efficientdet/lightning/model_adapter.py

+30-12
Original file line numberDiff line numberDiff line change
@@ -31,31 +31,49 @@ def training_step(self, batch, batch_idx):
3131
(xb, yb), records = batch
3232
preds = self(xb, yb)
3333

34-
loss = efficientdet.loss_fn(preds, yb)
34+
loss = self.compute_loss(preds, yb)
3535

3636
for k, v in preds.items():
37-
self.log(f"train/{k}", v)
37+
self.log(f"train_{k}", v)
3838

3939
return loss
4040

41+
def compute_loss(self, preds, yb):
42+
return efficientdet.loss_fn(preds, yb)
43+
4144
def validation_step(self, batch, batch_idx):
45+
self._shared_eval(batch, loss_log_key="val")
46+
47+
def _shared_eval(self, batch, loss_log_key):
4248
(xb, yb), records = batch
4349

44-
with torch.no_grad():
45-
raw_preds = self(xb, yb)
46-
preds = efficientdet.convert_raw_predictions(
47-
batch=(xb, yb),
48-
raw_preds=raw_preds["detections"],
49-
records=records,
50-
detection_threshold=0.0,
51-
)
52-
loss = efficientdet.loss_fn(raw_preds, yb)
50+
raw_preds = self(xb, yb)
51+
52+
preds = self.convert_raw_predictions(xb, yb, raw_preds, records)
53+
54+
self.compute_loss(raw_preds, yb)
5355

5456
self.accumulate_metrics(preds)
5557

5658
for k, v in raw_preds.items():
5759
if "loss" in k:
58-
self.log(f"valid/{k}", v)
60+
self.log(f"{loss_log_key}_{k}", v)
61+
62+
def convert_raw_predictions(self, xb, yb, raw_preds, records):
63+
# Note: raw_preds["detections"] key is available only during Pytorch Lightning validation/test step
64+
# Calling the method manually (instead of letting the Trainer call it) will raise an exception.
65+
return efficientdet.convert_raw_predictions(
66+
batch=(xb, yb),
67+
raw_preds=raw_preds["detections"],
68+
records=records,
69+
detection_threshold=0.0,
70+
)
5971

6072
def validation_epoch_end(self, outs):
6173
self.finalize_metrics()
74+
75+
def test_step(self, batch, batch_idx):
76+
self._shared_eval(batch=batch, loss_log_key="test")
77+
78+
def test_epoch_end(self, outs):
79+
self.finalize_metrics()

icevision/models/torchvision/lightning_model_adapter.py

+24-12
Original file line numberDiff line numberDiff line change
@@ -29,26 +29,38 @@ def training_step(self, batch, batch_idx):
2929
(xb, yb), records = batch
3030
preds = self(xb, yb)
3131

32-
loss = loss_fn(preds, yb)
32+
loss = self.compute_loss(preds, yb)
3333
self.log("train_loss", loss)
3434

3535
return loss
3636

37+
def compute_loss(self, preds, yb):
38+
return loss_fn(preds, yb)
39+
3740
def validation_step(self, batch, batch_idx):
41+
self._shared_eval(batch, loss_log_key="val")
42+
43+
def _shared_eval(self, batch, loss_log_key):
3844
(xb, yb), records = batch
39-
with torch.no_grad():
40-
self.train()
41-
train_preds = self(xb, yb)
42-
loss = loss_fn(train_preds, yb)
4345

44-
self.eval()
45-
raw_preds = self(xb)
46-
preds = self.convert_raw_predictions(
47-
batch=batch, raw_preds=raw_preds, records=records
48-
)
49-
self.accumulate_metrics(preds=preds)
46+
self.train()
47+
preds = self(xb, yb)
48+
loss = self.compute_loss(preds, yb)
49+
50+
self.eval()
51+
raw_preds = self(xb)
52+
preds = self.convert_raw_predictions(
53+
batch=batch, raw_preds=raw_preds, records=records
54+
)
55+
self.accumulate_metrics(preds=preds)
5056

51-
self.log("val_loss", loss)
57+
self.log(f"{loss_log_key}_loss", loss)
5258

5359
def validation_epoch_end(self, outs):
5460
self.finalize_metrics()
61+
62+
def test_step(self, batch, batch_idx):
63+
self._shared_eval(batch=batch, loss_log_key="test")
64+
65+
def test_epoch_end(self, outs):
66+
self.finalize_metrics()

icevision/models/ultralytics/yolov5/lightning/model_adapter.py

+30-11
Original file line numberDiff line numberDiff line change
@@ -40,22 +40,41 @@ def training_step(self, batch, batch_idx):
4040
return loss
4141

4242
def validation_step(self, batch, batch_idx):
43+
self._shared_eval(batch, loss_log_key="val")
44+
45+
def _shared_eval(self, batch, loss_log_key):
4346
(xb, yb), records = batch
4447

45-
with torch.no_grad():
46-
inference_out, training_out = self(xb)
47-
preds = yolov5.convert_raw_predictions(
48-
batch=xb,
49-
raw_preds=inference_out,
50-
records=records,
51-
detection_threshold=0.001,
52-
nms_iou_threshold=0.6,
53-
)
54-
loss = self.compute_loss(training_out, yb)[0]
48+
inference_out, training_out = self(xb)
49+
preds = self.convert_raw_predictions(
50+
batch=xb,
51+
raw_preds=inference_out,
52+
records=records,
53+
detection_threshold=0.001,
54+
nms_iou_threshold=0.6,
55+
)
56+
loss = self.compute_loss(training_out, yb)[0]
5557

5658
self.accumulate_metrics(preds)
5759

58-
self.log("val_loss", loss)
60+
self.log(f"{loss_log_key}_loss", loss)
61+
62+
def convert_raw_predictions(
63+
self, batch, raw_preds, records, detection_threshold, nms_iou_threshold
64+
):
65+
return yolov5.convert_raw_predictions(
66+
batch=batch,
67+
raw_preds=raw_preds,
68+
records=records,
69+
detection_threshold=detection_threshold,
70+
nms_iou_threshold=nms_iou_threshold,
71+
)
5972

6073
def validation_epoch_end(self, outs):
6174
self.finalize_metrics()
75+
76+
def test_step(self, batch, batch_idx):
77+
self._shared_eval(batch=batch, loss_log_key="test")
78+
79+
def test_epoch_end(self, outs):
80+
self.finalize_metrics()

notebooks/getting_started_object_detection.ipynb

+21-3
Original file line numberDiff line numberDiff line change
@@ -802,6 +802,24 @@
802802
"trainer.fit(light_model, train_dl, valid_dl)"
803803
]
804804
},
805+
{
806+
"cell_type": "markdown",
807+
"metadata": {},
808+
"source": [
809+
"#### Testing using Pytorch Lightning\n",
810+
"For testing, it is recommended to use a separate test dataset that the model did not see during training but for demonstration purposes we'll re-use the validation dataset."
811+
]
812+
},
813+
{
814+
"cell_type": "code",
815+
"execution_count": null,
816+
"metadata": {},
817+
"outputs": [],
818+
"source": [
819+
"trainer = pl.Trainer()\n",
820+
"trainer.test(light_model, valid_dl)"
821+
]
822+
},
805823
{
806824
"cell_type": "markdown",
807825
"metadata": {
@@ -961,7 +979,7 @@
961979
"provenance": []
962980
},
963981
"kernelspec": {
964-
"display_name": "Python 3.8.5 ('icevision')",
982+
"display_name": "Python 3.10.4 64-bit",
965983
"language": "python",
966984
"name": "python3"
967985
},
@@ -975,7 +993,7 @@
975993
"name": "python",
976994
"nbconvert_exporter": "python",
977995
"pygments_lexer": "ipython3",
978-
"version": "3.8.5"
996+
"version": "3.10.4"
979997
},
980998
"metadata": {
981999
"interpreter": {
@@ -1002,7 +1020,7 @@
10021020
},
10031021
"vscode": {
10041022
"interpreter": {
1005-
"hash": "8de9f07afee82c69462511c5dd73ef92dc31ce4377be5cdd30293ad211c0627b"
1023+
"hash": "916dbcbb3f70747c44a77c7bcd40155683ae19c65e1c03b4aa3499c5328201f1"
10061024
}
10071025
},
10081026
"widgets": {

tests/engines/lightning/test_lightning_model_adapter.py

-1
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@ class DummLightningModelAdapter(LightningModelAdapter):
1616
pass
1717

1818

19-
# test if finalize metrics reports metrics correctly
2019
def test_finalze_metrics_reports_metrics_correctly(mocker):
2120
mocker.patch(
2221
"icevision.engines.lightning.lightning_model_adapter.LightningModelAdapter.log"

0 commit comments

Comments
 (0)