Skip to content

Commit d3cdb73

Browse files
committed
Allow transforms that accept torch.Tensor / np.ndarray from the beginning
1 parent 6f93905 commit d3cdb73

File tree

1 file changed

+54
-4
lines changed

1 file changed

+54
-4
lines changed

src/open_clip/transform.py

Lines changed: 54 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from dataclasses import dataclass, asdict
55
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
66

7+
import numpy as np
78
import torch
89
import torchvision.transforms.functional as F
910
from torchvision.transforms import Normalize, Compose, RandomResizedCrop, InterpolationMode, ToTensor, Resize, \
@@ -235,10 +236,59 @@ def __repr__(self) -> str:
235236
return f"{self.__class__.__name__}(size={self.size})"
236237

237238

239+
class MaybeConvertMode:
240+
"""Perform PIL convert("RGB") iff PIL image.
241+
"""
242+
243+
def __init__(self, mode="RGB") -> None:
244+
super().__init__()
245+
self.mode = mode
246+
247+
def __call__(self, pic):
248+
"""
249+
Args:
250+
pic (PIL Image or numpy.ndarray): Image to be converted to tensor.
251+
252+
Returns:
253+
Tensor: Converted image.
254+
"""
255+
# FIXME if we switched to torchvision v2 transforms, can handle image mode
256+
# conversion more consistently across all input types.
257+
if isinstance(pic, (np.ndarray, torch.Tensor)):
258+
return pic
259+
return pic.convert(self.mode)
260+
261+
def __repr__(self) -> str:
262+
return f"{self.__class__.__name__}()"
263+
264+
238265
def _convert_to_rgb(image):
239266
return image.convert('RGB')
240267

241268

269+
class MaybeToTensor(ToTensor):
270+
"""Convert a PIL Image or ndarray to tensor if it's not already one.
271+
"""
272+
273+
def __init__(self) -> None:
274+
super().__init__()
275+
276+
def __call__(self, pic) -> torch.Tensor:
277+
"""
278+
Args:
279+
pic (PIL Image or numpy.ndarray): Image to be converted to tensor.
280+
281+
Returns:
282+
Tensor: Converted image.
283+
"""
284+
if isinstance(pic, torch.Tensor):
285+
return pic
286+
return F.to_tensor(pic)
287+
288+
def __repr__(self) -> str:
289+
return f"{self.__class__.__name__}()"
290+
291+
242292
class color_jitter(object):
243293
"""
244294
Apply Color Jitter to the PIL image with a specified probability.
@@ -337,7 +387,7 @@ def image_transform(
337387
scale=aug_cfg_dict.pop('scale'),
338388
interpolation=InterpolationMode.BICUBIC,
339389
),
340-
_convert_to_rgb,
390+
MaybeConvertMode(),
341391
]
342392
if aug_cfg.color_jitter_prob:
343393
assert aug_cfg.color_jitter is not None and len(aug_cfg.color_jitter) == 4
@@ -349,7 +399,7 @@ def image_transform(
349399
gray_scale(aug_cfg.gray_scale_prob)
350400
])
351401
train_transform.extend([
352-
ToTensor(),
402+
MaybeToTensor(),
353403
normalize,
354404
])
355405
train_transform = Compose(train_transform)
@@ -383,8 +433,8 @@ def image_transform(
383433
transforms += [CenterCrop(image_size)]
384434

385435
transforms.extend([
386-
_convert_to_rgb,
387-
ToTensor(),
436+
MaybeConvertMode(),
437+
MaybeToTensor(),
388438
normalize,
389439
])
390440
return Compose(transforms)

0 commit comments

Comments
 (0)