-
Notifications
You must be signed in to change notification settings - Fork 416
/
Copy pathpytorch_patched_decoder.py
796 lines (665 loc) · 25.7 KB
/
pytorch_patched_decoder.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Pytorch version of patched decoder."""
import dataclasses
import math
from typing import List, Tuple
import torch
from torch import nn
import torch.nn.functional as F
def _create_quantiles() -> list[float]:
return [0.1, 0.25, 0.5, 0.75, 0.9]
@dataclasses.dataclass
class TimesFMConfig:
"""Config for initializing timesfm patched_decoder class."""
# The number of blocks in the model.
num_layers: int = 20
# The number of attention heads used in the attention layers of the model.
num_heads: int = 16
# The number of key-value heads for implementing attention.
num_kv_heads: int = 16
# The hidden size of the model.
hidden_size: int = 1280
# The dimension of the MLP representations.
intermediate_size: int = 1280
# The number of head dimensions.
head_dim: int = 80
# The epsilon used by the rms normalization layers.
rms_norm_eps: float = 1e-6
# Patch length
patch_len: int = 32
# Horizon length
horizon_len: int = 128
# quantiles
quantiles: List[float] = dataclasses.field(default_factory=_create_quantiles)
# Padding value
pad_val: float = 1123581321.0
# Tolerance
tolerance: float = 1e-6
# The dtype of the weights.
dtype: str = "bfloat32"
# use positional embedding
use_positional_embedding: bool = True
def _masked_mean_std(
inputs: torch.Tensor,
padding: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
"""Calculates mean and standard deviation of `inputs` across axis 1.
It excludes values where `padding` is 1.
Args:
inputs: A PyTorch tensor of shape [b, n, p].
padding: A PyTorch tensor of shape [b, n, p] with values 0 or 1.
Returns:
A tuple containing the mean and standard deviation.
We return the statistics of the first patch with more than three non-padded
values.
"""
# Selecting the first patch with more than 3 unpadded values.
pad_sum = torch.sum(1 - padding, dim=2)
def _get_patch_index(arr: torch.Tensor):
indices = torch.argmax((arr >= 3).to(torch.int32), dim=1)
row_sum = (arr >= 3).to(torch.int32).sum(dim=1)
return torch.where(row_sum == 0, arr.shape[1] - 1, indices)
patch_indices = _get_patch_index(pad_sum)
bidxs = torch.arange(inputs.shape[0])
arr = inputs[bidxs, patch_indices, :]
pad = padding[bidxs, patch_indices, :]
# Create a mask where padding is 0
mask = 1 - pad
# Calculate the number of valid elements
num_valid_elements = torch.sum(mask, dim=1)
num_valid_elements = torch.where(
num_valid_elements == 0,
torch.tensor(1,
dtype=num_valid_elements.dtype,
device=num_valid_elements.device),
num_valid_elements,
)
# Calculate the masked sum and squared sum
masked_sum = torch.sum(arr * mask, dim=1)
masked_squared_sum = torch.sum((arr * mask)**2, dim=1)
# Calculate the masked mean and standard deviation
masked_mean = masked_sum / num_valid_elements
masked_var = masked_squared_sum / num_valid_elements - masked_mean**2
masked_var = torch.where(
masked_var < 0.0,
torch.tensor(0.0, dtype=masked_var.dtype, device=masked_var.device),
masked_var,
)
masked_std = torch.sqrt(masked_var)
return masked_mean, masked_std
def _shift_padded_seq(mask: torch.Tensor, seq: torch.Tensor) -> torch.Tensor:
"""Shifts rows of seq based on the first 0 in each row of the mask.
Args:
mask: mask tensor of shape [B, N]
seq: seq tensor of shape [B, N, P]
Returns:
Returns the shifted sequence.
"""
batch_size, num_seq, feature_dim = seq.shape
new_mask: torch.BoolTensor = mask == 0
# Use argmax to find the first True value in each row
indices = new_mask.to(torch.int32).argmax(dim=1)
# Handle rows with all zeros
indices[~new_mask.any(dim=1)] = -1
# Create index ranges for each sequence in the batch
idx_range = (torch.arange(num_seq).to(
seq.device).unsqueeze(0).unsqueeze(-1).expand(batch_size, -1,
feature_dim))
# Calculate shifted indices for each element in each sequence
shifted_idx = (idx_range - indices[:, None, None]) % num_seq
# Gather values from seq using shifted indices
shifted_seq = seq.gather(1, shifted_idx)
return shifted_seq
def get_large_negative_number(dtype: torch.dtype) -> torch.Tensor:
"""Returns a large negative value for the given dtype."""
if dtype.is_floating_point:
dtype_max = torch.finfo(dtype).max
else:
dtype_max = torch.iinfo(dtype).max
return torch.tensor(-0.7 * dtype_max, dtype=dtype)
def apply_mask_to_logits(logits: torch.Tensor,
mask: torch.Tensor) -> torch.Tensor:
"""Applies a floating-point mask to a set of logits.
Args:
logits: A torch.Tensor of logit values.
mask: A torch.Tensor (float32) of mask values with the encoding described
in the function documentation.
Returns:
Masked logits.
"""
min_value = get_large_negative_number(logits.dtype)
return torch.where((mask >= min_value * 0.5), logits, min_value)
def convert_paddings_to_mask(
paddings: torch.Tensor, dtype: torch.dtype = torch.float32) -> torch.Tensor:
"""Converts binary paddings to a logit mask ready to add to attention matrix.
Args:
paddings: binary torch.Tensor of shape [B, T], with 1 denoting padding
token.
dtype: data type of the input.
Returns:
A torch.Tensor of shape [B, 1, 1, T] ready to add to attention logits.
"""
attention_mask = paddings[:, None, None, :] # Equivalent to jnp.newaxis
attention_mask *= get_large_negative_number(dtype)
return attention_mask
def causal_mask(input_t: torch.Tensor) -> torch.Tensor:
"""Computes and returns causal mask.
Args:
input_t: A torch.Tensor of shape [B, T, D].
Returns:
An attention_mask torch.Tensor of shape [1, 1, T, T]. Attention mask has
already been converted to large negative values.
"""
assert input_t.dtype.is_floating_point, input_t.dtype
large_negative_number = get_large_negative_number(input_t.dtype)
t = input_t.shape[1]
col_idx = torch.arange(t).unsqueeze(0).repeat(t, 1)
row_idx = torch.arange(t).unsqueeze(1).repeat(1, t)
mask = (row_idx < col_idx).to(input_t.dtype) * large_negative_number
return (mask.unsqueeze(0).unsqueeze(0).to(input_t.device)
) # Equivalent to jnp.newaxis
def merge_masks(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
"""Merges 2 masks.
logscale mask is expected but 0/1 mask is also fine.
Args:
a: torch.Tensor of shape [1|B, 1, 1|T, S].
b: torch.Tensor of shape [1|B, 1, 1|T, S].
Returns:
torch.Tensor of shape [1|B, 1, 1|T, S].
"""
def expand_t(key_mask):
query_mask = key_mask.transpose(-1, -2) # Equivalent of jnp.transpose
return torch.minimum(query_mask, key_mask)
if a.shape[2] != b.shape[2]:
if a.shape[2] == 1:
a = expand_t(a)
else:
assert b.shape[2] == 1
b = expand_t(b)
assert a.shape[1:] == b.shape[1:], f"a.shape={a.shape}, b.shape={b.shape}."
return torch.minimum(a, b) # Element-wise minimum, similar to jnp.minimum
class ResidualBlock(nn.Module):
"""TimesFM residual block."""
def __init__(
self,
input_dims,
hidden_dims,
output_dims,
):
super(ResidualBlock, self).__init__()
self.input_dims = input_dims
self.hidden_dims = hidden_dims
self.output_dims = output_dims
# Hidden Layer
self.hidden_layer = nn.Linear(input_dims, hidden_dims)
# Activation Function
self.act = nn.SiLU()
# Output Layer
self.output_layer = nn.Linear(hidden_dims, output_dims)
# Residual Layer
self.residual_layer = nn.Linear(input_dims, output_dims)
def forward(self, x):
hidden = self.act(self.hidden_layer(x))
output = self.output_layer(hidden)
residual = self.residual_layer(x)
return output + residual
class RMSNorm(torch.nn.Module):
"""Pax rms norm in pytorch."""
def __init__(
self,
dim: int,
eps: float = 1e-6,
add_unit_offset: bool = False,
):
super().__init__()
self.eps = eps
self.add_unit_offset = add_unit_offset
self.weight = nn.Parameter(torch.zeros(dim))
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
output = self._norm(x.float())
if self.add_unit_offset:
output = output * (1 + self.weight.float())
else:
output = output * self.weight.float()
return output.type_as(x)
class TransformerMLP(nn.Module):
"""Pax transformer MLP in pytorch."""
def __init__(
self,
hidden_size: int,
intermediate_size: int,
):
super().__init__()
self.gate_proj = nn.Linear(hidden_size, intermediate_size)
self.down_proj = nn.Linear(intermediate_size, hidden_size)
self.layer_norm = nn.LayerNorm(normalized_shape=hidden_size, eps=1e-6)
def forward(self, x, paddings=None):
gate_inp = self.layer_norm(x)
gate = self.gate_proj(gate_inp)
gate = F.relu(gate)
outputs = self.down_proj(gate)
if paddings is not None:
outputs = outputs * (1.0 - paddings[:, :, None])
return outputs + x
class TimesFMAttention(nn.Module):
"""Implements the attention used in TimesFM."""
def __init__(
self,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
head_dim: int,
):
super().__init__()
self.num_heads = num_heads
self.num_kv_heads = num_kv_heads
assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
self.hidden_size = hidden_size
self.head_dim = head_dim
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = nn.Parameter(
torch.empty((self.head_dim,), dtype=torch.float32),)
self.qkv_proj = nn.Linear(
self.hidden_size,
(self.num_heads + 2 * self.num_kv_heads) * self.head_dim,
)
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size)
def _per_dim_scaling(self, query: torch.Tensor) -> torch.Tensor:
# [batch_size, n_local_heads, input_len, head_dim]
r_softplus_0 = 1.442695041
softplus_func = torch.nn.Softplus()
scale = r_softplus_0 / math.sqrt(self.head_dim)
scale = scale * softplus_func(self.scaling)
return query * scale[None, None, None, :]
def forward(
self,
hidden_states: torch.Tensor,
mask: torch.Tensor,
kv_write_indices: torch.Tensor | None = None,
kv_cache: Tuple[torch.Tensor, torch.Tensor] | None = None,
) -> torch.Tensor:
hidden_states_shape = hidden_states.shape
assert len(hidden_states_shape) == 3
batch_size, input_len, _ = hidden_states_shape
qkv = self.qkv_proj(hidden_states)
xq, xk, xv = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
xq = xq.view(batch_size, -1, self.num_heads, self.head_dim)
xk = xk.view(batch_size, -1, self.num_kv_heads, self.head_dim)
xv = xv.view(batch_size, -1, self.num_kv_heads, self.head_dim)
xq = self._per_dim_scaling(xq)
# Write new kv cache.
# [batch_size, input_len, n_local_kv_heads, head_dim]
if kv_cache is not None and kv_write_indices is not None:
k_cache, v_cache = kv_cache
k_cache.index_copy_(1, kv_write_indices, xk)
v_cache.index_copy_(1, kv_write_indices, xv)
key = k_cache
value = v_cache
else:
key = xk
value = xv
if self.num_kv_heads != self.num_heads:
# [batch_size, max_seq_len, n_local_heads, head_dim]
key = torch.repeat_interleave(key, self.num_queries_per_kv, dim=2)
value = torch.repeat_interleave(value, self.num_queries_per_kv, dim=2)
# [batch_size, n_local_heads, input_len, head_dim]
q = xq.transpose(1, 2)
# [batch_size, n_local_heads, max_seq_len, head_dim]
k = key.transpose(1, 2)
v = value.transpose(1, 2)
# [batch_size, n_local_heads, input_len, max_seq_len]
scores = torch.matmul(q, k.transpose(2, 3))
scores = scores + mask
scores = F.softmax(scores.float(), dim=-1).type_as(q)
# [batch_size, n_local_heads, input_len, head_dim]
output = torch.matmul(scores, v)
# return scores, output.transpose(1, 2).contiguous()
# [batch_size, input_len, hidden_dim]
output = output.transpose(1, 2).contiguous().view(batch_size, input_len, -1)
output = self.o_proj(output)
return scores, output
class TimesFMDecoderLayer(nn.Module):
"""Transformer layer."""
def __init__(
self,
hidden_size: int,
intermediate_size: int,
num_heads: int,
num_kv_heads: int,
head_dim: int,
rms_norm_eps: float = 1e-6,
):
super().__init__()
self.self_attn = TimesFMAttention(
hidden_size=hidden_size,
num_heads=num_heads,
num_kv_heads=num_kv_heads,
head_dim=head_dim,
)
self.mlp = TransformerMLP(
hidden_size=hidden_size,
intermediate_size=intermediate_size,
)
self.input_layernorm = RMSNorm(hidden_size, eps=rms_norm_eps)
def forward(
self,
hidden_states: torch.Tensor,
mask: torch.Tensor,
paddings: torch.Tensor,
kv_write_indices: torch.Tensor | None = None,
kv_cache: Tuple[torch.Tensor, torch.Tensor] | None = None,
) -> torch.Tensor:
# Self Attention
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
scores, hidden_states = self.self_attn(
hidden_states=hidden_states,
mask=mask,
kv_write_indices=kv_write_indices,
kv_cache=kv_cache,
)
hidden_states = residual + hidden_states
# MLP
hidden_states = self.mlp(hidden_states, paddings=paddings)
return scores, hidden_states
class StackedDecoder(nn.Module):
"""Stacked transformer layer."""
def __init__(
self,
hidden_size: int,
intermediate_size: int,
num_heads: int,
num_kv_heads: int,
head_dim: int,
num_layers: int,
rms_norm_eps: float = 1e-6,
):
super().__init__()
self.layers = nn.ModuleList()
for _ in range(num_layers):
self.layers.append(
TimesFMDecoderLayer(
hidden_size=hidden_size,
intermediate_size=intermediate_size,
num_heads=num_heads,
num_kv_heads=num_kv_heads,
head_dim=head_dim,
rms_norm_eps=rms_norm_eps,
))
def forward(
self,
hidden_states: torch.Tensor,
paddings: torch.Tensor,
kv_write_indices: torch.Tensor | None = None,
kv_caches: List[Tuple[torch.Tensor, torch.Tensor]] | None = None,
) -> torch.Tensor:
padding_mask = convert_paddings_to_mask(paddings, hidden_states.dtype)
atten_mask = causal_mask(hidden_states)
mask = merge_masks(padding_mask, atten_mask)
for i in range(len(self.layers)):
layer = self.layers[i]
kv_cache = kv_caches[i] if kv_caches is not None else None
_, hidden_states = layer(
hidden_states=hidden_states,
mask=mask,
paddings=paddings,
kv_write_indices=kv_write_indices,
kv_cache=kv_cache,
)
return hidden_states
class PositionalEmbedding(torch.nn.Module):
"""Generates position embedding for a given 1-d sequence.
Attributes:
min_timescale: Start of the geometric index. Determines the periodicity of
the added signal.
max_timescale: End of the geometric index. Determines the frequency of the
added signal.
embedding_dims: Dimension of the embedding to be generated.
"""
def __init__(
self,
embedding_dims: int,
min_timescale: int = 1,
max_timescale: int = 10_000,
) -> None:
super().__init__()
self.min_timescale = min_timescale
self.max_timescale = max_timescale
self.embedding_dims = embedding_dims
def forward(self, seq_length=None, position=None):
"""Generates a Tensor of sinusoids with different frequencies.
Args:
seq_length: an optional Python int defining the output sequence length.
if the `position` argument is specified.
position: [B, seq_length], optional position for each token in the
sequence, only required when the sequence is packed.
Returns:
[B, seqlen, D] if `position` is specified, else [1, seqlen, D]
"""
if position is None:
assert seq_length is not None
# [1, seqlen]
position = torch.arange(seq_length, dtype=torch.float32).unsqueeze(0)
else:
assert position.ndim == 2, position.shape
num_timescales = self.embedding_dims // 2
log_timescale_increment = math.log(
float(self.max_timescale) / float(self.min_timescale)) / max(
num_timescales - 1, 1)
inv_timescales = self.min_timescale * torch.exp(
torch.arange(num_timescales, dtype=torch.float32) *
-log_timescale_increment)
scaled_time = position.unsqueeze(2) * inv_timescales.unsqueeze(0).unsqueeze(
0)
signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=2)
# Padding to ensure correct embedding dimension
signal = F.pad(signal, (0, 0, 0, self.embedding_dims % 2))
return signal
class PatchedTimeSeriesDecoder(nn.Module):
"""Patched time-series decoder."""
def __init__(self, config: TimesFMConfig):
super().__init__()
self.config = config
self.input_ff_layer = ResidualBlock(
input_dims=2 * config.patch_len,
output_dims=config.hidden_size,
hidden_dims=config.intermediate_size,
)
self.freq_emb = nn.Embedding(num_embeddings=3,
embedding_dim=config.hidden_size)
self.horizon_ff_layer = ResidualBlock(
input_dims=config.hidden_size,
output_dims=config.horizon_len * (1 + len(config.quantiles)),
hidden_dims=config.intermediate_size,
)
self.stacked_transformer = StackedDecoder(
hidden_size=self.config.hidden_size,
intermediate_size=self.config.intermediate_size,
num_heads=self.config.num_heads,
num_kv_heads=self.config.num_kv_heads,
head_dim=self.config.head_dim,
num_layers=self.config.num_layers,
rms_norm_eps=self.config.rms_norm_eps,
)
if self.config.use_positional_embedding:
self.position_emb = PositionalEmbedding(self.config.hidden_size)
def _forward_transform(
self, inputs: torch.Tensor, patched_pads: torch.Tensor
) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
"""Input is of shape [B, N, P]."""
mu, sigma = _masked_mean_std(inputs, patched_pads)
sigma = torch.where(
sigma < self.config.tolerance,
torch.tensor(1.0, dtype=sigma.dtype, device=sigma.device),
sigma,
)
# Normalize each patch
outputs = (inputs - mu[:, None, None]) / sigma[:, None, None]
outputs = torch.where(
torch.abs(inputs - self.config.pad_val) < self.config.tolerance,
torch.tensor(self.config.pad_val,
dtype=outputs.dtype,
device=outputs.device),
outputs,
)
return outputs, (mu, sigma)
def _reverse_transform(
self, outputs: torch.Tensor, stats: tuple[torch.Tensor,
torch.Tensor]) -> torch.Tensor:
"""Output is of shape [B, N, P, Q]."""
mu, sigma = stats
return outputs * sigma[:, None, None, None] + mu[:, None, None, None]
def _preprocess_input(
self,
input_ts: torch.Tensor,
input_padding: torch.Tensor,
) -> tuple[
torch.Tensor,
torch.Tensor,
tuple[torch.Tensor, torch.Tensor] | None,
torch.Tensor,
]:
"""Preprocess input for stacked transformer."""
# Reshape into patches (using view for efficiency)
bsize = input_ts.shape[0]
patched_inputs = input_ts.view(bsize, -1, self.config.patch_len)
patched_pads = input_padding.view(bsize, -1, self.config.patch_len)
patched_inputs = torch.where(
torch.abs(patched_pads - 1.0) < self.config.tolerance,
torch.tensor(0.0,
dtype=patched_inputs.dtype,
device=patched_inputs.device),
patched_inputs,
)
patched_pads = torch.where(
torch.abs(patched_inputs - self.config.pad_val) < self.config.tolerance,
torch.tensor(1.0, dtype=patched_pads.dtype, device=patched_pads.device),
patched_pads,
)
patched_inputs, stats = self._forward_transform(patched_inputs,
patched_pads)
# B x N x D
patched_inputs = patched_inputs * (1.0 - patched_pads)
concat_inputs = torch.cat([patched_inputs, patched_pads], dim=-1)
model_input = self.input_ff_layer(concat_inputs)
# A patch should not be padded even if there is at least one zero.
patched_padding = torch.min(patched_pads,
dim=-1)[0] # Get the values from the min result
if self.config.use_positional_embedding:
pos_emb = self.position_emb(model_input.shape[1]).to(model_input.device)
pos_emb = torch.concat([pos_emb] * model_input.shape[0], dim=0)
pos_emb = _shift_padded_seq(patched_padding, pos_emb)
model_input += pos_emb
return model_input, patched_padding, stats, patched_inputs
def _postprocess_output(
self,
model_output: torch.Tensor,
num_outputs: int,
stats: tuple[torch.Tensor, torch.Tensor],
) -> torch.Tensor:
"""Postprocess output of stacked transformer."""
# B x N x (H.Q)
output_ts = self.horizon_ff_layer(model_output)
# Reshape using view
b, n, _ = output_ts.shape
output_ts = output_ts.view(b, n, self.config.horizon_len, num_outputs)
return self._reverse_transform(output_ts, stats)
def forward(
self,
input_ts: torch.Tensor,
input_padding: torch.LongTensor,
freq: torch.Tensor,
) -> torch.Tensor:
num_outputs = len(self.config.quantiles) + 1
model_input, patched_padding, stats, _ = self._preprocess_input(
input_ts=input_ts,
input_padding=input_padding,
)
f_emb = self.freq_emb(freq) # B x 1 x D
model_input += f_emb
model_output = self.stacked_transformer(model_input, patched_padding)
output_ts = self._postprocess_output(model_output, num_outputs, stats)
return output_ts
def decode(
self,
input_ts: torch.Tensor,
paddings: torch.Tensor,
freq: torch.LongTensor,
horizon_len: int,
output_patch_len: int | None = None,
max_len: int = 512,
return_forecast_on_context: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Auto-regressive decoding without caching.
Args:
input_ts: input time-series and paddings. Time-series shape B x C.
paddings: padding shape B x (C + H) where H is the prediction length.
freq: frequency shape B x 1
horizon_len: prediction length.
output_patch_len: output length to be fetched from one step of
auto-regressive decoding.
max_len: maximum training context length.
return_forecast_on_context: whether to return the model forecast on the
context except the first input patch.
Returns:
Tuple of two forecasting results:
- Point (mean) output predictions as a tensor with shape B x H'.
- Full predictions (mean and quantiles) as a tensor with shape
B x H' x (1 + # quantiles).
In particular, if return_forecast_on_context is True, H' is H plus
the forecastable context length, i.e. context_len - (first) patch_len.
"""
final_out = input_ts
context_len = final_out.shape[1]
full_outputs = []
if paddings.shape[1] != final_out.shape[1] + horizon_len:
raise ValueError(
"Length of paddings must match length of input + horizon_len:"
f" {paddings.shape[1]} != {final_out.shape[1]} + {horizon_len}")
if output_patch_len is None:
output_patch_len = self.config.horizon_len
num_decode_patches = (horizon_len + output_patch_len -
1) // output_patch_len
for step_index in range(num_decode_patches):
current_padding = paddings[:, 0:final_out.shape[1]]
input_ts = final_out[:, -max_len:]
input_padding = current_padding[:, -max_len:]
fprop_outputs = self(input_ts, input_padding, freq)
if return_forecast_on_context and step_index == 0:
# For the first decodings step, collect the model forecast on the
# context except the unavailable first input batch forecast.
new_full_ts = fprop_outputs[:, :-1, :self.config.patch_len, :]
new_full_ts = fprop_outputs.view(new_full_ts.size(0), -1,
new_full_ts.size(3))
full_outputs.append(new_full_ts)
# (full batch, last patch, output_patch_len, index of mean forecast = 0)
new_ts = fprop_outputs[:, -1, :output_patch_len, 0]
new_full_ts = fprop_outputs[:, -1, :output_patch_len, :]
# (full batch, last patch, output_patch_len, all output indices)
full_outputs.append(new_full_ts)
final_out = torch.concatenate([final_out, new_ts], axis=-1)
if return_forecast_on_context:
# `full_outputs` indexing starts at after the first input patch.
full_outputs = torch.concatenate(
full_outputs,
axis=1)[:, :(context_len - self.config.patch_len + horizon_len), :]
else:
# `full_outputs` indexing starts at the forecast horizon.
full_outputs = torch.concatenate(full_outputs, axis=1)[:,
0:horizon_len, :]
return (full_outputs[:, :, 0], full_outputs)