Skip to content

Commit 0893f5d

Browse files
committed
Initial NaFlex ViT model and training support
1 parent e44f14d commit 0893f5d

12 files changed

+2928
-164
lines changed

timm/data/__init__.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,13 @@
88
from .imagenet_info import ImageNetInfo, infer_imagenet_subset
99
from .loader import create_loader
1010
from .mixup import Mixup, FastCollateMixup
11+
from .naflex_transforms import (
12+
ResizeToSequence,
13+
CenterCropToSequence,
14+
RandomCropToSequence,
15+
RandomResizedCropToSequence,
16+
ResizeKeepRatioToSequence,
17+
)
1118
from .readers import create_reader
1219
from .readers import get_img_extensions, is_img_extension, set_img_extensions, add_img_extensions, del_img_extensions
1320
from .real_labels import RealLabelsImagenet

timm/data/naflex_dataset.py

Lines changed: 422 additions & 0 deletions
Large diffs are not rendered by default.

timm/data/naflex_loader.py

Lines changed: 341 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,341 @@
1+
import math
2+
from contextlib import suppress
3+
from functools import partial
4+
from typing import Callable, List, Optional, Tuple, Union
5+
6+
import torch
7+
8+
from .constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
9+
from .loader import _worker_init
10+
from .naflex_dataset import VariableSeqMapWrapper
11+
from .transforms_factory import create_transform
12+
13+
14+
class NaFlexCollator:
15+
"""Custom collator for batching NaFlex-style variable-resolution images."""
16+
17+
def __init__(
18+
self,
19+
patch_size=16,
20+
max_seq_len=None,
21+
):
22+
self.patch_size = patch_size
23+
self.max_seq_len = max_seq_len or 576 # Default ViT-B/16 sequence length (577 = 24*24)
24+
25+
def __call__(self, batch):
26+
"""
27+
Args:
28+
batch: List of tuples (patch_dict, target)
29+
30+
Returns:
31+
A tuple of (input_dict, targets) where input_dict contains:
32+
- patches: Padded tensor of patches
33+
- patch_coord: Coordinates for each patch (y, x)
34+
- patch_valid: Valid indicators
35+
"""
36+
assert isinstance(batch[0], tuple)
37+
batch_size = len(batch)
38+
39+
# FIXME
40+
# get seq len from sampler schedule
41+
42+
# resize to final size based on seq_len and patchify
43+
44+
# Extract targets
45+
targets = torch.tensor([item[1] for item in batch], dtype=torch.int64)
46+
47+
# Get patch dictionaries
48+
patch_dicts = [item[0] for item in batch]
49+
50+
# If we have a maximum sequence length constraint, ensure we don't exceed it
51+
if self.max_seq_len is not None:
52+
max_patches = self.max_seq_len
53+
else:
54+
# Find the maximum number of patches in this batch
55+
max_patches = max(item['patches'].shape[0] for item in patch_dicts)
56+
57+
# Get patch dimensionality
58+
patch_dim = patch_dicts[0]['patches'].shape[1]
59+
60+
# Prepare tensors for the batch
61+
patches = torch.zeros((batch_size, max_patches, patch_dim), dtype=torch.float32)
62+
patch_coord = torch.zeros((batch_size, max_patches, 2), dtype=torch.int64) # [B, N, 2] for (y, x)
63+
patch_valid = torch.zeros((batch_size, max_patches), dtype=torch.bool)
64+
65+
# Fill in the tensors
66+
for i, patch_dict in enumerate(patch_dicts):
67+
num_patches = min(patch_dict['patches'].shape[0], max_patches)
68+
69+
patches[i, :num_patches] = patch_dict['patches'][:num_patches]
70+
patch_coord[i, :num_patches] = patch_dict['patch_coord'][:num_patches]
71+
patch_valid[i, :num_patches] = patch_dict['patch_valid'][:num_patches]
72+
73+
return {
74+
'patches': patches,
75+
'patch_coord': patch_coord,
76+
'patch_valid': patch_valid,
77+
'seq_len': max_patches,
78+
}, targets
79+
80+
81+
class NaFlexPrefetchLoader:
82+
"""Data prefetcher for NaFlex format which normalizes patches."""
83+
84+
def __init__(
85+
self,
86+
loader,
87+
mean=(0.485, 0.456, 0.406),
88+
std=(0.229, 0.224, 0.225),
89+
img_dtype=torch.float32,
90+
device=torch.device('cuda')
91+
):
92+
self.loader = loader
93+
self.device = device
94+
self.img_dtype = img_dtype or torch.float32
95+
96+
# Create mean/std tensors for normalization (will be applied to patches)
97+
self.mean = torch.tensor([x * 255 for x in mean], device=device, dtype=self.img_dtype).view(1, 1, 3)
98+
self.std = torch.tensor([x * 255 for x in std], device=device, dtype=self.img_dtype).view(1, 1, 3)
99+
100+
# Check for CUDA/NPU availability
101+
self.is_cuda = device.type == 'cuda' and torch.cuda.is_available()
102+
self.is_npu = device.type == 'npu' and torch.npu.is_available()
103+
104+
def __iter__(self):
105+
first = True
106+
if self.is_cuda:
107+
stream = torch.cuda.Stream()
108+
stream_context = partial(torch.cuda.stream, stream=stream)
109+
elif self.is_npu:
110+
stream = torch.npu.Stream()
111+
stream_context = partial(torch.npu.stream, stream=stream)
112+
else:
113+
stream = None
114+
stream_context = suppress
115+
116+
for next_input_dict, next_target in self.loader:
117+
with stream_context():
118+
# Move all tensors in input_dict to device
119+
for k, v in next_input_dict.items():
120+
if isinstance(v, torch.Tensor):
121+
dtype = self.img_dtype if k == 'patches' else None
122+
next_input_dict[k] = next_input_dict[k].to(
123+
device=self.device,
124+
non_blocking=True,
125+
dtype=dtype,
126+
)
127+
128+
next_target = next_target.to(device=self.device, non_blocking=True)
129+
130+
# Normalize patch values (assuming patches are in format [B, N, P*P*C])
131+
batch_size, num_patches, patch_pixels = next_input_dict['patches'].shape
132+
patches = next_input_dict['patches'].view(batch_size, -1, 3) # to [B*N, P*P, C] for normalization
133+
patches = patches.sub(self.mean).div(self.std)
134+
135+
# Reshape back
136+
next_input_dict['patches'] = patches.reshape(batch_size, num_patches, patch_pixels)
137+
138+
if not first:
139+
yield input_dict, target
140+
else:
141+
first = False
142+
143+
if stream is not None:
144+
if self.is_cuda:
145+
torch.cuda.current_stream().wait_stream(stream)
146+
elif self.is_npu:
147+
torch.npu.current_stream().wait_stream(stream)
148+
149+
input_dict = next_input_dict
150+
target = next_target
151+
152+
yield input_dict, target
153+
154+
def __len__(self):
155+
return len(self.loader)
156+
157+
@property
158+
def sampler(self):
159+
return self.loader.sampler
160+
161+
@property
162+
def dataset(self):
163+
return self.loader.dataset
164+
165+
166+
def create_naflex_loader(
167+
dataset,
168+
patch_size: Union[Tuple[int, int], int] = 16,
169+
train_seq_lens: List[int] = (128, 256, 576, 784, 1024), # Training sequence lengths
170+
max_seq_len: int = 576, # Fixed sequence length for validation
171+
batch_size: int = 32, # Used for max_seq_len and max(train_seq_lens)
172+
is_training: bool = False,
173+
174+
no_aug: bool = False,
175+
re_prob: float = 0.,
176+
re_mode: str = 'const',
177+
re_count: int = 1,
178+
re_split: bool = False,
179+
train_crop_mode: Optional[str] = None,
180+
scale: Optional[Tuple[float, float]] = None,
181+
ratio: Optional[Tuple[float, float]] = None,
182+
hflip: float = 0.5,
183+
vflip: float = 0.,
184+
color_jitter: float = 0.4,
185+
color_jitter_prob: Optional[float] = None,
186+
grayscale_prob: float = 0.,
187+
gaussian_blur_prob: float = 0.,
188+
auto_augment: Optional[str] = None,
189+
num_aug_repeats: int = 0,
190+
num_aug_splits: int = 0,
191+
interpolation: str = 'bilinear',
192+
mean: Tuple[float, ...] = IMAGENET_DEFAULT_MEAN,
193+
std: Tuple[float, ...] = IMAGENET_DEFAULT_STD,
194+
crop_pct: Optional[float] = None,
195+
crop_mode: Optional[str] = None,
196+
crop_border_pixels: Optional[int] = None,
197+
198+
num_workers: int = 4,
199+
distributed: bool = False,
200+
rank: int = 0,
201+
world_size: int = 1,
202+
seed: int = 42,
203+
epoch: int = 0,
204+
use_prefetcher: bool = True,
205+
pin_memory: bool = True,
206+
img_dtype: torch.dtype = torch.float32,
207+
device: Union[str, torch.device] = torch.device('cuda'),
208+
persistent_workers: bool = True,
209+
worker_seeding: str = 'all',
210+
):
211+
"""Create a data loader with dynamic sequence length sampling for training."""
212+
213+
if is_training:
214+
# For training, use the dynamic sequence length mechanism
215+
assert num_aug_repeats == 0, 'Augmentation repeats not currently supported in NaFlex loader'
216+
217+
transform_factory = partial(
218+
create_transform,
219+
is_training=True,
220+
no_aug=no_aug,
221+
train_crop_mode=train_crop_mode,
222+
scale=scale,
223+
ratio=ratio,
224+
hflip=hflip,
225+
vflip=vflip,
226+
color_jitter=color_jitter,
227+
color_jitter_prob=color_jitter_prob,
228+
grayscale_prob=grayscale_prob,
229+
gaussian_blur_prob=gaussian_blur_prob,
230+
auto_augment=auto_augment,
231+
interpolation=interpolation,
232+
mean=mean,
233+
std=std,
234+
crop_pct=crop_pct,
235+
crop_mode=crop_mode,
236+
crop_border_pixels=crop_border_pixels,
237+
re_prob=re_prob,
238+
re_mode=re_mode,
239+
re_count=re_count,
240+
use_prefetcher=use_prefetcher,
241+
naflex=True,
242+
)
243+
244+
max_train_seq_len = max(train_seq_lens)
245+
max_tokens_per_batch = batch_size * max_train_seq_len
246+
247+
if isinstance(dataset, torch.utils.data.IterableDataset):
248+
assert False, "IterableDataset Wrapper is a WIP"
249+
250+
naflex_dataset = VariableSeqMapWrapper(
251+
dataset,
252+
transform_factory=transform_factory,
253+
patch_size=patch_size,
254+
seq_lens=train_seq_lens,
255+
max_tokens_per_batch=max_tokens_per_batch,
256+
seed=seed,
257+
distributed=distributed,
258+
rank=rank,
259+
world_size=world_size,
260+
shuffle=True,
261+
epoch=epoch,
262+
)
263+
264+
# NOTE: Collation is handled by the dataset wrapper for training
265+
# Create the collator (handles fixed-size collation)
266+
# collate_fn = NaFlexCollator(
267+
# patch_size=patch_size,
268+
# max_seq_len=max(seq_lens) + 1, # +1 for class token
269+
# use_prefetcher=use_prefetcher
270+
# )
271+
272+
loader = torch.utils.data.DataLoader(
273+
naflex_dataset,
274+
batch_size=None,
275+
shuffle=False,
276+
num_workers=num_workers,
277+
sampler=None,
278+
#collate_fn=collate_fn,
279+
pin_memory=pin_memory,
280+
worker_init_fn=partial(_worker_init, worker_seeding=worker_seeding),
281+
persistent_workers=persistent_workers
282+
)
283+
284+
if use_prefetcher:
285+
loader = NaFlexPrefetchLoader(
286+
loader,
287+
mean=mean,
288+
std=std,
289+
img_dtype=img_dtype,
290+
device=device,
291+
)
292+
293+
else:
294+
# For validation, use fixed sequence length (unchanged)
295+
dataset.transform = create_transform(
296+
is_training=False,
297+
interpolation=interpolation,
298+
mean=mean,
299+
std=std,
300+
# FIXME add crop args when sequence transforms support crop modes
301+
use_prefetcher=use_prefetcher,
302+
naflex=True,
303+
patch_size=patch_size,
304+
max_seq_len=max_seq_len,
305+
patchify=True,
306+
)
307+
308+
# Create the collator
309+
collate_fn = NaFlexCollator(
310+
patch_size=patch_size,
311+
max_seq_len=max_seq_len,
312+
)
313+
314+
# Handle distributed training
315+
sampler = None
316+
if distributed and not isinstance(dataset, torch.utils.data.IterableDataset):
317+
# For validation, use OrderedDistributedSampler
318+
from timm.data.distributed_sampler import OrderedDistributedSampler
319+
sampler = OrderedDistributedSampler(dataset)
320+
321+
loader = torch.utils.data.DataLoader(
322+
dataset,
323+
batch_size=batch_size,
324+
shuffle=False,
325+
num_workers=num_workers,
326+
sampler=sampler,
327+
collate_fn=collate_fn,
328+
pin_memory=pin_memory,
329+
drop_last=False,
330+
)
331+
332+
if use_prefetcher:
333+
loader = NaFlexPrefetchLoader(
334+
loader,
335+
mean=mean,
336+
std=std,
337+
img_dtype=img_dtype,
338+
device=device,
339+
)
340+
341+
return loader

0 commit comments

Comments
 (0)