Skip to content

Commit

Permalink
Fix doc strings and minor argument update for cross entropy losses (f…
Browse files Browse the repository at this point in the history
…acebookresearch#651)

Summary:
Pull Request resolved: facebookresearch#651

Updated doc strings and changed `normalized_targets` to a bool

Reviewed By: lauragustafson

Differential Revision: D24902068

fbshipit-source-id: f0031a9814dea9b73a809e111cb75f527ed8495d
  • Loading branch information
mannatsingh authored and facebook-github-bot committed Nov 16, 2020
1 parent e2eba06 commit 932aabc
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 20 deletions.
19 changes: 10 additions & 9 deletions classy_vision/losses/label_smoothing_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,22 +16,25 @@

@register_loss("label_smoothing_cross_entropy")
class LabelSmoothingCrossEntropyLoss(ClassyLoss):
def __init__(self, ignore_index, reduction, smoothing_param):
def __init__(self, ignore_index=-100, reduction="mean", smoothing_param=None):
"""Intializer for the label smoothed cross entropy loss.
This decreases gap between output scores and encourages generalization.
Targets provided to forward can be one-hot vectors (NxC) or class indices(Nx1)
Targets provided to forward can be one-hot vectors (NxC) or class indices (Nx1).
Config params:
'weight': weight of sample (not yet implemented),
'ignore_index': sample should be ignored for loss (optional),
'smoothing_param': value to be added to each target entry
This normalizes the targets to a sum of 1 based on the total count of positive
targets for a given sample before applying label smoothing.
Args:
ignore_index: sample should be ignored for loss if the class is this value
reduction: specifies reduction to apply to the output
smoothing_param: value to be added to each target entry
"""
super().__init__()
self._ignore_index = ignore_index
self._reduction = reduction
self._smoothing_param = smoothing_param
self.loss_function = SoftTargetCrossEntropyLoss(
self._ignore_index, self._reduction, None
self._ignore_index, self._reduction, normalize_targets=False
)
self._eps = np.finfo(np.float32).eps

Expand All @@ -47,7 +50,6 @@ def from_config(cls, config: Dict[str, Any]) -> "LabelSmoothingCrossEntropyLoss"
A LabelSmoothingCrossEntropyLoss instance.
"""

assert "weight" not in config, '"weight" not implemented'
assert (
"smoothing_param" in config
), "Label Smoothing needs a smoothing parameter"
Expand Down Expand Up @@ -79,7 +81,6 @@ def compute_valid_targets(self, target, classes):
return valid_targets

def smooth_targets(self, valid_targets, classes):

"""
This function takes valid (No ignore values present) one-hot target vectors
and computes smoothed target vectors (normalized) according to the loss's
Expand Down
19 changes: 9 additions & 10 deletions classy_vision/losses/soft_target_cross_entropy_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,20 @@

@register_loss("soft_target_cross_entropy")
class SoftTargetCrossEntropyLoss(ClassyLoss):
def __init__(self, ignore_index, reduction, normalize_targets):
def __init__(self, ignore_index=-100, reduction="mean", normalize_targets=True):
"""Intializer for the soft target cross-entropy loss loss.
This allows the targets for the cross entropy loss to be multilabel
Config params:
'weight': weight of sample (not yet implemented),
'ignore_index': sample should be ignored for loss (optional),
'reduction': specifies reduction to apply to the output (optional),
Args:
ignore_index: sample should be ignored for loss if the class is this value
reduction: specifies reduction to apply to the output
normalize_targets: whether the targets should be normalized to a sum of 1
based on the total count of positive targets for a given sample
"""
super(SoftTargetCrossEntropyLoss, self).__init__()
self._ignore_index = ignore_index
self._reduction = reduction
assert normalize_targets in [None, "count_based"]
assert isinstance(normalize_targets, bool)
self._normalize_targets = normalize_targets
if self._reduction != "mean":
raise NotImplementedError(
Expand All @@ -47,12 +48,10 @@ def from_config(cls, config: Dict[str, Any]) -> "SoftTargetCrossEntropyLoss":
A SoftTargetCrossEntropyLoss instance.
"""

if "weight" in config:
raise NotImplementedError('"weight" not implemented')
return cls(
ignore_index=config.get("ignore_index", -100),
reduction=config.get("reduction", "mean"),
normalize_targets=config.get("normalize_targets", "count_based"),
normalize_targets=config.get("normalize_targets", True),
)

def forward(self, output, target):
Expand All @@ -77,7 +76,7 @@ def forward(self, output, target):
)
valid_mask = target != self._ignore_index
valid_targets = target.float() * valid_mask.float()
if self._normalize_targets == "count_based":
if self._normalize_targets:
valid_targets /= self._eps + valid_targets.sum(dim=1, keepdim=True)
per_sample_per_target_loss = -valid_targets * F.log_softmax(output, -1)
# perform reduction
Expand Down
2 changes: 1 addition & 1 deletion test/losses_soft_target_cross_entropy_loss_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def test_unnormalized_soft_target_cross_entropy(self):
"name": "soft_target_cross_entropy",
"ignore_index": -1,
"reduction": "mean",
"normalize_targets": None,
"normalize_targets": False,
}
crit = SoftTargetCrossEntropyLoss.from_config(config)
outputs = self._get_outputs()
Expand Down

0 comments on commit 932aabc

Please sign in to comment.