-
Notifications
You must be signed in to change notification settings - Fork 65
Open
Description
I am trying to apply koila lazy eval on a Unet3D.
# defining the model
import torch
import torch.nn as nn
import torch.nn.functional as F
def conv3(in_channels, out_channels, stride, norm='BatchNorm3d', act='GELU'):
return nn.Sequential(
nn.Conv3d(in_channels, out_channels, 3, 1, 1),
getattr(nn, norm)(out_channels),
getattr(nn, act)())
def double_conv3(in_channels, out_channels, stride):
return nn.Sequential(conv3(in_channels, out_channels, 1),
conv3(out_channels, out_channels, stride))
def merge_skip(x, skip):
x = F.upsample(x, size=skip.shape[-3:], mode='trilinear', align_corners=True)
return torch.cat((x,skip),dim=1)
class Unet3D(nn.Module):
def __init__(self, in_channels, out_channels, num_layers=4, base=16):
super().__init__()
enc_channels = [in_channels]+[base * 2**i for i in range(num_layers)]
dec_channels = [base * 2**i for i in range(num_layers-1,-1,-1)]+[out_channels]
self.encoders = nn.ModuleList()
for i in range(len(enc_channels)-1):
cin = enc_channels[i]
cout = enc_channels[i+1]
enc = double_conv3(cin, cout, 2)
self.encoders.append(enc)
self.decoders = nn.ModuleList()
for i in range(len(dec_channels)-1):
cin_skip = enc_channels[-i-2]
cin_up = dec_channels[i]
cin = cin_skip + cin_up
cout = dec_channels[i+1]
dec = double_conv3(cin, cout, 1)
self.decoders.append(dec)
def forward(self, x, return_all=False):
out = [x]
for encoder in self.encoders:
x = encoder(x)
out.append(x)
n = len(out)
for i, decoder in enumerate(self.decoders):
skip = out[n - 2 - i]
x = merge_skip(out[-1], skip)
x = decoder(x)
out.append(x)
if return_all:
return out
else:
return out[-1]
# test of koila on unet
def test_lazy():
net = Unet3D(1,3)
net.cuda()
s = 64
b,c,d,h,w = 2,1,s,s,s
x = torch.randn(b,c,d,h,w).cuda()
t = torch.randint(0,3, (b,d,h,w)).cuda()
loss_fn = nn.CrossEntropyLoss()
net.zero_grad()
lazy_x, lazy_t = lazy(x, t, batch=0)
lazy_out = net(lazy_x)
lazy_loss = loss_fn(lazy_out, lazy_t)
assert isinstance(lazy_loss, LazyTensor), type(lazy_loss)
lazy_loss.backward()
# This fails
test_lazy()
This fails and outputs:
tensors = (tensor([[[[[-8.9936e-02, -7.9037e-02, -1.5048e-02, ..., 2.9969e-01,
2.9774e-01, -1.0489e-01],
...]]], device='cuda:0',
grad_fn=<UpsampleTrilinear3DBackward1>), <koila.lazy.LazyTensor object at 0x7fa21bf99880>)
dim = 1, args = (), kwargs = {}, shapes = [torch.Size([2, 128, 64, 64, 64]), (2, 64, 64, 64, 64)]
no_dim = [torch.Size([2, 64, 64, 64]), (2, 64, 64, 64)], result_size = torch.Size([2, 64, 64, 64])
size = (2, 64, 64, 64)
def cat(
tensors: Sequence[TensorLike], dim: int = 0, *args: Any, **kwargs: Any
) -> PrePass:
mute_unused_args(*args, **kwargs)
if len(tensors) == 0:
raise ValueError("Expected a sequence of tensors. Got empty sequence.")
shapes = [t.size() for t in tensors]
no_dim = [t[:dim] + t[dim + 1 :] for t in shapes]
result_size = no_dim[0]
for size in no_dim[1:]:
if result_size != size:
raise ValueError(
f"Dimension should be equal outside dim {dim}. Got {shapes}."
)
if len(set(interfaces.bat(t) for t in tensors)) != 1:
> raise UnsupportedError
E koila.errors.UnsupportedError
../miniconda3/envs/snakes/lib/python3.9/site-packages/koila/prepasses.py:423: UnsupportedError
Metadata
Metadata
Assignees
Labels
No labels