Skip to content

Commit

Permalink
add notebook for quantitative assessing data characteristics
Browse files Browse the repository at this point in the history
  • Loading branch information
imJiawen committed Jan 5, 2025
1 parent 72a685d commit bc4ed88
Show file tree
Hide file tree
Showing 3 changed files with 1 addition and 148 deletions.
127 changes: 1 addition & 126 deletions probts/model/forecast_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,43 +13,6 @@
from probts.utils.save_utils import update_metrics, calculate_weighted_average, load_checkpoint, get_hor_str
from probts.utils.utils import init_class_helper


import itertools
from collections import ChainMap
from gluonts.ev.ts_stats import seasonal_error
from typing import Iterable, List
from gluonts.model import Forecast

from gluonts.ev.metrics import (
MSE,
MAE,
MASE,
MAPE,
SMAPE,
MSIS,
RMSE,
NRMSE,
ND,
MeanWeightedSumQuantileLoss,
)

# Instantiate the metrics
metrics_func = [
MSE(forecast_type="mean"),
MSE(forecast_type=0.5),
MAE(),
MASE(),
MAPE(),
SMAPE(),
MSIS(),
RMSE(),
NRMSE(),
ND(),
MeanWeightedSumQuantileLoss(
quantile_levels=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
),
]

def get_weights(sampling_weight_scheme, max_hor):
'''
return: w [max_hor]
Expand Down Expand Up @@ -150,34 +113,6 @@ def evaluate(self, batch, stage='',dataloader_idx=None):
norm_metrics = self.evaluator(norm_future_data, forecasts, past_data=norm_past_data, freq=self.forecaster.freq)
self.metrics_dict = update_metrics(norm_metrics, stage, 'norm', target_dict=self.metrics_dict)

###########

evaluators = {}
for metric in metrics_func:
evaluator = metric(axis=None)
evaluators[evaluator.name] = evaluator

input_batches = iter(orin_past_data.cpu())
label_batches = iter(orin_future_data.cpu())
forecast_batches = iter(denorm_forecasts.cpu())
season_length = get_seasonality(self.forecaster.freq)

for input_batch, label_batch, forecast_batch in zip(
input_batches, label_batches, forecast_batches
):
data_batch = self._get_data_batch(
input_batch,
label_batch,
forecast_batch,
seasonality=season_length,
mask_invalid_label=True,
allow_nan_forecast=False,
)

for evaluator in evaluators.values():
evaluator.update(data_batch)
############

l = orin_future_data.shape[1]

if stage != 'test' and self.sampling_weight_scheme not in ['fix', 'none']:
Expand Down Expand Up @@ -262,64 +197,4 @@ def configure_optimizers(self):

return {"optimizer": optimizer, "lr_scheduler": lr_scheduler}

return optimizer

def _get_data_batch(
self,
input_batch,
label_batch,
forecast_batch,
seasonality: Optional[int] = None,
mask_invalid_label: bool = True,
allow_nan_forecast: bool = False,
) -> ChainMap:
label_target = np.stack([label for label in label_batch], axis=0)
if mask_invalid_label:
label_target = np.ma.masked_invalid(label_target)

other_data = {
"label": label_target,
}

seasonal_error_values = []
for input_ in input_batch:
seasonality_entry = seasonality
# if seasonality_entry is None:
# seasonality_entry = get_seasonality(input_["start"].freqstr)
input_target = input_
if mask_invalid_label:
input_target = np.ma.masked_invalid(input_target)
seasonal_error_values.append(
seasonal_error(
input_target,
seasonality=seasonality_entry,
time_axis=-1,
)
)
other_data["seasonal_error"] = np.array(seasonal_error_values)
# print("label_target.shape ", label_target.shape)

return ChainMap(
other_data, BatchForecast(forecast_batch, allow_nan=allow_nan_forecast) # type: ignore
)


from dataclasses import dataclass

@dataclass
class BatchForecast:
"""
Wrapper around ``Forecast`` objects, that adds a batch dimension to arrays
returned by ``__getitem__``, for compatibility with ``gluonts.ev``.
"""

forecasts: List[Forecast]
allow_nan: bool = False

def __getitem__(self, name):

values = [forecast.T for forecast in self.forecasts]
res = np.stack(values, axis=0)
# print("res.shape ", res.shape)
# sys.exit(0)
return res
return optimizer
13 changes: 0 additions & 13 deletions probts/utils/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,16 +60,6 @@ def get_sequence_metrics(self, targets, forecasts, seasonal_error=None, samples_
[metrics[self.weighted_loss_name(q)] for q in self.quantiles]
)

# scaled CRPS

# for q in self.quantiles:
# q_forecasts = np.quantile(forecasts, q, axis=samples_dim)
# metrics['mean_weighted_'+self.loss_name(q)] = np.mean(scaled_quantile_loss(targets, q_forecasts, q, seasonal_error))
# metrics['mean_scale_'+self.weighted_loss_name(q)] = \
# metrics['mean_scale_'+self.loss_name(q)] / metrics["abs_target_sum"]
# metrics[self.coverage_name(q)] = coverage(targets, q_forecasts)


metrics["MAE_Coverage"] = np.mean(
[
np.abs(metrics[self.coverage_name(q)] - np.array([q]))
Expand Down Expand Up @@ -118,9 +108,6 @@ def __call__(self, targets, forecasts, past_data, freq, loss_weights=None):
Dict[String, float]
metrics
"""
# targets = targets.cpu().detach().numpy()
# forecasts = forecasts.cpu().detach().numpy()
# past_data = past_data.cpu().detach().numpy()

targets = process_tensor(targets)
forecasts = process_tensor(forecasts)
Expand Down
9 changes: 0 additions & 9 deletions probts/utils/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,15 +123,6 @@ def smape(target: np.ndarray, forecast: np.ndarray) -> float:
np.abs(target - forecast) / (np.abs(target) + np.abs(forecast))
)


# def quantile_loss(target: np.ndarray, forecast: np.ndarray, q: float) -> float:
# r"""
# .. math::

# quantile\_loss = 2 * sum(|(Y - \hat{Y}) * ((Y <= \hat{Y}) - q)|)
# """
# return 2 * np.sum(np.abs((forecast - target) * ((target <= forecast) - q)))

def quantile_loss(target: np.ndarray, forecast: np.ndarray, q: float) -> float:
r"""
.. math::
Expand Down

0 comments on commit bc4ed88

Please sign in to comment.