-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathdata_transformations.py
33 lines (26 loc) · 1.13 KB
/
data_transformations.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
import torchvision.transforms as transforms
__all__ = ['tensor_transform']
# custom class for data transformation can also be written
# Use for the SSL dataset
tensor_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5032, 0.4746, 0.4275), (0.2276, 0.2228, 0.2265))
])
# Use for the SSL dataset
resnet_input_transform = transforms.Compose([
transforms.RandomAffine((-5,5), translate=(0.1,0.1), scale=(0.9, 1.1), shear=None, resample=False, fillcolor=0),
transforms.Resize((96, 96)),
transforms.ToTensor(),
transforms.Normalize((0.5032, 0.4746, 0.4275), (0.2276, 0.2228, 0.2265))
])
# use only for the CIFAR Dataset
cifar10_input_transform = transforms.Compose([
transforms.RandomAffine((-5,5), translate=(0.1,0.1), scale=(0.9, 1.1), shear=None, resample=False, fillcolor=0),
transforms.Resize((48, 48)),
transforms.ToTensor(),
transforms.Normalize((0.4925, 0.4828, 0.4464), (0.2024, 0.1998, 0.2006))
])
# SSL Data
# transforms.Normalize((0.5032, 0.4746, 0.4275), (0.2276, 0.2228, 0.2265))
# for CIFAR data
# transforms.Normalize((0.4925, 0.4828, 0.4464), (0.2024, 0.1998, 0.2006))