Description
Summary
Introduce improved skip connections to the Lensiformer architecture to help preserve fine-grained spatial features and improve gradient flow. The modification adds residual connections between the transformer blocks' inputs and their outputs. This update aims at enhancing accuracy and stability by fusing the original tokenized representations with the outputs of each transformer block.
Proposed Changes
-
Store Initial Tokens: Before entering the transformer blocks loop, clone the initial tokens from the initial tokenization step.
-
Residual Update: For each transformer block, compute the block output and add a skip connection from the previous tokens (i.e. perform an element-wise addition between the block output and the stored skip connection). Update the skip connection for the next iteration.
-
Propagation: Continue this process for all transformer blocks to cumulatively merge features from the initial tokens at each level.
Example Code Modification
Below is an updated snippet from lensiformer.py
demonstrating the improved skip connections in the forward
method:
batch_size = images.size(0)
# Tokenize input images into patches
initial_patches = self.initial_tokenizer(images.reshape(batch_size, 1, self.image_size, self.image_size))
# Encode images and patches
lens_corrected_images = self.encoder(images, initial_patches)
lens_corrected_patches = self.secondary_tokenizer(lens_corrected_images.reshape(batch_size, 1, self.image_size, self.image_size))
# Initialize skip connection using the initial tokenization
skip_patches = initial_patches.clone()
# Pass through transformer blocks with improved skip (residual) connections
for block in self.transformer_blocks:
block_output = block(key=initial_patches, value=lens_corrected_patches)
# Combine the block output with the skip using a residual connection
initial_patches = block_output + skip_patches
# Update skip_patches for the next iteration
skip_patches = initial_patches.clone()
# Flatten the patches
flattened_patches = self.flatten_layer(initial_patches)
# Generate final predictions
final_predictions = self.feedforward_layer(flattened_patches)
return final_predictions
Impact
By fusing the initial token information with the transformer block outputs at each layer, the network is expected to maintain more detailed spatial features throughout its depth, potentially leading to improved accuracy and precision in gravitational lensing inversion tasks.
Please review the above changes and let me know if any further refinements are needed.