-
Notifications
You must be signed in to change notification settings - Fork 80
Open
Description
Hi everyone, I'm trying to use jaxtyping to hopefully make my code more debuggable. However, I just ran into this problem which I can't seem to fix. So I was trying to put jaxtyping to check one method in a class. I followed instructions from here: link. However I got this error:
" TypeError: <module 'main'> is a built-in module".
I don't know why it got an error even though I followed the tutorial. Below is the code. Thank you.
import torchvision.transforms as transforms
import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.nn import init
from jaxtyping import Array, Float, PyTree, jaxtyped
from typeguard import typechecked as typechecker
from dataclasses import dataclass
# Ensure every computation happens on the GPU when available
device = 'cuda' if torch.cuda.is_available() else 'cpu'
class PatchEmbeddings(nn.Module):
def __init__(self, img_size=96, patch_size=16, hidden_dim=512):
super().__init__()
# Store the input image size
self.img_size = img_size
# Store the size of each patch
self.patch_size = patch_size
# Calculate the total number of patches
self.num_patches = (img_size // patch_size) ** 2
# Create a convolutional layer to extract patch embeddings
# in_channels=3 assumes the input image has 3 color channels (RGB)
# out_channels=hidden_dim sets the number of output channels to match the hidden dimension
# kernel_size=patch_size and stride=patch_size ensure each patch is separately embedded
self.conv = nn.Conv2d(in_channels=3, out_channels=hidden_dim,
kernel_size=patch_size, stride=patch_size)
@jaxtyped(typechecker=typechecker)
def forward(self, X: Float[torch.Tensor, "B C H H"])-> Float[torch.Tensor, "B C img_size img_size"]:
# X: (B, C, img_size, img_size)
# Extract patch embeddings from the input image
# set_trace()
X = self.conv(X) # (B, hidden_dim, img_size // patch_size, img_size // patch_size)
# set_trace()
# Flatten the spatial dimensions (height and width) of the patch embeddings
# This step flattens the patch dimensions into a single dimension
X = X.flatten(2) # (B, hidden_dim, self.num_patches)
# set_trace()
# Transpose the dimensions to obtain the shape [batch_size, num_patches, hidden_dim]
# This step brings the num_patches dimension to the second position
# X = X.transpose(1, 2) # (B, self.num_patches, hidden_dim)
return X
#testing
img_size, patch_size, num_hiddens, batch_size = 96, 16, 512, 4
patch_embeddings = PatchEmbeddings(img_size, patch_size, num_hiddens )
X = torch.zeros(batch_size, 3, img_size, img_size)
patch_embeddings(X).shape
Metadata
Metadata
Assignees
Labels
No labels