44from dataclasses import dataclass , asdict
55from typing import Any , Dict , List , Optional , Sequence , Tuple , Union
66
7+ import numpy as np
78import torch
89import torchvision .transforms .functional as F
910from torchvision .transforms import Normalize , Compose , RandomResizedCrop , InterpolationMode , ToTensor , Resize , \
@@ -235,10 +236,59 @@ def __repr__(self) -> str:
235236 return f"{ self .__class__ .__name__ } (size={ self .size } )"
236237
237238
239+ class MaybeConvertMode :
240+ """Perform PIL convert("RGB") iff PIL image.
241+ """
242+
243+ def __init__ (self , mode = "RGB" ) -> None :
244+ super ().__init__ ()
245+ self .mode = mode
246+
247+ def __call__ (self , pic ):
248+ """
249+ Args:
250+ pic (PIL Image or numpy.ndarray): Image to be converted to tensor.
251+
252+ Returns:
253+ Tensor: Converted image.
254+ """
255+ # FIXME if we switched to torchvision v2 transforms, can handle image mode
256+ # conversion more consistently across all input types.
257+ if isinstance (pic , (np .ndarray , torch .Tensor )):
258+ return pic
259+ return pic .convert (self .mode )
260+
261+ def __repr__ (self ) -> str :
262+ return f"{ self .__class__ .__name__ } ()"
263+
264+
238265def _convert_to_rgb (image ):
239266 return image .convert ('RGB' )
240267
241268
269+ class MaybeToTensor (ToTensor ):
270+ """Convert a PIL Image or ndarray to tensor if it's not already one.
271+ """
272+
273+ def __init__ (self ) -> None :
274+ super ().__init__ ()
275+
276+ def __call__ (self , pic ) -> torch .Tensor :
277+ """
278+ Args:
279+ pic (PIL Image or numpy.ndarray): Image to be converted to tensor.
280+
281+ Returns:
282+ Tensor: Converted image.
283+ """
284+ if isinstance (pic , torch .Tensor ):
285+ return pic
286+ return F .to_tensor (pic )
287+
288+ def __repr__ (self ) -> str :
289+ return f"{ self .__class__ .__name__ } ()"
290+
291+
242292class color_jitter (object ):
243293 """
244294 Apply Color Jitter to the PIL image with a specified probability.
@@ -337,7 +387,7 @@ def image_transform(
337387 scale = aug_cfg_dict .pop ('scale' ),
338388 interpolation = InterpolationMode .BICUBIC ,
339389 ),
340- _convert_to_rgb ,
390+ MaybeConvertMode () ,
341391 ]
342392 if aug_cfg .color_jitter_prob :
343393 assert aug_cfg .color_jitter is not None and len (aug_cfg .color_jitter ) == 4
@@ -349,7 +399,7 @@ def image_transform(
349399 gray_scale (aug_cfg .gray_scale_prob )
350400 ])
351401 train_transform .extend ([
352- ToTensor (),
402+ MaybeToTensor (),
353403 normalize ,
354404 ])
355405 train_transform = Compose (train_transform )
@@ -383,8 +433,8 @@ def image_transform(
383433 transforms += [CenterCrop (image_size )]
384434
385435 transforms .extend ([
386- _convert_to_rgb ,
387- ToTensor (),
436+ MaybeConvertMode () ,
437+ MaybeToTensor (),
388438 normalize ,
389439 ])
390440 return Compose (transforms )
0 commit comments