@@ -59,63 +59,65 @@ def calculate_batch_size(
59
59
return batch_size
60
60
61
61
62
- def _collate_batch (
63
- batch_samples : List [Tuple [Dict [str , torch .Tensor ], Any ]],
64
- target_seq_len : int
65
- ) -> Tuple [Dict [str , torch .Tensor ], torch .Tensor ]:
66
- """Collates processed samples into a batch, padding/truncating to target_seq_len."""
67
- batch_patch_data = [item [0 ] for item in batch_samples ]
68
- batch_labels = [item [1 ] for item in batch_samples ]
69
-
70
- if not batch_patch_data :
71
- return {}, torch .empty (0 )
72
-
73
- batch_size = len (batch_patch_data )
74
- patch_dim = batch_patch_data [0 ]['patches' ].shape [1 ]
75
-
76
- # Initialize tensors with target sequence length
77
- patches_batch = torch .zeros ((batch_size , target_seq_len , patch_dim ), dtype = torch .float32 )
78
- patch_coord_batch = torch .zeros ((batch_size , target_seq_len , 2 ), dtype = torch .int64 )
79
- patch_valid_batch = torch .zeros ((batch_size , target_seq_len ), dtype = torch .bool ) # Use bool
80
-
81
- for i , data in enumerate (batch_patch_data ):
82
- num_patches = data ['patches' ].shape [0 ]
83
- # Take min(num_patches, target_seq_len) patches
84
- n_copy = min (num_patches , target_seq_len )
85
-
86
- patches_batch [i , :n_copy ] = data ['patches' ][:n_copy ]
87
- patch_coord_batch [i , :n_copy ] = data ['patch_coord' ][:n_copy ]
88
- patch_valid_batch [i , :n_copy ] = data ['patch_valid' ][:n_copy ] # Copy validity flags
89
-
90
- # Create the final input dict
91
- input_dict = {
92
- 'patches' : patches_batch ,
93
- 'patch_coord' : patch_coord_batch ,
94
- 'patch_valid' : patch_valid_batch , # Boolean mask
95
- # Note: 'seq_length' might be ambiguous. The target length is target_seq_len.
96
- # The actual number of valid patches per sample varies.
97
- # 'patch_valid' mask is the most reliable source of truth.
98
- }
99
-
100
- # Attempt to stack labels if they are tensors, otherwise return list
101
- try :
102
- if isinstance (batch_labels [0 ], torch .Tensor ):
103
- labels_tensor = torch .stack (batch_labels )
62
+ class NaFlexCollator :
63
+ """Custom collator for batching NaFlex-style variable-resolution images."""
64
+
65
+ def __init__ (
66
+ self ,
67
+ max_seq_len = None ,
68
+ ):
69
+ self .max_seq_len = max_seq_len or 576 # Default ViT-B/16 sequence length (577 = 24*24)
70
+
71
+ def __call__ (self , batch ):
72
+ """
73
+ Args:
74
+ batch: List of tuples (patch_dict, target)
75
+
76
+ Returns:
77
+ A tuple of (input_dict, targets) where input_dict contains:
78
+ - patches: Padded tensor of patches
79
+ - patch_coord: Coordinates for each patch (y, x)
80
+ - patch_valid: Valid indicators
81
+ """
82
+ assert isinstance (batch [0 ], tuple )
83
+ batch_size = len (batch )
84
+
85
+ # Extract targets
86
+ # FIXME need to handle dense (float) targets or always done downstream of this?
87
+ targets = torch .tensor ([item [1 ] for item in batch ], dtype = torch .int64 )
88
+
89
+ # Get patch dictionaries
90
+ patch_dicts = [item [0 ] for item in batch ]
91
+
92
+ # If we have a maximum sequence length constraint, ensure we don't exceed it
93
+ if self .max_seq_len is not None :
94
+ max_patches = self .max_seq_len
104
95
else :
105
- # Convert numerical types to tensor, keep others as list (or handle specific types)
106
- if isinstance (batch_labels [0 ], (int , float )):
107
- labels_tensor = torch .tensor (batch_labels )
108
- else :
109
- # Cannot convert non-numerical labels easily, return as list
110
- # Or handle specific conversion if needed
111
- # For FakeDataset, labels are ints, so this works
112
- labels_tensor = torch .tensor (batch_labels ) # Assuming labels are numerical
113
- except Exception :
114
- # Fallback if stacking fails (e.g., different shapes, types)
115
- print ("Warning: Could not stack labels into a tensor. Returning list of labels." )
116
- labels_tensor = batch_labels # Return as list
96
+ # Find the maximum number of patches in this batch
97
+ max_patches = max (item ['patches' ].shape [0 ] for item in patch_dicts )
98
+
99
+ # Get patch dimensionality
100
+ patch_dim = patch_dicts [0 ]['patches' ].shape [1 ]
101
+
102
+ # Prepare tensors for the batch
103
+ patches = torch .zeros ((batch_size , max_patches , patch_dim ), dtype = torch .float32 )
104
+ patch_coord = torch .zeros ((batch_size , max_patches , 2 ), dtype = torch .int64 ) # [B, N, 2] for (y, x)
105
+ patch_valid = torch .zeros ((batch_size , max_patches ), dtype = torch .bool )
106
+
107
+ # Fill in the tensors
108
+ for i , patch_dict in enumerate (patch_dicts ):
109
+ num_patches = min (patch_dict ['patches' ].shape [0 ], max_patches )
117
110
118
- return input_dict , labels_tensor
111
+ patches [i , :num_patches ] = patch_dict ['patches' ][:num_patches ]
112
+ patch_coord [i , :num_patches ] = patch_dict ['patch_coord' ][:num_patches ]
113
+ patch_valid [i , :num_patches ] = patch_dict ['patch_valid' ][:num_patches ]
114
+
115
+ return {
116
+ 'patches' : patches ,
117
+ 'patch_coord' : patch_coord ,
118
+ 'patch_valid' : patch_valid ,
119
+ 'seq_len' : max_patches ,
120
+ }, targets
119
121
120
122
121
123
class VariableSeqMapWrapper (IterableDataset ):
@@ -161,15 +163,15 @@ def __init__(
161
163
self .epoch = epoch
162
164
self .batch_divisor = batch_divisor
163
165
164
- # Pre-initialize transforms for each sequence length
166
+ # Pre-initialize transforms and collate fns for each sequence length
165
167
self .transforms : Dict [int , Optional [Callable ]] = {}
166
- if transform_factory :
167
- for seq_len in self .seq_lens :
168
+ self .collate_fns : Dict [int , Callable ] = {}
169
+ for seq_len in self .seq_lens :
170
+ if transform_factory :
168
171
self .transforms [seq_len ] = transform_factory (max_seq_len = seq_len , patch_size = self .patch_size )
169
- else :
170
- for seq_len in self .seq_lens :
171
- self .transforms [seq_len ] = None # No transform
172
-
172
+ else :
173
+ self .transforms [seq_len ] = None # No transform
174
+ self .collate_fns [seq_len ] = NaFlexCollator (seq_len )
173
175
self .patchifier = Patchify (self .patch_size )
174
176
175
177
# --- Canonical Schedule Calculation (Done Once) ---
@@ -417,6 +419,6 @@ def __iter__(self) -> Iterator[Tuple[Dict[str, torch.Tensor], torch.Tensor]]:
417
419
418
420
# Collate the processed samples into a batch
419
421
if batch_samples : # Only yield if we successfully processed samples
420
- yield _collate_batch (batch_samples , seq_len )
422
+ yield self . collate_fns [ seq_len ] (batch_samples )
421
423
422
424
# If batch_samples is empty after processing 'indices', an empty batch is skipped.
0 commit comments