Skip to content

Commit

Permalink
Merge pull request sktime#260 from jdb78/fix/transformation
Browse files Browse the repository at this point in the history
Fix missing output transformation
  • Loading branch information
jdb78 authored Jan 12, 2021
2 parents 5152fcf + d2d51c7 commit 908a9aa
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 6 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
# Release Notes

## v0.8.2 Fix for output transformer (12/01/2021)

- Added missing output transformation which was switched off by default (#260)

## v0.8.1 Adding support for lag variables (10/01/2021)

### Added
Expand Down
21 changes: 15 additions & 6 deletions pytorch_forecasting/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def transform_output(self, out: Dict[str, torch.Tensor]) -> torch.Tensor:
return out

# no transformation logic
elif self.output_transformer is None or out.get("output_transformation", None) is None:
elif self.output_transformer is None or out.get("output_transformation", True) is None:
out = out["prediction"]

# distribution transformation
Expand Down Expand Up @@ -1349,7 +1349,10 @@ def from_dataset(
return super().from_dataset(dataset, **kwargs)

def output_to_prediction(
self, normalized_prediction_parameters: torch.Tensor, target_scale: Union[List[torch.Tensor], torch.Tensor]
self,
normalized_prediction_parameters: torch.Tensor,
target_scale: Union[List[torch.Tensor], torch.Tensor],
**kwargs,
) -> Tuple[Union[List[torch.Tensor], torch.Tensor], torch.Tensor]:
"""
Convert network output to rescaled and normalized prediction.
Expand All @@ -1359,17 +1362,18 @@ def output_to_prediction(
Args:
normalized_prediction_parameters (torch.Tensor): network prediction output
target_scale (Union[List[torch.Tensor], torch.Tensor]): target scale to rescale network output
**kwargs: extra arguments for dictionary passed to :py:meth:`~transform_output` method.
Returns:
Tuple[Union[List[torch.Tensor], torch.Tensor], torch.Tensor]: tuple of rescaled prediction and
normalized prediction (e.g. for input into next auto-regressive step)
"""
single_prediction = to_list(normalized_prediction_parameters)[0].ndim == 2
if single_prediction: # add time dimension as it is expected
normalized_prediction_parameters = apply_to_list(normalized_prediction_parameters, lambda x: x.unsqueeze(1))
# transform into real space
prediction_parameters = self.transform_output(
dict(
prediction=normalized_prediction_parameters,
target_scale=target_scale,
)
dict(prediction=normalized_prediction_parameters, target_scale=target_scale, **kwargs)
)
# todo: handle classification
# sample value(s) from distribution and select first sample
Expand All @@ -1386,6 +1390,11 @@ def output_to_prediction(
input_target = torch.cat(normalized_prediction, dim=-1)
else:
input_target = normalized_prediction # set next input target to normalized prediction

# remove time dimension
if single_prediction:
prediction = apply_to_list(prediction, lambda x: x.squeeze(1))
input_target = input_target.squeeze(1)
return prediction, input_target

def decode_autoregressive(
Expand Down

0 comments on commit 908a9aa

Please sign in to comment.