Skip to content

Commit e2073e3

Browse files
committed
Move NaFlexCollate with dataset, remove stand alone collate_fn and remove redundancy
1 parent 39eb56f commit e2073e3

File tree

2 files changed

+67
-134
lines changed

2 files changed

+67
-134
lines changed

timm/data/naflex_dataset.py

+65-63
Original file line numberDiff line numberDiff line change
@@ -59,63 +59,65 @@ def calculate_batch_size(
5959
return batch_size
6060

6161

62-
def _collate_batch(
63-
batch_samples: List[Tuple[Dict[str, torch.Tensor], Any]],
64-
target_seq_len: int
65-
) -> Tuple[Dict[str, torch.Tensor], torch.Tensor]:
66-
"""Collates processed samples into a batch, padding/truncating to target_seq_len."""
67-
batch_patch_data = [item[0] for item in batch_samples]
68-
batch_labels = [item[1] for item in batch_samples]
69-
70-
if not batch_patch_data:
71-
return {}, torch.empty(0)
72-
73-
batch_size = len(batch_patch_data)
74-
patch_dim = batch_patch_data[0]['patches'].shape[1]
75-
76-
# Initialize tensors with target sequence length
77-
patches_batch = torch.zeros((batch_size, target_seq_len, patch_dim), dtype=torch.float32)
78-
patch_coord_batch = torch.zeros((batch_size, target_seq_len, 2), dtype=torch.int64)
79-
patch_valid_batch = torch.zeros((batch_size, target_seq_len), dtype=torch.bool) # Use bool
80-
81-
for i, data in enumerate(batch_patch_data):
82-
num_patches = data['patches'].shape[0]
83-
# Take min(num_patches, target_seq_len) patches
84-
n_copy = min(num_patches, target_seq_len)
85-
86-
patches_batch[i, :n_copy] = data['patches'][:n_copy]
87-
patch_coord_batch[i, :n_copy] = data['patch_coord'][:n_copy]
88-
patch_valid_batch[i, :n_copy] = data['patch_valid'][:n_copy] # Copy validity flags
89-
90-
# Create the final input dict
91-
input_dict = {
92-
'patches': patches_batch,
93-
'patch_coord': patch_coord_batch,
94-
'patch_valid': patch_valid_batch, # Boolean mask
95-
# Note: 'seq_length' might be ambiguous. The target length is target_seq_len.
96-
# The actual number of valid patches per sample varies.
97-
# 'patch_valid' mask is the most reliable source of truth.
98-
}
99-
100-
# Attempt to stack labels if they are tensors, otherwise return list
101-
try:
102-
if isinstance(batch_labels[0], torch.Tensor):
103-
labels_tensor = torch.stack(batch_labels)
62+
class NaFlexCollator:
63+
"""Custom collator for batching NaFlex-style variable-resolution images."""
64+
65+
def __init__(
66+
self,
67+
max_seq_len=None,
68+
):
69+
self.max_seq_len = max_seq_len or 576 # Default ViT-B/16 sequence length (577 = 24*24)
70+
71+
def __call__(self, batch):
72+
"""
73+
Args:
74+
batch: List of tuples (patch_dict, target)
75+
76+
Returns:
77+
A tuple of (input_dict, targets) where input_dict contains:
78+
- patches: Padded tensor of patches
79+
- patch_coord: Coordinates for each patch (y, x)
80+
- patch_valid: Valid indicators
81+
"""
82+
assert isinstance(batch[0], tuple)
83+
batch_size = len(batch)
84+
85+
# Extract targets
86+
# FIXME need to handle dense (float) targets or always done downstream of this?
87+
targets = torch.tensor([item[1] for item in batch], dtype=torch.int64)
88+
89+
# Get patch dictionaries
90+
patch_dicts = [item[0] for item in batch]
91+
92+
# If we have a maximum sequence length constraint, ensure we don't exceed it
93+
if self.max_seq_len is not None:
94+
max_patches = self.max_seq_len
10495
else:
105-
# Convert numerical types to tensor, keep others as list (or handle specific types)
106-
if isinstance(batch_labels[0], (int, float)):
107-
labels_tensor = torch.tensor(batch_labels)
108-
else:
109-
# Cannot convert non-numerical labels easily, return as list
110-
# Or handle specific conversion if needed
111-
# For FakeDataset, labels are ints, so this works
112-
labels_tensor = torch.tensor(batch_labels) # Assuming labels are numerical
113-
except Exception:
114-
# Fallback if stacking fails (e.g., different shapes, types)
115-
print("Warning: Could not stack labels into a tensor. Returning list of labels.")
116-
labels_tensor = batch_labels # Return as list
96+
# Find the maximum number of patches in this batch
97+
max_patches = max(item['patches'].shape[0] for item in patch_dicts)
98+
99+
# Get patch dimensionality
100+
patch_dim = patch_dicts[0]['patches'].shape[1]
101+
102+
# Prepare tensors for the batch
103+
patches = torch.zeros((batch_size, max_patches, patch_dim), dtype=torch.float32)
104+
patch_coord = torch.zeros((batch_size, max_patches, 2), dtype=torch.int64) # [B, N, 2] for (y, x)
105+
patch_valid = torch.zeros((batch_size, max_patches), dtype=torch.bool)
106+
107+
# Fill in the tensors
108+
for i, patch_dict in enumerate(patch_dicts):
109+
num_patches = min(patch_dict['patches'].shape[0], max_patches)
117110

118-
return input_dict, labels_tensor
111+
patches[i, :num_patches] = patch_dict['patches'][:num_patches]
112+
patch_coord[i, :num_patches] = patch_dict['patch_coord'][:num_patches]
113+
patch_valid[i, :num_patches] = patch_dict['patch_valid'][:num_patches]
114+
115+
return {
116+
'patches': patches,
117+
'patch_coord': patch_coord,
118+
'patch_valid': patch_valid,
119+
'seq_len': max_patches,
120+
}, targets
119121

120122

121123
class VariableSeqMapWrapper(IterableDataset):
@@ -161,15 +163,15 @@ def __init__(
161163
self.epoch = epoch
162164
self.batch_divisor = batch_divisor
163165

164-
# Pre-initialize transforms for each sequence length
166+
# Pre-initialize transforms and collate fns for each sequence length
165167
self.transforms: Dict[int, Optional[Callable]] = {}
166-
if transform_factory:
167-
for seq_len in self.seq_lens:
168+
self.collate_fns: Dict[int, Callable] = {}
169+
for seq_len in self.seq_lens:
170+
if transform_factory:
168171
self.transforms[seq_len] = transform_factory(max_seq_len=seq_len, patch_size=self.patch_size)
169-
else:
170-
for seq_len in self.seq_lens:
171-
self.transforms[seq_len] = None # No transform
172-
172+
else:
173+
self.transforms[seq_len] = None # No transform
174+
self.collate_fns[seq_len] = NaFlexCollator(seq_len)
173175
self.patchifier = Patchify(self.patch_size)
174176

175177
# --- Canonical Schedule Calculation (Done Once) ---
@@ -417,6 +419,6 @@ def __iter__(self) -> Iterator[Tuple[Dict[str, torch.Tensor], torch.Tensor]]:
417419

418420
# Collate the processed samples into a batch
419421
if batch_samples: # Only yield if we successfully processed samples
420-
yield _collate_batch(batch_samples, seq_len)
422+
yield self.collate_fns[seq_len](batch_samples)
421423

422424
# If batch_samples is empty after processing 'indices', an empty batch is skipped.

timm/data/naflex_loader.py

+2-71
Original file line numberDiff line numberDiff line change
@@ -7,74 +7,10 @@
77

88
from .constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
99
from .loader import _worker_init
10-
from .naflex_dataset import VariableSeqMapWrapper
10+
from .naflex_dataset import VariableSeqMapWrapper, NaFlexCollator
1111
from .transforms_factory import create_transform
1212

1313

14-
class NaFlexCollator:
15-
"""Custom collator for batching NaFlex-style variable-resolution images."""
16-
17-
def __init__(
18-
self,
19-
patch_size=16,
20-
max_seq_len=None,
21-
):
22-
self.patch_size = patch_size
23-
self.max_seq_len = max_seq_len or 576 # Default ViT-B/16 sequence length (577 = 24*24)
24-
25-
def __call__(self, batch):
26-
"""
27-
Args:
28-
batch: List of tuples (patch_dict, target)
29-
30-
Returns:
31-
A tuple of (input_dict, targets) where input_dict contains:
32-
- patches: Padded tensor of patches
33-
- patch_coord: Coordinates for each patch (y, x)
34-
- patch_valid: Valid indicators
35-
"""
36-
assert isinstance(batch[0], tuple)
37-
batch_size = len(batch)
38-
39-
# Resize to final size based on seq_len and patchify
40-
41-
# Extract targets
42-
targets = torch.tensor([item[1] for item in batch], dtype=torch.int64)
43-
44-
# Get patch dictionaries
45-
patch_dicts = [item[0] for item in batch]
46-
47-
# If we have a maximum sequence length constraint, ensure we don't exceed it
48-
if self.max_seq_len is not None:
49-
max_patches = self.max_seq_len
50-
else:
51-
# Find the maximum number of patches in this batch
52-
max_patches = max(item['patches'].shape[0] for item in patch_dicts)
53-
54-
# Get patch dimensionality
55-
patch_dim = patch_dicts[0]['patches'].shape[1]
56-
57-
# Prepare tensors for the batch
58-
patches = torch.zeros((batch_size, max_patches, patch_dim), dtype=torch.float32)
59-
patch_coord = torch.zeros((batch_size, max_patches, 2), dtype=torch.int64) # [B, N, 2] for (y, x)
60-
patch_valid = torch.zeros((batch_size, max_patches), dtype=torch.bool)
61-
62-
# Fill in the tensors
63-
for i, patch_dict in enumerate(patch_dicts):
64-
num_patches = min(patch_dict['patches'].shape[0], max_patches)
65-
66-
patches[i, :num_patches] = patch_dict['patches'][:num_patches]
67-
patch_coord[i, :num_patches] = patch_dict['patch_coord'][:num_patches]
68-
patch_valid[i, :num_patches] = patch_dict['patch_valid'][:num_patches]
69-
70-
return {
71-
'patches': patches,
72-
'patch_coord': patch_coord,
73-
'patch_valid': patch_valid,
74-
'seq_len': max_patches,
75-
}, targets
76-
77-
7814
class NaFlexPrefetchLoader:
7915
"""Data prefetcher for NaFlex format which normalizes patches."""
8016

@@ -261,9 +197,7 @@ def create_naflex_loader(
261197
# NOTE: Collation is handled by the dataset wrapper for training
262198
# Create the collator (handles fixed-size collation)
263199
# collate_fn = NaFlexCollator(
264-
# patch_size=patch_size,
265200
# max_seq_len=max(seq_lens) + 1, # +1 for class token
266-
# use_prefetcher=use_prefetcher
267201
# )
268202

269203
loader = torch.utils.data.DataLoader(
@@ -303,10 +237,7 @@ def create_naflex_loader(
303237
)
304238

305239
# Create the collator
306-
collate_fn = NaFlexCollator(
307-
patch_size=patch_size,
308-
max_seq_len=max_seq_len,
309-
)
240+
collate_fn = NaFlexCollator(max_seq_len=max_seq_len)
310241

311242
# Handle distributed training
312243
sampler = None

0 commit comments

Comments
 (0)