Skip to content

Commit f56fd59

Browse files
authored
Allow subset reconstruction (#74)
* Implementation for randomized subset reconstruction * Code Improvement
1 parent a3efe5d commit f56fd59

File tree

11 files changed

+440
-161
lines changed

11 files changed

+440
-161
lines changed

CHANGELOG.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,12 @@ and this project adheres to [Semantic Versioning][].
1414
### Added
1515

1616
- Add support for Python 3.13
17+
- Allow subset reconstruction
18+
- Allow gradient scaling in the last layer
19+
20+
### Changed
21+
22+
- Minor code improvements
1723

1824
## [0.2.0] - 2025-11-24
1925

src/drvi/nn_modules/gradients.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
import torch
2+
3+
4+
class GradScale(torch.autograd.Function):
5+
@staticmethod
6+
def forward(ctx, x, scale):
7+
ctx.scale = scale
8+
return x # forward pass unchanged
9+
10+
@staticmethod
11+
def backward(ctx, grad_output):
12+
return grad_output * ctx.scale, None # scale gradient only
13+
14+
15+
def grad_scale(x, scale):
16+
return GradScale.apply(x, scale)
17+
18+
19+
class GradientScaler(torch.nn.Module):
20+
def __init__(self, scale: float):
21+
super().__init__()
22+
self.register_buffer("scale", torch.tensor(scale, dtype=torch.float32))
23+
24+
def forward(self, x):
25+
return grad_scale(x, self.scale)
26+
27+
def extra_repr(self):
28+
return f"scale={self.scale.item()}"

src/drvi/nn_modules/layer/factory.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
from torch import nn
66

7-
from drvi.nn_modules.layer.linear_layer import StackedLinearLayer
7+
from drvi.nn_modules.layer.linear_layer import LinearLayer, StackedLinearLayer
88
from drvi.nn_modules.layer.structures import SimpleResidual
99

1010
if TYPE_CHECKING:
@@ -145,7 +145,7 @@ def get_normal_layer(
145145
if intermediate_layer is None:
146146
intermediate_layer = True
147147
if intermediate_layer and self.intermediate_arch == "FC":
148-
layer = nn.Linear(d_in, d_out, bias)
148+
layer = LinearLayer(d_in, d_out, bias)
149149
elif (not intermediate_layer) or self.intermediate_arch == "SAME":
150150
layer = self._get_normal_layer(d_in, d_out, bias=True, **kwargs)
151151
else:
@@ -243,7 +243,7 @@ class FCLayerFactory(LayerFactory):
243243
Notes
244244
-----
245245
This factory creates:
246-
- Normal layers: `nn.Linear` layers
246+
- Normal layers: `LinearLayer` layers
247247
- Stacked layers: `StackedLinearLayer` for processing multiple splits
248248
249249
The "SAME" and "FC" architectures are equivalent for this factory since
@@ -264,7 +264,7 @@ class FCLayerFactory(LayerFactory):
264264
def __init__(self, intermediate_arch: Literal["SAME", "FC"] = "SAME", residual_preferred: bool = False) -> None:
265265
super().__init__(intermediate_arch=intermediate_arch, residual_preferred=residual_preferred)
266266

267-
def _get_normal_layer(self, d_in: int, d_out: int, bias: bool = True, **kwargs: Any) -> nn.Linear:
267+
def _get_normal_layer(self, d_in: int, d_out: int, bias: bool = True, **kwargs: Any) -> LinearLayer:
268268
"""Create a fully connected layer.
269269
270270
Parameters
@@ -280,7 +280,7 @@ def _get_normal_layer(self, d_in: int, d_out: int, bias: bool = True, **kwargs:
280280
281281
Returns
282282
-------
283-
nn.Linear
283+
LinearLayer
284284
A fully connected linear layer.
285285
286286
Examples
@@ -290,7 +290,7 @@ def _get_normal_layer(self, d_in: int, d_out: int, bias: bool = True, **kwargs:
290290
>>> print(layer.weight.shape) # torch.Size([128, 64])
291291
>>> print(layer.bias.shape) # torch.Size([128])
292292
"""
293-
return nn.Linear(d_in, d_out, bias=bias)
293+
return LinearLayer(d_in, d_out, bias=bias)
294294

295295
def _get_stacked_layer(
296296
self, d_channel: int, d_in: int, d_out: int, bias: bool = True, **kwargs: Any

src/drvi/nn_modules/layer/linear_layer.py

Lines changed: 33 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,22 @@
55

66
import torch
77
from torch import nn
8+
from torch.nn import functional as F
9+
10+
11+
class LinearLayer(nn.Linear):
12+
def forward(self, x: torch.Tensor, output_subset: torch.Tensor | None = None) -> torch.Tensor:
13+
if output_subset is None:
14+
# x: (..., i) -> output: (..., o)
15+
return super().forward(x)
16+
elif output_subset.dim() == 1:
17+
# x: (..., i) -> output_subset: (o_subset)
18+
bias = self.bias[output_subset] if self.bias is not None else None # (o_subset)
19+
weight = self.weight[output_subset] # (o_subset, i)
20+
return F.linear(x, weight, bias) # (..., i) -> (..., o_subset)
21+
else:
22+
raise NotImplementedError()
23+
824

925
if TYPE_CHECKING:
1026
from typing import Any
@@ -39,7 +55,7 @@ class StackedLinearLayer(nn.Module):
3955
- Bias shape: (n_channels, out_features) if bias=True, None otherwise
4056
4157
The forward pass applies the transformation to each channel independently:
42-
output[b, c, o] = sum_i(input[b, c, i] * weight[c, i, o]) + bias[c, o]
58+
output[b, c, o] = sum_i(x[b, c, i] * weight[c, i, o]) + bias[c, o]
4359
4460
This is equivalent to applying n_channels separate linear layers in parallel,
4561
which is more efficient than using separate nn.Linear layers.
@@ -137,13 +153,15 @@ def _init_bias(self) -> None:
137153
bound = 1 / math.sqrt(fan_in)
138154
nn.init.uniform_(self.bias, -bound, bound)
139155

140-
def forward(self, input: torch.Tensor) -> torch.Tensor:
156+
def forward(self, x: torch.Tensor, output_subset: torch.Tensor | None = None) -> torch.Tensor:
141157
r"""Forward pass through the stacked linear layer.
142158
143159
Parameters
144160
----------
145-
input
161+
x
146162
Input tensor with shape (batch_size, n_channels, in_features).
163+
output_subset
164+
Subset of outputs to provide in the output.
147165
148166
Returns
149167
-------
@@ -178,10 +196,18 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
178196
>>> output = layer(x)
179197
>>> print(output.shape) # torch.Size([2, 3, 5])
180198
"""
181-
mm = torch.einsum("bci,cio->bco", input, self.weight)
182-
if self.bias is not None:
183-
mm = mm + self.bias # They will broadcast well
184-
return mm
199+
if output_subset is None or output_subset.dim() == 1:
200+
# weight: (c, i, o), bias: (c, o)
201+
# x: (b, c, i), output_subset: (o_subset) -> output: (b, c, o_subset)
202+
weight = self.weight if output_subset is None else self.weight[:, :, output_subset] # (c, i, o_subset)
203+
# slower: mm = torch.einsum("bci,cio->bco", x, weight)
204+
mm = torch.bmm(x.transpose(0, 1), weight).transpose(0, 1) # (b, c, o_subset)
205+
if self.bias is not None:
206+
bias = self.bias if output_subset is None else self.bias[:, output_subset] # (c, o_subset)
207+
mm = mm + bias # They (bco, co) will broadcast well
208+
return mm
209+
else:
210+
raise NotImplementedError()
185211

186212
def extra_repr(self) -> str:
187213
"""String representation for printing the layer.

0 commit comments

Comments
 (0)