Skip to content

unet3d - koila.errors.UnsupportedError #22

@etienne87

Description

@etienne87

I am trying to apply koila lazy eval on a Unet3D.

# defining the model
import torch
import torch.nn as nn
import torch.nn.functional as F


def conv3(in_channels, out_channels, stride, norm='BatchNorm3d', act='GELU'):
    return nn.Sequential(
            nn.Conv3d(in_channels, out_channels, 3, 1, 1),
            getattr(nn, norm)(out_channels),
            getattr(nn, act)())


def double_conv3(in_channels, out_channels, stride):
    return nn.Sequential(conv3(in_channels, out_channels, 1),
                         conv3(out_channels, out_channels, stride))

def merge_skip(x, skip):
    x = F.upsample(x, size=skip.shape[-3:], mode='trilinear', align_corners=True)
    return torch.cat((x,skip),dim=1)



class Unet3D(nn.Module):
    def __init__(self, in_channels, out_channels, num_layers=4, base=16):  
        super().__init__()
	
        enc_channels = [in_channels]+[base * 2**i for i in range(num_layers)]
        dec_channels = [base * 2**i for i in range(num_layers-1,-1,-1)]+[out_channels]

        self.encoders = nn.ModuleList()
        for i in range(len(enc_channels)-1):
            cin = enc_channels[i]
            cout = enc_channels[i+1]
            enc = double_conv3(cin, cout, 2)
            self.encoders.append(enc)

        self.decoders = nn.ModuleList()
        for i in range(len(dec_channels)-1):
            cin_skip = enc_channels[-i-2]
            cin_up = dec_channels[i]
            cin = cin_skip + cin_up 
            cout = dec_channels[i+1]
            dec = double_conv3(cin, cout, 1)	
            self.decoders.append(dec)

    def forward(self, x, return_all=False):
        out = [x]
        for encoder in self.encoders:
            x = encoder(x)
            out.append(x)
        n = len(out)
        for i, decoder in enumerate(self.decoders): 
            skip = out[n - 2 - i]
            x = merge_skip(out[-1], skip)
            x = decoder(x)
            out.append(x)

        if return_all:
            return out 
        else:
            return out[-1]

# test of koila on unet
def test_lazy():
    net = Unet3D(1,3)
    net.cuda()
    s = 64 
    b,c,d,h,w = 2,1,s,s,s
    x = torch.randn(b,c,d,h,w).cuda()
    t = torch.randint(0,3, (b,d,h,w)).cuda()

    loss_fn = nn.CrossEntropyLoss()
    net.zero_grad()

    lazy_x, lazy_t = lazy(x, t, batch=0)
    lazy_out = net(lazy_x)
    lazy_loss = loss_fn(lazy_out, lazy_t) 
    assert isinstance(lazy_loss, LazyTensor), type(lazy_loss)
    lazy_loss.backward()



# This fails
test_lazy()

This fails and outputs:

tensors = (tensor([[[[[-8.9936e-02, -7.9037e-02, -1.5048e-02,  ...,  2.9969e-01,
             2.9774e-01, -1.0489e-01],
        ...]]], device='cuda:0',
       grad_fn=<UpsampleTrilinear3DBackward1>), <koila.lazy.LazyTensor object at 0x7fa21bf99880>)
dim = 1, args = (), kwargs = {}, shapes = [torch.Size([2, 128, 64, 64, 64]), (2, 64, 64, 64, 64)]
no_dim = [torch.Size([2, 64, 64, 64]), (2, 64, 64, 64)], result_size = torch.Size([2, 64, 64, 64])
size = (2, 64, 64, 64)

    def cat(
        tensors: Sequence[TensorLike], dim: int = 0, *args: Any, **kwargs: Any
    ) -> PrePass:
        mute_unused_args(*args, **kwargs)

        if len(tensors) == 0:
            raise ValueError("Expected a sequence of tensors. Got empty sequence.")

        shapes = [t.size() for t in tensors]
        no_dim = [t[:dim] + t[dim + 1 :] for t in shapes]

        result_size = no_dim[0]
        for size in no_dim[1:]:
            if result_size != size:
                raise ValueError(
                    f"Dimension should be equal outside dim {dim}. Got {shapes}."
                )

        if len(set(interfaces.bat(t) for t in tensors)) != 1:
>           raise UnsupportedError
E           koila.errors.UnsupportedError

../miniconda3/envs/snakes/lib/python3.9/site-packages/koila/prepasses.py:423: UnsupportedError

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions