Skip to content

Commit

Permalink
add support for inference sans positional encoding
Browse files Browse the repository at this point in the history
  • Loading branch information
siriuz42 committed Dec 12, 2024
1 parent 02bc2f2 commit 27f4037
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 10 deletions.
22 changes: 12 additions & 10 deletions src/timesfm/patched_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,7 @@ class PatchedTimeSeriesDecoder(base_layer.BaseLayer):
stacked_transformer_params_tpl: LayerTpl = template_field(
transformers.StackedTransformer)
use_freq: bool = True
use_pos_emb: bool = True

def setup(self) -> None:
"""Construct the model."""
Expand Down Expand Up @@ -332,16 +333,17 @@ def _preprocess_input(
model_input = self.input_ff_layer(concat_inputs)
# A patch should not be padded even if there is at least one zero.
patched_padding = jnp.min(patched_pads, axis=-1)

if pos_emb is None:
position_emb = self.position_emb(seq_length=model_input.shape[1])
else:
position_emb = pos_emb
if self.do_eval:
if position_emb.shape[0] != model_input.shape[0]:
position_emb = jnp.repeat(position_emb, model_input.shape[0], axis=0)
position_emb = _shift_padded_seq(patched_padding, position_emb)
model_input += position_emb

if self.use_pos_emb:
if pos_emb is None:
position_emb = self.position_emb(seq_length=model_input.shape[1])
else:
position_emb = pos_emb
if self.do_eval:
if position_emb.shape[0] != model_input.shape[0]:
position_emb = jnp.repeat(position_emb, model_input.shape[0], axis=0)
position_emb = _shift_padded_seq(patched_padding, position_emb)
model_input += position_emb

return model_input, patched_padding, stats, patched_inputs

Expand Down
2 changes: 2 additions & 0 deletions src/timesfm/timesfm_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ class TimesFmHparams:
per_core_batch_size: int = 32
backend: Literal["cpu", "gpu", "tpu"] = "cpu"
quantiles: Sequence[float] | None = DEFAULT_QUANTILES
use_positional_embedding: bool = True
# Hparams beyond the model.
point_forecast_mode: Literal["mean", "median"] = "median"

Expand Down Expand Up @@ -172,6 +173,7 @@ def __init__(self, hparams: TimesFmHparams,
self.backend = hparams.backend
self.quantiles = hparams.quantiles
self.num_heads = hparams.num_heads
self.use_pos_emb = hparams.use_positional_embedding

# Rewrite these values in __post_init__ for SPMD.
self.num_cores = 1
Expand Down
1 change: 1 addition & 0 deletions src/timesfm/timesfm_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ def load_from_checkpoint(
residual_block_tpl=pax_fiddle.Config(patched_decoder.ResidualBlock),
quantiles=self.quantiles,
use_freq=True,
use_pos_emb=self.use_pos_emb,
stacked_transformer_params_tpl=pax_fiddle.Config(
transformers.StackedTransformer,
num_heads=self.num_heads,
Expand Down
1 change: 1 addition & 0 deletions src/timesfm/timesfm_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def __post_init__(self):
horizon_len=self.output_patch_len,
head_dim=self.model_dims // self.num_heads,
quantiles=self.quantiles,
use_positional_embedding=self.use_pos_emb,
)
self._model = None
self.num_cores = 1
Expand Down

0 comments on commit 27f4037

Please sign in to comment.