|
| 1 | +import math |
| 2 | +from contextlib import suppress |
| 3 | +from functools import partial |
| 4 | +from typing import Callable, List, Optional, Tuple, Union |
| 5 | + |
| 6 | +import torch |
| 7 | + |
| 8 | +from .constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD |
| 9 | +from .loader import _worker_init |
| 10 | +from .naflex_dataset import VariableSeqMapWrapper |
| 11 | +from .transforms_factory import create_transform |
| 12 | + |
| 13 | + |
| 14 | +class NaFlexCollator: |
| 15 | + """Custom collator for batching NaFlex-style variable-resolution images.""" |
| 16 | + |
| 17 | + def __init__( |
| 18 | + self, |
| 19 | + patch_size=16, |
| 20 | + max_seq_len=None, |
| 21 | + ): |
| 22 | + self.patch_size = patch_size |
| 23 | + self.max_seq_len = max_seq_len or 576 # Default ViT-B/16 sequence length (577 = 24*24) |
| 24 | + |
| 25 | + def __call__(self, batch): |
| 26 | + """ |
| 27 | + Args: |
| 28 | + batch: List of tuples (patch_dict, target) |
| 29 | +
|
| 30 | + Returns: |
| 31 | + A tuple of (input_dict, targets) where input_dict contains: |
| 32 | + - patches: Padded tensor of patches |
| 33 | + - patch_coord: Coordinates for each patch (y, x) |
| 34 | + - patch_valid: Valid indicators |
| 35 | + """ |
| 36 | + assert isinstance(batch[0], tuple) |
| 37 | + batch_size = len(batch) |
| 38 | + |
| 39 | + # FIXME |
| 40 | + # get seq len from sampler schedule |
| 41 | + |
| 42 | + # resize to final size based on seq_len and patchify |
| 43 | + |
| 44 | + # Extract targets |
| 45 | + targets = torch.tensor([item[1] for item in batch], dtype=torch.int64) |
| 46 | + |
| 47 | + # Get patch dictionaries |
| 48 | + patch_dicts = [item[0] for item in batch] |
| 49 | + |
| 50 | + # If we have a maximum sequence length constraint, ensure we don't exceed it |
| 51 | + if self.max_seq_len is not None: |
| 52 | + max_patches = self.max_seq_len |
| 53 | + else: |
| 54 | + # Find the maximum number of patches in this batch |
| 55 | + max_patches = max(item['patches'].shape[0] for item in patch_dicts) |
| 56 | + |
| 57 | + # Get patch dimensionality |
| 58 | + patch_dim = patch_dicts[0]['patches'].shape[1] |
| 59 | + |
| 60 | + # Prepare tensors for the batch |
| 61 | + patches = torch.zeros((batch_size, max_patches, patch_dim), dtype=torch.float32) |
| 62 | + patch_coord = torch.zeros((batch_size, max_patches, 2), dtype=torch.int64) # [B, N, 2] for (y, x) |
| 63 | + patch_valid = torch.zeros((batch_size, max_patches), dtype=torch.bool) |
| 64 | + |
| 65 | + # Fill in the tensors |
| 66 | + for i, patch_dict in enumerate(patch_dicts): |
| 67 | + num_patches = min(patch_dict['patches'].shape[0], max_patches) |
| 68 | + |
| 69 | + patches[i, :num_patches] = patch_dict['patches'][:num_patches] |
| 70 | + patch_coord[i, :num_patches] = patch_dict['patch_coord'][:num_patches] |
| 71 | + patch_valid[i, :num_patches] = patch_dict['patch_valid'][:num_patches] |
| 72 | + |
| 73 | + return { |
| 74 | + 'patches': patches, |
| 75 | + 'patch_coord': patch_coord, |
| 76 | + 'patch_valid': patch_valid, |
| 77 | + 'seq_len': max_patches, |
| 78 | + }, targets |
| 79 | + |
| 80 | + |
| 81 | +class NaFlexPrefetchLoader: |
| 82 | + """Data prefetcher for NaFlex format which normalizes patches.""" |
| 83 | + |
| 84 | + def __init__( |
| 85 | + self, |
| 86 | + loader, |
| 87 | + mean=(0.485, 0.456, 0.406), |
| 88 | + std=(0.229, 0.224, 0.225), |
| 89 | + img_dtype=torch.float32, |
| 90 | + device=torch.device('cuda') |
| 91 | + ): |
| 92 | + self.loader = loader |
| 93 | + self.device = device |
| 94 | + self.img_dtype = img_dtype or torch.float32 |
| 95 | + |
| 96 | + # Create mean/std tensors for normalization (will be applied to patches) |
| 97 | + self.mean = torch.tensor([x * 255 for x in mean], device=device, dtype=self.img_dtype).view(1, 1, 3) |
| 98 | + self.std = torch.tensor([x * 255 for x in std], device=device, dtype=self.img_dtype).view(1, 1, 3) |
| 99 | + |
| 100 | + # Check for CUDA/NPU availability |
| 101 | + self.is_cuda = device.type == 'cuda' and torch.cuda.is_available() |
| 102 | + self.is_npu = device.type == 'npu' and torch.npu.is_available() |
| 103 | + |
| 104 | + def __iter__(self): |
| 105 | + first = True |
| 106 | + if self.is_cuda: |
| 107 | + stream = torch.cuda.Stream() |
| 108 | + stream_context = partial(torch.cuda.stream, stream=stream) |
| 109 | + elif self.is_npu: |
| 110 | + stream = torch.npu.Stream() |
| 111 | + stream_context = partial(torch.npu.stream, stream=stream) |
| 112 | + else: |
| 113 | + stream = None |
| 114 | + stream_context = suppress |
| 115 | + |
| 116 | + for next_input_dict, next_target in self.loader: |
| 117 | + with stream_context(): |
| 118 | + # Move all tensors in input_dict to device |
| 119 | + for k, v in next_input_dict.items(): |
| 120 | + if isinstance(v, torch.Tensor): |
| 121 | + dtype = self.img_dtype if k == 'patches' else None |
| 122 | + next_input_dict[k] = next_input_dict[k].to( |
| 123 | + device=self.device, |
| 124 | + non_blocking=True, |
| 125 | + dtype=dtype, |
| 126 | + ) |
| 127 | + |
| 128 | + next_target = next_target.to(device=self.device, non_blocking=True) |
| 129 | + |
| 130 | + # Normalize patch values (assuming patches are in format [B, N, P*P*C]) |
| 131 | + batch_size, num_patches, patch_pixels = next_input_dict['patches'].shape |
| 132 | + patches = next_input_dict['patches'].view(batch_size, -1, 3) # to [B*N, P*P, C] for normalization |
| 133 | + patches = patches.sub(self.mean).div(self.std) |
| 134 | + |
| 135 | + # Reshape back |
| 136 | + next_input_dict['patches'] = patches.reshape(batch_size, num_patches, patch_pixels) |
| 137 | + |
| 138 | + if not first: |
| 139 | + yield input_dict, target |
| 140 | + else: |
| 141 | + first = False |
| 142 | + |
| 143 | + if stream is not None: |
| 144 | + if self.is_cuda: |
| 145 | + torch.cuda.current_stream().wait_stream(stream) |
| 146 | + elif self.is_npu: |
| 147 | + torch.npu.current_stream().wait_stream(stream) |
| 148 | + |
| 149 | + input_dict = next_input_dict |
| 150 | + target = next_target |
| 151 | + |
| 152 | + yield input_dict, target |
| 153 | + |
| 154 | + def __len__(self): |
| 155 | + return len(self.loader) |
| 156 | + |
| 157 | + @property |
| 158 | + def sampler(self): |
| 159 | + return self.loader.sampler |
| 160 | + |
| 161 | + @property |
| 162 | + def dataset(self): |
| 163 | + return self.loader.dataset |
| 164 | + |
| 165 | + |
| 166 | +def create_naflex_loader( |
| 167 | + dataset, |
| 168 | + patch_size: Union[Tuple[int, int], int] = 16, |
| 169 | + train_seq_lens: List[int] = (128, 256, 576, 784, 1024), # Training sequence lengths |
| 170 | + max_seq_len: int = 576, # Fixed sequence length for validation |
| 171 | + batch_size: int = 32, # Used for max_seq_len and max(train_seq_lens) |
| 172 | + is_training: bool = False, |
| 173 | + |
| 174 | + no_aug: bool = False, |
| 175 | + re_prob: float = 0., |
| 176 | + re_mode: str = 'const', |
| 177 | + re_count: int = 1, |
| 178 | + re_split: bool = False, |
| 179 | + train_crop_mode: Optional[str] = None, |
| 180 | + scale: Optional[Tuple[float, float]] = None, |
| 181 | + ratio: Optional[Tuple[float, float]] = None, |
| 182 | + hflip: float = 0.5, |
| 183 | + vflip: float = 0., |
| 184 | + color_jitter: float = 0.4, |
| 185 | + color_jitter_prob: Optional[float] = None, |
| 186 | + grayscale_prob: float = 0., |
| 187 | + gaussian_blur_prob: float = 0., |
| 188 | + auto_augment: Optional[str] = None, |
| 189 | + num_aug_repeats: int = 0, |
| 190 | + num_aug_splits: int = 0, |
| 191 | + interpolation: str = 'bilinear', |
| 192 | + mean: Tuple[float, ...] = IMAGENET_DEFAULT_MEAN, |
| 193 | + std: Tuple[float, ...] = IMAGENET_DEFAULT_STD, |
| 194 | + crop_pct: Optional[float] = None, |
| 195 | + crop_mode: Optional[str] = None, |
| 196 | + crop_border_pixels: Optional[int] = None, |
| 197 | + |
| 198 | + num_workers: int = 4, |
| 199 | + distributed: bool = False, |
| 200 | + rank: int = 0, |
| 201 | + world_size: int = 1, |
| 202 | + seed: int = 42, |
| 203 | + epoch: int = 0, |
| 204 | + use_prefetcher: bool = True, |
| 205 | + pin_memory: bool = True, |
| 206 | + img_dtype: torch.dtype = torch.float32, |
| 207 | + device: Union[str, torch.device] = torch.device('cuda'), |
| 208 | + persistent_workers: bool = True, |
| 209 | + worker_seeding: str = 'all', |
| 210 | + ): |
| 211 | + """Create a data loader with dynamic sequence length sampling for training.""" |
| 212 | + |
| 213 | + if is_training: |
| 214 | + # For training, use the dynamic sequence length mechanism |
| 215 | + assert num_aug_repeats == 0, 'Augmentation repeats not currently supported in NaFlex loader' |
| 216 | + |
| 217 | + transform_factory = partial( |
| 218 | + create_transform, |
| 219 | + is_training=True, |
| 220 | + no_aug=no_aug, |
| 221 | + train_crop_mode=train_crop_mode, |
| 222 | + scale=scale, |
| 223 | + ratio=ratio, |
| 224 | + hflip=hflip, |
| 225 | + vflip=vflip, |
| 226 | + color_jitter=color_jitter, |
| 227 | + color_jitter_prob=color_jitter_prob, |
| 228 | + grayscale_prob=grayscale_prob, |
| 229 | + gaussian_blur_prob=gaussian_blur_prob, |
| 230 | + auto_augment=auto_augment, |
| 231 | + interpolation=interpolation, |
| 232 | + mean=mean, |
| 233 | + std=std, |
| 234 | + crop_pct=crop_pct, |
| 235 | + crop_mode=crop_mode, |
| 236 | + crop_border_pixels=crop_border_pixels, |
| 237 | + re_prob=re_prob, |
| 238 | + re_mode=re_mode, |
| 239 | + re_count=re_count, |
| 240 | + use_prefetcher=use_prefetcher, |
| 241 | + naflex=True, |
| 242 | + ) |
| 243 | + |
| 244 | + max_train_seq_len = max(train_seq_lens) |
| 245 | + max_tokens_per_batch = batch_size * max_train_seq_len |
| 246 | + |
| 247 | + if isinstance(dataset, torch.utils.data.IterableDataset): |
| 248 | + assert False, "IterableDataset Wrapper is a WIP" |
| 249 | + |
| 250 | + naflex_dataset = VariableSeqMapWrapper( |
| 251 | + dataset, |
| 252 | + transform_factory=transform_factory, |
| 253 | + patch_size=patch_size, |
| 254 | + seq_lens=train_seq_lens, |
| 255 | + max_tokens_per_batch=max_tokens_per_batch, |
| 256 | + seed=seed, |
| 257 | + distributed=distributed, |
| 258 | + rank=rank, |
| 259 | + world_size=world_size, |
| 260 | + shuffle=True, |
| 261 | + epoch=epoch, |
| 262 | + ) |
| 263 | + |
| 264 | + # NOTE: Collation is handled by the dataset wrapper for training |
| 265 | + # Create the collator (handles fixed-size collation) |
| 266 | + # collate_fn = NaFlexCollator( |
| 267 | + # patch_size=patch_size, |
| 268 | + # max_seq_len=max(seq_lens) + 1, # +1 for class token |
| 269 | + # use_prefetcher=use_prefetcher |
| 270 | + # ) |
| 271 | + |
| 272 | + loader = torch.utils.data.DataLoader( |
| 273 | + naflex_dataset, |
| 274 | + batch_size=None, |
| 275 | + shuffle=False, |
| 276 | + num_workers=num_workers, |
| 277 | + sampler=None, |
| 278 | + #collate_fn=collate_fn, |
| 279 | + pin_memory=pin_memory, |
| 280 | + worker_init_fn=partial(_worker_init, worker_seeding=worker_seeding), |
| 281 | + persistent_workers=persistent_workers |
| 282 | + ) |
| 283 | + |
| 284 | + if use_prefetcher: |
| 285 | + loader = NaFlexPrefetchLoader( |
| 286 | + loader, |
| 287 | + mean=mean, |
| 288 | + std=std, |
| 289 | + img_dtype=img_dtype, |
| 290 | + device=device, |
| 291 | + ) |
| 292 | + |
| 293 | + else: |
| 294 | + # For validation, use fixed sequence length (unchanged) |
| 295 | + dataset.transform = create_transform( |
| 296 | + is_training=False, |
| 297 | + interpolation=interpolation, |
| 298 | + mean=mean, |
| 299 | + std=std, |
| 300 | + # FIXME add crop args when sequence transforms support crop modes |
| 301 | + use_prefetcher=use_prefetcher, |
| 302 | + naflex=True, |
| 303 | + patch_size=patch_size, |
| 304 | + max_seq_len=max_seq_len, |
| 305 | + patchify=True, |
| 306 | + ) |
| 307 | + |
| 308 | + # Create the collator |
| 309 | + collate_fn = NaFlexCollator( |
| 310 | + patch_size=patch_size, |
| 311 | + max_seq_len=max_seq_len, |
| 312 | + ) |
| 313 | + |
| 314 | + # Handle distributed training |
| 315 | + sampler = None |
| 316 | + if distributed and not isinstance(dataset, torch.utils.data.IterableDataset): |
| 317 | + # For validation, use OrderedDistributedSampler |
| 318 | + from timm.data.distributed_sampler import OrderedDistributedSampler |
| 319 | + sampler = OrderedDistributedSampler(dataset) |
| 320 | + |
| 321 | + loader = torch.utils.data.DataLoader( |
| 322 | + dataset, |
| 323 | + batch_size=batch_size, |
| 324 | + shuffle=False, |
| 325 | + num_workers=num_workers, |
| 326 | + sampler=sampler, |
| 327 | + collate_fn=collate_fn, |
| 328 | + pin_memory=pin_memory, |
| 329 | + drop_last=False, |
| 330 | + ) |
| 331 | + |
| 332 | + if use_prefetcher: |
| 333 | + loader = NaFlexPrefetchLoader( |
| 334 | + loader, |
| 335 | + mean=mean, |
| 336 | + std=std, |
| 337 | + img_dtype=img_dtype, |
| 338 | + device=device, |
| 339 | + ) |
| 340 | + |
| 341 | + return loader |
0 commit comments