Skip to content

Commit

Permalink
Log transform usage
Browse files Browse the repository at this point in the history
Summary: Refactor build_transform() to return transform in the end. Call log_class_usage() before returning.

Reviewed By: kazhang

Differential Revision: D25653451

fbshipit-source-id: 099ade840ae9e8d2dff53fecb1100151ec0cf037
  • Loading branch information
Xin Lei authored and facebook-github-bot committed Dec 19, 2020
1 parent d756b94 commit 614fc55
Showing 1 changed file with 17 additions and 13 deletions.
30 changes: 17 additions & 13 deletions classy_vision/dataset/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import torchvision.transforms as transforms
import torchvision.transforms._transforms_video as transforms_video
from classy_vision.generic.registry_utils import import_all_modules
from classy_vision.generic.util import log_class_usage

from .classy_transform import ClassyTransform

Expand Down Expand Up @@ -49,20 +50,23 @@ def build_transform(transform_config: Dict[str, Any]) -> Callable:
transform_args = copy.deepcopy(transform_config)
del transform_args["name"]
if name in TRANSFORM_REGISTRY:
return TRANSFORM_REGISTRY[name].from_config(transform_args)
# the name should be available in torchvision.transforms
# if users specify the torchvision transform name in snake case,
# we need to convert it to title case.
if not (hasattr(transforms, name) or hasattr(transforms_video, name)):
name = name.title().replace("_", "")
assert hasattr(transforms, name) or hasattr(transforms_video, name), (
f"{name} isn't a registered tranform"
", nor is it available in torchvision.transforms"
)
if hasattr(transforms, name):
return getattr(transforms, name)(**transform_args)
transform = TRANSFORM_REGISTRY[name].from_config(transform_args)
else:
return getattr(transforms_video, name)(**transform_args)
# the name should be available in torchvision.transforms
# if users specify the torchvision transform name in snake case,
# we need to convert it to title case.
if not (hasattr(transforms, name) or hasattr(transforms_video, name)):
name = name.title().replace("_", "")
assert hasattr(transforms, name) or hasattr(transforms_video, name), (
f"{name} isn't a registered tranform"
", nor is it available in torchvision.transforms"
)
if hasattr(transforms, name):
transform = getattr(transforms, name)(**transform_args)
else:
transform = getattr(transforms_video, name)(**transform_args)
log_class_usage("Transform", transform.__class__)
return transform


def build_transforms(transforms_config: List[Dict[str, Any]]) -> Callable:
Expand Down

0 comments on commit 614fc55

Please sign in to comment.