-
Notifications
You must be signed in to change notification settings - Fork 329
/
Copy pathchronos_bolt.py
640 lines (539 loc) · 23.1 KB
/
chronos_bolt.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
# Authors: Abdul Fatir Ansari <[email protected]>, Caner Turkmen <[email protected]>, Lorenzo Stella <[email protected]>
# Original source:
# https://github.com/autogluon/autogluon/blob/f57beb26cb769c6e0d484a6af2b89eab8aee73a8/timeseries/src/autogluon/timeseries/models/chronos/pipeline/chronos_bolt.py
import copy
import logging
import warnings
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union
import torch
import torch.nn as nn
from transformers import AutoConfig
from transformers.models.t5.modeling_t5 import (
ACT2FN,
T5Config,
T5LayerNorm,
T5PreTrainedModel,
T5Stack,
)
from transformers.utils import ModelOutput
from .base import BaseChronosPipeline, ForecastType
logger = logging.getLogger(__file__)
@dataclass
class ChronosBoltConfig:
context_length: int
prediction_length: int
input_patch_size: int
input_patch_stride: int
quantiles: List[float]
use_reg_token: bool = False
@dataclass
class ChronosBoltOutput(ModelOutput):
loss: Optional[torch.Tensor] = None
quantile_preds: Optional[torch.Tensor] = None
attentions: Optional[torch.Tensor] = None
cross_attentions: Optional[torch.Tensor] = None
class Patch(nn.Module):
def __init__(self, patch_size: int, patch_stride: int) -> None:
super().__init__()
self.patch_size = patch_size
self.patch_stride = patch_stride
def forward(self, x: torch.Tensor) -> torch.Tensor:
length = x.shape[-1]
if length % self.patch_size != 0:
padding_size = (
*x.shape[:-1],
self.patch_size - (length % self.patch_size),
)
padding = torch.full(
size=padding_size, fill_value=torch.nan, dtype=x.dtype, device=x.device
)
x = torch.concat((padding, x), dim=-1)
x = x.unfold(dimension=-1, size=self.patch_size, step=self.patch_stride)
return x
class InstanceNorm(nn.Module):
"""
See, also, RevIN. Apply standardization along the last dimension.
"""
def __init__(self, eps: float = 1e-5) -> None:
super().__init__()
self.eps = eps
def forward(
self,
x: torch.Tensor,
loc_scale: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
if loc_scale is None:
loc = torch.nan_to_num(torch.nanmean(x, dim=-1, keepdim=True), nan=0.0)
scale = torch.nan_to_num(
torch.nanmean((x - loc).square(), dim=-1, keepdim=True).sqrt(), nan=1.0
)
scale = torch.where(scale == 0, torch.abs(loc) + self.eps, scale)
else:
loc, scale = loc_scale
return (x - loc) / scale, (loc, scale)
def inverse(
self, x: torch.Tensor, loc_scale: Tuple[torch.Tensor, torch.Tensor]
) -> torch.Tensor:
loc, scale = loc_scale
return x * scale + loc
class ResidualBlock(nn.Module):
def __init__(
self,
in_dim: int,
h_dim: int,
out_dim: int,
act_fn_name: str,
dropout_p: float = 0.0,
use_layer_norm: bool = False,
) -> None:
super().__init__()
self.dropout = nn.Dropout(dropout_p)
self.hidden_layer = nn.Linear(in_dim, h_dim)
self.act = ACT2FN[act_fn_name]
self.output_layer = nn.Linear(h_dim, out_dim)
self.residual_layer = nn.Linear(in_dim, out_dim)
self.use_layer_norm = use_layer_norm
if use_layer_norm:
self.layer_norm = T5LayerNorm(out_dim)
def forward(self, x: torch.Tensor):
hid = self.act(self.hidden_layer(x))
out = self.dropout(self.output_layer(hid))
res = self.residual_layer(x)
out = out + res
if self.use_layer_norm:
return self.layer_norm(out)
return out
class ChronosBoltModelForForecasting(T5PreTrainedModel):
_keys_to_ignore_on_load_missing = [
r"input_patch_embedding\.",
r"output_patch_embedding\.",
]
_keys_to_ignore_on_load_unexpected = [r"lm_head.weight"]
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
def __init__(self, config: T5Config):
assert hasattr(config, "chronos_config"), "Not a Chronos config file"
super().__init__(config)
self.model_dim = config.d_model
self.chronos_config = ChronosBoltConfig(**config.chronos_config)
# Only decoder_start_id (and optionally REG token)
if self.chronos_config.use_reg_token:
config.reg_token_id = 1
config.vocab_size = 2 if self.chronos_config.use_reg_token else 1
self.shared = nn.Embedding(config.vocab_size, config.d_model)
# Input patch embedding layer
self.input_patch_embedding = ResidualBlock(
in_dim=self.chronos_config.input_patch_size * 2,
h_dim=config.d_ff,
out_dim=config.d_model,
act_fn_name=config.dense_act_fn,
dropout_p=config.dropout_rate,
)
# patching layer
self.patch = Patch(
patch_size=self.chronos_config.input_patch_size,
patch_stride=self.chronos_config.input_patch_stride,
)
# instance normalization, also referred to as "scaling" in Chronos and GluonTS
self.instance_norm = InstanceNorm()
encoder_config = copy.deepcopy(config)
encoder_config.is_decoder = False
encoder_config.use_cache = False
encoder_config.is_encoder_decoder = False
self.encoder = T5Stack(encoder_config, self.shared)
self._init_decoder(config)
self.num_quantiles = len(self.chronos_config.quantiles)
quantiles = torch.tensor(self.chronos_config.quantiles, dtype=self.dtype)
self.register_buffer("quantiles", quantiles, persistent=False)
self.output_patch_embedding = ResidualBlock(
in_dim=config.d_model,
h_dim=config.d_ff,
out_dim=self.num_quantiles * self.chronos_config.prediction_length,
act_fn_name=config.dense_act_fn,
dropout_p=config.dropout_rate,
)
# Initialize weights and apply final processing
self.post_init()
# Model parallel
self.model_parallel = False
self.device_map = None
def _init_weights(self, module):
super()._init_weights(module)
"""Initialize the weights"""
factor = self.config.initializer_factor
if isinstance(module, (self.__class__)):
module.shared.weight.data.normal_(mean=0.0, std=factor * 1.0)
elif isinstance(module, ResidualBlock):
module.hidden_layer.weight.data.normal_(
mean=0.0,
std=factor * ((self.chronos_config.input_patch_size * 2) ** -0.5),
)
if (
hasattr(module.hidden_layer, "bias")
and module.hidden_layer.bias is not None
):
module.hidden_layer.bias.data.zero_()
module.residual_layer.weight.data.normal_(
mean=0.0,
std=factor * ((self.chronos_config.input_patch_size * 2) ** -0.5),
)
if (
hasattr(module.residual_layer, "bias")
and module.residual_layer.bias is not None
):
module.residual_layer.bias.data.zero_()
module.output_layer.weight.data.normal_(
mean=0.0, std=factor * ((self.config.d_ff) ** -0.5)
)
if (
hasattr(module.output_layer, "bias")
and module.output_layer.bias is not None
):
module.output_layer.bias.data.zero_()
def encode(
self, context: torch.Tensor, mask: Optional[torch.Tensor] = None
) -> Tuple[
torch.Tensor, Tuple[torch.Tensor, torch.Tensor], torch.Tensor, torch.Tensor
]:
mask = (
mask.to(context.dtype)
if mask is not None
else torch.isnan(context).logical_not().to(context.dtype)
)
batch_size, _ = context.shape
if context.shape[-1] > self.chronos_config.context_length:
context = context[..., -self.chronos_config.context_length :]
mask = mask[..., -self.chronos_config.context_length :]
# scaling
context, loc_scale = self.instance_norm(context)
# the scaling op above is done in 32-bit precision,
# then the context is moved to model's dtype
context = context.to(self.dtype)
mask = mask.to(self.dtype)
# patching
patched_context = self.patch(context)
patched_mask = torch.nan_to_num(self.patch(mask), nan=0.0)
patched_context = torch.where(patched_mask > 0.0, patched_context, 0.0)
# concat context and mask along patch dim
patched_context = torch.cat([patched_context, patched_mask], dim=-1)
# attention_mask = 1 if at least one item in the patch is observed
attention_mask = (
patched_mask.sum(dim=-1) > 0
) # (batch_size, patched_seq_length)
input_embeds = self.input_patch_embedding(patched_context)
if self.chronos_config.use_reg_token:
# Append [REG]
reg_input_ids = torch.full(
(batch_size, 1),
self.config.reg_token_id,
device=input_embeds.device,
)
reg_embeds = self.shared(reg_input_ids)
input_embeds = torch.cat([input_embeds, reg_embeds], dim=-2)
attention_mask = torch.cat(
[
attention_mask.to(self.dtype),
torch.ones_like(reg_input_ids).to(self.dtype),
],
dim=-1,
)
encoder_outputs = self.encoder(
attention_mask=attention_mask,
inputs_embeds=input_embeds,
)
return encoder_outputs[0], loc_scale, input_embeds, attention_mask
def forward(
self,
context: torch.Tensor,
mask: Optional[torch.Tensor] = None,
target: Optional[torch.Tensor] = None,
target_mask: Optional[torch.Tensor] = None,
) -> ChronosBoltOutput:
batch_size = context.size(0)
hidden_states, loc_scale, input_embeds, attention_mask = self.encode(
context=context, mask=mask
)
sequence_output = self.decode(input_embeds, attention_mask, hidden_states)
quantile_preds_shape = (
batch_size,
self.num_quantiles,
self.chronos_config.prediction_length,
)
quantile_preds = self.output_patch_embedding(sequence_output).view(
*quantile_preds_shape
)
loss = None
if target is not None:
# normalize target
target, _ = self.instance_norm(target, loc_scale)
target = target.unsqueeze(1) # type: ignore
assert self.chronos_config.prediction_length >= target.shape[-1]
target = target.to(quantile_preds.device)
target_mask = (
target_mask.unsqueeze(1).to(quantile_preds.device)
if target_mask is not None
else ~torch.isnan(target)
)
target[~target_mask] = 0.0
# pad target and target_mask if they are shorter than model's prediction_length
if self.chronos_config.prediction_length > target.shape[-1]:
padding_shape = (
*target.shape[:-1],
self.chronos_config.prediction_length - target.shape[-1],
)
target = torch.cat(
[target, torch.zeros(padding_shape).to(target)], dim=-1
)
target_mask = torch.cat(
[target_mask, torch.zeros(padding_shape).to(target_mask)], dim=-1
)
loss = (
2
* torch.abs(
(target - quantile_preds)
* (
(target <= quantile_preds).float()
- self.quantiles.view(1, self.num_quantiles, 1)
)
)
* target_mask.float()
)
loss = loss.mean(dim=-2) # Mean over prediction horizon
loss = loss.sum(dim=-1) # Sum over quantile levels
loss = loss.mean() # Mean over batch
# Unscale predictions
quantile_preds = self.instance_norm.inverse(
quantile_preds.view(batch_size, -1),
loc_scale,
).view(*quantile_preds_shape)
return ChronosBoltOutput(
loss=loss,
quantile_preds=quantile_preds,
)
def _init_decoder(self, config):
decoder_config = copy.deepcopy(config)
decoder_config.is_decoder = True
decoder_config.is_encoder_decoder = False
decoder_config.num_layers = config.num_decoder_layers
self.decoder = T5Stack(decoder_config, self.shared)
def decode(
self,
input_embeds,
attention_mask,
hidden_states,
output_attentions=False,
):
"""
Parameters
----------
input_embeds: torch.Tensor
Patched and embedded inputs. Shape (batch_size, patched_context_length, d_model)
attention_mask: torch.Tensor
Attention mask for the patched context. Shape (batch_size, patched_context_length), type: torch.int64
hidden_states: torch.Tensor
Hidden states returned by the encoder. Shape (batch_size, patched_context_length, d_model)
Returns
-------
last_hidden_state
Last hidden state returned by the decoder, of shape (batch_size, 1, d_model)
"""
batch_size = input_embeds.shape[0]
decoder_input_ids = torch.full(
(batch_size, 1),
self.config.decoder_start_token_id,
device=input_embeds.device,
)
decoder_outputs = self.decoder(
input_ids=decoder_input_ids,
encoder_hidden_states=hidden_states,
encoder_attention_mask=attention_mask,
output_attentions=output_attentions,
return_dict=True,
)
return decoder_outputs.last_hidden_state # sequence_outputs, b x 1 x d_model
class ChronosBoltPipeline(BaseChronosPipeline):
forecast_type: ForecastType = ForecastType.QUANTILES
default_context_length: int = 2048
def __init__(self, model: ChronosBoltModelForForecasting):
super().__init__(inner_model=model)
self.model = model
@property
def quantiles(self) -> List[float]:
return self.model.config.chronos_config["quantiles"]
@torch.no_grad()
def embed(
self, context: Union[torch.Tensor, List[torch.Tensor]]
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
"""
Get encoder embeddings for the given time series.
Parameters
----------
context
Input series. This is either a 1D tensor, or a list
of 1D tensors, or a 2D tensor whose first dimension
is batch. In the latter case, use left-padding with
``torch.nan`` to align series of different lengths.
Returns
-------
embeddings, loc_scale
A tuple of two items: the encoder embeddings and the loc_scale,
i.e., the mean and std of the original time series.
The encoder embeddings are shaped (batch_size, num_patches + 1, d_model),
where num_patches is the number of patches in the time series
and the extra 1 is for the [REG] token (if used by the model).
"""
context_tensor = self._prepare_and_validate_context(context=context)
model_context_length = self.model.config.chronos_config["context_length"]
if context_tensor.shape[-1] > model_context_length:
context_tensor = context_tensor[..., -model_context_length:]
context_tensor = context_tensor.to(
device=self.model.device,
dtype=torch.float32,
)
embeddings, loc_scale, *_ = self.model.encode(context=context_tensor)
return embeddings.cpu(), (
loc_scale[0].squeeze(-1).cpu(),
loc_scale[1].squeeze(-1).cpu(),
)
def predict( # type: ignore[override]
self,
context: Union[torch.Tensor, List[torch.Tensor]],
prediction_length: Optional[int] = None,
limit_prediction_length: bool = False,
) -> torch.Tensor:
"""
Get forecasts for the given time series.
Refer to the base method (``BaseChronosPipeline.predict``)
for details on shared parameters.
Additional parameters
---------------------
limit_prediction_length
Force prediction length smaller or equal than the
built-in prediction length from the model. False by
default. When true, fail loudly if longer predictions
are requested, otherwise longer predictions are allowed.
Returns
-------
torch.Tensor
Forecasts of shape (batch_size, num_quantiles, prediction_length)
where num_quantiles is the number of quantiles the model has been
trained to output. For official Chronos-Bolt models, the value of
num_quantiles is 9 for [0.1, 0.2, ..., 0.9]-quantiles.
Raises
------
ValueError
When limit_prediction_length is True and the prediction_length is
greater than model's trainig prediction_length.
"""
context_tensor = self._prepare_and_validate_context(context=context)
model_context_length = self.model.config.chronos_config["context_length"]
model_prediction_length = self.model.config.chronos_config["prediction_length"]
if prediction_length is None:
prediction_length = model_prediction_length
if prediction_length > model_prediction_length:
msg = (
f"We recommend keeping prediction length <= {model_prediction_length}. "
"The quality of longer predictions may degrade since the model is not optimized for it. "
)
if limit_prediction_length:
msg += "You can turn off this check by setting `limit_prediction_length=False`."
raise ValueError(msg)
warnings.warn(msg)
predictions = []
remaining = prediction_length
# We truncate the context here because otherwise batches with very long
# context could take up large amounts of GPU memory unnecessarily.
if context_tensor.shape[-1] > model_context_length:
context_tensor = context_tensor[..., -model_context_length:]
# TODO: We unroll the forecast of Chronos Bolt greedily with the full forecast
# horizon that the model was trained with (i.e., 64). This results in variance collapsing
# every 64 steps.
context_tensor = context_tensor.to(
device=self.model.device,
dtype=torch.float32,
)
while remaining > 0:
with torch.no_grad():
prediction = self.model(
context=context_tensor,
).quantile_preds.to(context_tensor)
predictions.append(prediction)
remaining -= prediction.shape[-1]
if remaining <= 0:
break
central_idx = torch.abs(torch.tensor(self.quantiles) - 0.5).argmin()
central_prediction = prediction[:, central_idx]
context_tensor = torch.cat([context_tensor, central_prediction], dim=-1)
return torch.cat(predictions, dim=-1)[..., :prediction_length].to(
dtype=torch.float32, device="cpu"
)
def predict_quantiles(
self,
context: Union[torch.Tensor, List[torch.Tensor]],
prediction_length: Optional[int] = None,
quantile_levels: List[float] = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9],
**predict_kwargs,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Refer to the base method (``BaseChronosPipeline.predict_quantiles``).
"""
# shape (batch_size, prediction_length, len(training_quantile_levels))
predictions = (
self.predict(context, prediction_length=prediction_length, **predict_kwargs)
.detach()
.swapaxes(1, 2)
)
training_quantile_levels = self.quantiles
if set(quantile_levels).issubset(set(training_quantile_levels)):
# no need to perform intra/extrapolation
quantiles = predictions[
..., [training_quantile_levels.index(q) for q in quantile_levels]
]
else:
# we rely on torch for interpolating quantiles if quantiles that
# Chronos Bolt was trained on were not provided
if min(quantile_levels) < min(training_quantile_levels) or max(
quantile_levels
) > max(training_quantile_levels):
logger.warning(
f"\tQuantiles to be predicted ({quantile_levels}) are not within the range of "
f"quantiles that Chronos-Bolt was trained on ({training_quantile_levels}). "
"Quantile predictions will be set to the minimum/maximum levels at which Chronos-Bolt "
"was trained on. This may significantly affect the quality of the predictions."
)
# TODO: this is a hack that assumes the model's quantiles during training (training_quantile_levels)
# made up an equidistant grid along the quantile dimension. i.e., they were (0.1, 0.2, ..., 0.9).
# While this holds for official Chronos-Bolt models, this may not be true in the future, and this
# function may have to be revised.
augmented_predictions = torch.cat(
[predictions[..., [0]], predictions, predictions[..., [-1]]],
dim=-1,
)
quantiles = torch.quantile(
augmented_predictions,
q=torch.tensor(quantile_levels, dtype=augmented_predictions.dtype),
dim=-1,
).permute(1, 2, 0)
# NOTE: the median is returned as the mean here
mean = predictions[:, :, training_quantile_levels.index(0.5)]
return quantiles, mean
@classmethod
def from_pretrained(cls, *args, **kwargs):
"""
Load the model, either from a local path or from the HuggingFace Hub.
Supports the same arguments as ``AutoConfig`` and ``AutoModel``
from ``transformers``.
"""
config = AutoConfig.from_pretrained(*args, **kwargs)
assert hasattr(config, "chronos_config"), "Not a Chronos config file"
architecture = config.architectures[0]
class_ = globals().get(architecture)
if class_ is None:
logger.warning(
f"Unknown architecture: {architecture}, defaulting to ChronosBoltModelForForecasting"
)
class_ = ChronosBoltModelForForecasting
model = class_.from_pretrained(*args, **kwargs)
return cls(model=model)