-
Notifications
You must be signed in to change notification settings - Fork 1
/
data_utils.py
94 lines (71 loc) · 2.65 KB
/
data_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
from utilities import SquarePad
from tqdm import tqdm
import torch
from PIL import Image
import pandas as pd
import numpy as np
import os
import json
def parse_data(filepath, img_dir, filename):
''' Construct dataset containing all relevant information.
Image data and metadata. '''
data = pd.read_csv(filepath)
columns = ['img', 'artist', 'date', 'era', 'source']
metadata = pd.DataFrame(columns=columns)
pbar = tqdm(sorted(os.listdir(img_dir)))
for image in pbar:
im = Image.open(img_dir + '/' + image)
index = int(image.split('.')[0])
info = data.loc[index]
properties = [str(img_dir + '/' + image), info['full_name'], info['date'], info['era'], info['name']]
metadata.loc[index] = properties
metadata.to_csv('data/' + filename)
class DataProcessor():
def __init__(self, datapath, image_size, mode, prefix='', mappings=None):
# Transform images
self.transform_train = transforms.Compose([
transforms.RandomChoice([transforms.RandomHorizontalFlip(),
transforms.RandomVerticalFlip()]),
transforms.RandomResizedCrop((image_size, image_size), scale=(0.05, 1.0)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.247, 0.243, 0.261]),
])
self.transform_test = transforms.Compose([
transforms.Resize((image_size, image_size)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.247, 0.243, 0.261]),
])
self.prefix = prefix
data = pd.read_csv(datapath)
self.data = data
self.mode = mode
# Retrieve unique labels
self.artists = data['artist'].unique()
self.eras = data['era'].unique()
self.artists_weights = data['artist'].value_counts()
self.eras_weights = data['era'].value_counts()
# Mapping from unique label to int label
if mappings == None:
self.artist_map = {self.artists[i]:i for i in range(0, len(self.artists))}
self.era_map = {self.eras[i]:i for i in range(0, len(self.eras))}
data = {'artist': self.artist_map, 'era': self.era_map}
json.dump(data, open('data/mapping.json', 'w+'))
else:
self.artist_map = mappings['artist']
self.era_map = mappings['era']
def get_labels(self):
return [len(self.artists), len(self.eras)]
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
row = self.data.iloc[idx]
img = Image.open(f'{self.prefix}/{row.img}').convert('RGB')
if self.mode == 'train':
img = self.transform_train(img)
if self.mode == 'test':
img = self.transform_test(img)
artist = self.artist_map[row.artist]
era = self.era_map[row.era]
return img, artist, row.date, era