|
5 | 5 |
|
6 | 6 | import torch |
7 | 7 | from torch import nn |
| 8 | +from torch.nn import functional as F |
| 9 | + |
| 10 | + |
| 11 | +class LinearLayer(nn.Linear): |
| 12 | + def forward(self, x: torch.Tensor, output_subset: torch.Tensor | None = None) -> torch.Tensor: |
| 13 | + if output_subset is None: |
| 14 | + # x: (..., i) -> output: (..., o) |
| 15 | + return super().forward(x) |
| 16 | + elif output_subset.dim() == 1: |
| 17 | + # x: (..., i) -> output_subset: (o_subset) |
| 18 | + bias = self.bias[output_subset] if self.bias is not None else None # (o_subset) |
| 19 | + weight = self.weight[output_subset] # (o_subset, i) |
| 20 | + return F.linear(x, weight, bias) # (..., i) -> (..., o_subset) |
| 21 | + else: |
| 22 | + raise NotImplementedError() |
| 23 | + |
8 | 24 |
|
9 | 25 | if TYPE_CHECKING: |
10 | 26 | from typing import Any |
@@ -39,7 +55,7 @@ class StackedLinearLayer(nn.Module): |
39 | 55 | - Bias shape: (n_channels, out_features) if bias=True, None otherwise |
40 | 56 |
|
41 | 57 | The forward pass applies the transformation to each channel independently: |
42 | | - output[b, c, o] = sum_i(input[b, c, i] * weight[c, i, o]) + bias[c, o] |
| 58 | + output[b, c, o] = sum_i(x[b, c, i] * weight[c, i, o]) + bias[c, o] |
43 | 59 |
|
44 | 60 | This is equivalent to applying n_channels separate linear layers in parallel, |
45 | 61 | which is more efficient than using separate nn.Linear layers. |
@@ -137,13 +153,15 @@ def _init_bias(self) -> None: |
137 | 153 | bound = 1 / math.sqrt(fan_in) |
138 | 154 | nn.init.uniform_(self.bias, -bound, bound) |
139 | 155 |
|
140 | | - def forward(self, input: torch.Tensor) -> torch.Tensor: |
| 156 | + def forward(self, x: torch.Tensor, output_subset: torch.Tensor | None = None) -> torch.Tensor: |
141 | 157 | r"""Forward pass through the stacked linear layer. |
142 | 158 |
|
143 | 159 | Parameters |
144 | 160 | ---------- |
145 | | - input |
| 161 | + x |
146 | 162 | Input tensor with shape (batch_size, n_channels, in_features). |
| 163 | + output_subset |
| 164 | + Subset of outputs to provide in the output. |
147 | 165 |
|
148 | 166 | Returns |
149 | 167 | ------- |
@@ -178,10 +196,18 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: |
178 | 196 | >>> output = layer(x) |
179 | 197 | >>> print(output.shape) # torch.Size([2, 3, 5]) |
180 | 198 | """ |
181 | | - mm = torch.einsum("bci,cio->bco", input, self.weight) |
182 | | - if self.bias is not None: |
183 | | - mm = mm + self.bias # They will broadcast well |
184 | | - return mm |
| 199 | + if output_subset is None or output_subset.dim() == 1: |
| 200 | + # weight: (c, i, o), bias: (c, o) |
| 201 | + # x: (b, c, i), output_subset: (o_subset) -> output: (b, c, o_subset) |
| 202 | + weight = self.weight if output_subset is None else self.weight[:, :, output_subset] # (c, i, o_subset) |
| 203 | + # slower: mm = torch.einsum("bci,cio->bco", x, weight) |
| 204 | + mm = torch.bmm(x.transpose(0, 1), weight).transpose(0, 1) # (b, c, o_subset) |
| 205 | + if self.bias is not None: |
| 206 | + bias = self.bias if output_subset is None else self.bias[:, output_subset] # (c, o_subset) |
| 207 | + mm = mm + bias # They (bco, co) will broadcast well |
| 208 | + return mm |
| 209 | + else: |
| 210 | + raise NotImplementedError() |
185 | 211 |
|
186 | 212 | def extra_repr(self) -> str: |
187 | 213 | """String representation for printing the layer. |
|
0 commit comments