diff --git a/src/timesfm_torch/pytorch_patched_decoder.py b/src/timesfm_torch/pytorch_patched_decoder.py index 059a906..eb10be5 100644 --- a/src/timesfm_torch/pytorch_patched_decoder.py +++ b/src/timesfm_torch/pytorch_patched_decoder.py @@ -260,18 +260,16 @@ def __init__( self.output_dims = output_dims # Hidden Layer - self.hidden_layer = nn.Sequential( - nn.Linear(input_dims, hidden_dims), - nn.SiLU(), - ) - + self.hidden_layer = nn.Linear(input_dims, hidden_dims) + # Activation Function + self.act = nn.SiLU() # Output Layer self.output_layer = nn.Linear(hidden_dims, output_dims) # Residual Layer self.residual_layer = nn.Linear(input_dims, output_dims) def forward(self, x): - hidden = self.hidden_layer(x) + hidden = self.act(self.hidden_layer(x)) output = self.output_layer(hidden) residual = self.residual_layer(x) return output + residual