Skip to content

Commit

Permalink
Fix energy score
Browse files Browse the repository at this point in the history
  • Loading branch information
Jan Beitner committed Apr 10, 2022
1 parent 46fdd1e commit a0f54bb
Show file tree
Hide file tree
Showing 5 changed files with 35 additions and 52 deletions.
3 changes: 3 additions & 0 deletions docs/source/getting-started.rst
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ Alternatively, to installl the package via conda:

PyTorch Forecasting is now installed from the conda-forge channel while PyTorch is install from the pytorch channel.

To use the MQF2 loss (multivariate quantile loss), also install
`pip install git+https://github.com/KelvinKan/CP-Flow.git@package-specific-version --no-deps`


Usage
-------------
Expand Down
22 changes: 14 additions & 8 deletions pytorch_forecasting/metrics/_mqf2_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,7 +395,12 @@ def rsample(self, sample_shape: torch.Size = torch.Size()) -> torch.Tensor:
layout=self.hidden_state.layout,
)

return self.quantile(alpha, hidden_state_repeat).reshape((numel_batch,) + sample_shape + (prediction_length,))
samples = (
self.quantile(alpha, hidden_state_repeat)
.reshape((numel_batch,) + sample_shape + (prediction_length,))
.transpose(0, 1)
)
return samples

def quantile(self, alpha: torch.Tensor, hidden_state: Optional[torch.Tensor] = None) -> torch.Tensor:
"""
Expand Down Expand Up @@ -441,11 +446,11 @@ def batch_shape(self) -> torch.Size:

@property
def event_shape(self) -> Tuple:
return ()
return (self.prediction_length,)

@property
def event_dim(self) -> int:
return 0
return 1


class TransformedMQF2Distribution(TransformedDistribution):
Expand All @@ -469,9 +474,7 @@ def scale_input(self, y: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
return z, scale

def repeat_scale(self, scale: torch.Tensor) -> torch.Tensor:
if scale.ndim > 1:
scale = scale.squeeze(-1)
return scale.repeat_interleave(self.base_dist.context_length, 0)
return scale.squeeze(-1).repeat_interleave(self.base_dist.context_length, 0)

def log_prob(self, y: torch.Tensor) -> torch.Tensor:
prediction_length = self.base_dist.prediction_length
Expand All @@ -498,8 +501,11 @@ def energy_score(self, y: torch.Tensor) -> torch.Tensor:

def quantile(self, alpha: torch.Tensor, hidden_state: Optional[torch.Tensor] = None) -> torch.Tensor:
result = self.base_dist.quantile(alpha, hidden_state=hidden_state)
result = result.reshape(self.base_dist.hidden_state.size(0), -1, self.base_dist.prediction_length).transpose(
0, 1
)
for transform in self.transforms:
# transform separate for each prediction horizon
result = torch.stack(tuple(transform(r.squeeze(1)) for r in result.split(1, dim=1)), dim=1)
result = transform(result)

return result
return result.transpose(0, 1).reshape_as(alpha)
51 changes: 12 additions & 39 deletions pytorch_forecasting/metrics/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,13 +259,14 @@ def rescale_parameters(


class MQF2DistributionLoss(DistributionLoss):
"""Multivariate quantile loss."""

eps = 1e-4

def __init__(
self,
prediction_length: int,
hidden_size: int = 30,
hidden_size: int = 4,
threshold_input: float = 100.0,
es_num_samples: int = 50,
beta: float = 1.0,
Expand All @@ -276,7 +277,6 @@ def __init__(
super().__init__()

from cpflows.flows import ActNorm
import cpflows.icnn
from cpflows.icnn import PICNN

from pytorch_forecasting.metrics._mqf2_utils import (
Expand All @@ -288,9 +288,7 @@ def __init__(

self.distribution_class = MQF2Distribution
self.transformed_distribution_class = TransformedMQF2Distribution
n_arguments = hidden_size / prediction_length
assert round(n_arguments) == n_arguments, "MQF2 requires hidden_size to be a multiple of prediction_length"
self.distribution_arguments = list(range(int(n_arguments)))
self.distribution_arguments = list(range(int(hidden_size)))
self.prediction_length = prediction_length
self.threshold_input = threshold_input
self.es_num_samples = es_num_samples
Expand All @@ -300,7 +298,7 @@ def __init__(
convexnet = PICNN(
dim=prediction_length,
dimh=icnn_hidden_size,
dimc=hidden_size,
dimc=hidden_size * prediction_length,
num_hidden_layers=icnn_num_layers,
symm_act_first=True,
)
Expand Down Expand Up @@ -336,8 +334,8 @@ def map_x_to_distribution(self, x: torch.Tensor) -> distributions.Distribution:
beta=self.beta,
)
# rescale
loc = x[..., -2]
scale = x[..., -1]
loc = x[..., -2][:, None]
scale = x[..., -1][:, None]
return self.transformed_distribution_class(distr, [distributions.AffineTransform(loc=loc, scale=scale)])

def loss(self, y_pred: torch.Tensor, y_actual: torch.Tensor) -> torch.Tensor:
Expand All @@ -352,42 +350,17 @@ def loss(self, y_pred: torch.Tensor, y_actual: torch.Tensor) -> torch.Tensor:
torch.Tensor: metric value on which backpropagation can be applied
"""
distribution = self.map_x_to_distribution(y_pred)
# clip y_actual to avoid infinite losses
if self.is_energy_score:
# todo: why clip to 0 to 1??? certainly not right in this case
loss = -distribution.energy_score(y_actual) # .clip(self.eps, 1 - self.eps))
loss = distribution.energy_score(y_actual)
else:
loss = -distribution.log_prob(y_actual) # .clip(self.eps, 1 - self.eps))
loss = -distribution.log_prob(y_actual)
return loss.reshape(-1, 1)

def rescale_parameters(
self, parameters: torch.Tensor, target_scale: torch.Tensor, encoder: BaseEstimator
) -> torch.Tensor:
return torch.concat([parameters.reshape(parameters.size(0), -1), target_scale], dim=-1)

@property
def event_shape(self) -> Tuple:
return ()

def sample(self, y_pred, n_samples: int) -> torch.Tensor:
"""
Sample from distribution.
Args:
y_pred: prediction output of network (shape batch_size x n_timesteps x n_paramters)
n_samples (int): number of samples to draw
Returns:
torch.Tensor: tensor with samples (shape batch_size x n_timesteps x n_samples)
"""
dist = self.map_x_to_distribution(y_pred)
samples = dist.sample((n_samples,))
if samples.ndim == 3:
samples = samples.permute(0, 2, 1)
elif samples.ndim == 2:
samples = samples.transpose(0, 1)
return samples

def to_quantiles(self, y_pred: torch.Tensor, quantiles: List[float] = None) -> torch.Tensor:
"""
Convert network prediction into a quantile prediction.
Expand All @@ -402,14 +375,14 @@ def to_quantiles(self, y_pred: torch.Tensor, quantiles: List[float] = None) -> t
"""
if quantiles is None:
quantiles = self.quantiles
distribution = self.map_x_to_distribution(y_pred.repeat_interleave(len(quantiles), dim=0))
distribution = self.map_x_to_distribution(y_pred)
alpha = (
torch.as_tensor(quantiles, device=y_pred.device)[:, None]
.repeat(y_pred.size(0), 1)
.expand(-1, self.prediction_length)
) # (batch_size * quantiles x prediction_length)

result = distribution.quantile(alpha) # (batch_size * quantiles x prediction_length)
)
hidden_state = distribution.base_dist.hidden_state.repeat_interleave(len(quantiles), dim=0)
result = distribution.quantile(alpha, hidden_state=hidden_state) # (batch_size * quantiles x prediction_length)

# reshape
result = result.reshape(-1, len(quantiles), self.prediction_length).transpose(
Expand Down
7 changes: 4 additions & 3 deletions pytorch_forecasting/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -847,10 +847,11 @@ def plot_prediction(
raise ValueError(f"add_loss_to_title '{add_loss_to_title}'' is unkown")
if isinstance(loss, MASE):
loss_value = loss(y_raw[None], (y[-n_pred:][None], None), y[:n_pred][None])
elif isinstance(loss, DistributionLoss):
loss_value = "-"
elif isinstance(loss, Metric):
loss_value = loss(y_raw[None], (y[-n_pred:][None], None))
try:
loss_value = loss(y_raw[None], (y[-n_pred:][None], None))
except Exception:
loss_value = "-"
else:
loss_value = loss
ax.set_title(f"Loss {loss_value}")
Expand Down
4 changes: 2 additions & 2 deletions tests/test_models/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,8 @@ def data_with_covariates():

def make_dataloaders(data_with_covariates, **kwargs):
training_cutoff = "2016-09-01"
max_encoder_length = 5
max_prediction_length = 2
max_encoder_length = 4
max_prediction_length = 3

kwargs.setdefault("target", "volume")
kwargs.setdefault("group_ids", ["agency", "sku"])
Expand Down

0 comments on commit a0f54bb

Please sign in to comment.