Skip to content

TypeError: <module '__main__'> is a built-in module #340

@dangmanhtruong1995

Description

@dangmanhtruong1995

Hi everyone, I'm trying to use jaxtyping to hopefully make my code more debuggable. However, I just ran into this problem which I can't seem to fix. So I was trying to put jaxtyping to check one method in a class. I followed instructions from here: link. However I got this error:

" TypeError: <module 'main'> is a built-in module".

I don't know why it got an error even though I followed the tutorial. Below is the code. Thank you.

import torchvision.transforms as transforms
import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.nn import init

from jaxtyping import Array, Float, PyTree, jaxtyped
from typeguard import typechecked as typechecker
from dataclasses import dataclass

# Ensure every computation happens on the GPU when available
device = 'cuda' if torch.cuda.is_available() else 'cpu'

class PatchEmbeddings(nn.Module):
    def __init__(self, img_size=96, patch_size=16, hidden_dim=512):
        super().__init__()

        # Store the input image size
        self.img_size = img_size

        # Store the size of each patch
        self.patch_size = patch_size

        # Calculate the total number of patches
        self.num_patches = (img_size // patch_size) ** 2

        # Create a convolutional layer to extract patch embeddings
        # in_channels=3 assumes the input image has 3 color channels (RGB)
        # out_channels=hidden_dim sets the number of output channels to match the hidden dimension
        # kernel_size=patch_size and stride=patch_size ensure each patch is separately embedded
        self.conv = nn.Conv2d(in_channels=3, out_channels=hidden_dim,
                              kernel_size=patch_size, stride=patch_size)

    @jaxtyped(typechecker=typechecker)
    def forward(self, X: Float[torch.Tensor, "B C H H"])-> Float[torch.Tensor, "B C img_size img_size"]:
        # X: (B, C, img_size, img_size)
        
        # Extract patch embeddings from the input image
        # set_trace()
        X = self.conv(X) # (B, hidden_dim, img_size // patch_size, img_size // patch_size)
        
        # set_trace()

        # Flatten the spatial dimensions (height and width) of the patch embeddings
        # This step flattens the patch dimensions into a single dimension
        X = X.flatten(2) # (B, hidden_dim, self.num_patches)
        # set_trace()

        # Transpose the dimensions to obtain the shape [batch_size, num_patches, hidden_dim]
        # This step brings the num_patches dimension to the second position
        # X = X.transpose(1, 2) # (B, self.num_patches, hidden_dim)

        return X

#testing
img_size, patch_size,  num_hiddens, batch_size = 96, 16, 512, 4
patch_embeddings = PatchEmbeddings(img_size, patch_size, num_hiddens )
X = torch.zeros(batch_size, 3, img_size, img_size)
patch_embeddings(X).shape

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions