-
Notifications
You must be signed in to change notification settings - Fork 18
/
Copy pathmultiple_reducers.py
32 lines (27 loc) · 1.19 KB
/
multiple_reducers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import torch
from .base_reducer import BaseReducer
from .mean_reducer import MeanReducer
class MultipleReducers(BaseReducer):
def __init__(self, reducers, default_reducer=None, **kwargs):
super().__init__(**kwargs)
self.reducers = torch.nn.ModuleDict(reducers)
self.default_reducer = (
MeanReducer() if default_reducer is None else default_reducer
)
def forward(self, loss_dict, embeddings, labels):
self.reset_stats()
sub_losses = torch.zeros(
len(loss_dict), dtype=embeddings.dtype, device=embeddings.device
)
loss_count = 0
for loss_name, loss_info in loss_dict.items():
input_dict = {loss_name: loss_info}
if loss_name in self.reducers:
loss_val = self.reducers[loss_name](input_dict, embeddings, labels)
else:
loss_val = self.default_reducer(input_dict, embeddings, labels)
sub_losses[loss_count] = loss_val
loss_count += 1
return self.sub_loss_reduction(sub_losses, embeddings, labels)
def sub_loss_reduction(self, sub_losses, embeddings=None, labels=None):
return torch.sum(sub_losses)