Skip to content

Commit 9debb48

Browse files
rsomani95lgvaz
authored andcommitted
add build_batch_kwargs to transform_dl
1 parent 38d46bf commit 9debb48

File tree

1 file changed

+13
-5
lines changed

1 file changed

+13
-5
lines changed

icevision/models/utils.py

+13-5
Original file line numberDiff line numberDiff line change
@@ -52,20 +52,28 @@ def freeze(params):
5252
p.requires_grad = False
5353

5454

55-
def transform_dl(dataset, build_batch, batch_tfms=None, **dataloader_kwargs):
56-
"""Creates collate_fn from build_batch by decorating it with apply_batch_tfms and unload_records"""
57-
collate_fn = apply_batch_tfms(build_batch, batch_tfms=batch_tfms)
55+
def transform_dl(
56+
dataset,
57+
build_batch,
58+
batch_tfms=None,
59+
build_batch_kwargs: dict = {},
60+
**dataloader_kwargs,
61+
):
62+
"""Creates collate_fn from build_batch (collate function) by decorating it with apply_batch_tfms and unload_records"""
63+
collate_fn = apply_batch_tfms(
64+
build_batch, batch_tfms=batch_tfms, **build_batch_kwargs
65+
)
5866
collate_fn = unload_records(collate_fn)
5967
return DataLoader(dataset=dataset, collate_fn=collate_fn, **dataloader_kwargs)
6068

6169

62-
def apply_batch_tfms(build_batch, batch_tfms=None):
70+
def apply_batch_tfms(build_batch, batch_tfms=None, **build_batch_kwargs):
6371
"""This decorator function applies batch_tfms to records before passing them to build_batch"""
6472

6573
def inner(records):
6674
if batch_tfms is not None:
6775
records = batch_tfms(records)
68-
return build_batch(records)
76+
return build_batch(records, **build_batch_kwargs)
6977

7078
return inner
7179

0 commit comments

Comments
 (0)