Skip to content

Commit 9c1e008

Browse files
committed
Check in test data
1 parent bea9127 commit 9c1e008

9 files changed

+999
-700
lines changed

BrokenMixMatch.ipynb

+676
Large diffs are not rendered by default.

MixMatch.ipynb

-700
This file was deleted.

__init__.py

Whitespace-only changes.

cifar_subset.pkl

7.03 MB
Binary file not shown.

cifar_utils.py

+86
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
"""Stolen from Stanford CS231"""
2+
from __future__ import print_function
3+
4+
from builtins import range
5+
from six.moves import cPickle as pickle
6+
import numpy as np
7+
import os
8+
from imageio import imread
9+
import platform
10+
11+
CIFAR_10_DIR = 'cifar-10-batches-py'
12+
13+
14+
def load_pickle(f):
15+
version = platform.python_version_tuple()
16+
if version[0] == '2':
17+
return pickle.load(f)
18+
elif version[0] == '3':
19+
return pickle.load(f, encoding='latin1')
20+
raise ValueError("invalid python version: {}".format(version))
21+
22+
def load_CIFAR_batch(filename):
23+
""" load single batch of cifar """
24+
with open(filename, 'rb') as f:
25+
datadict = load_pickle(f)
26+
X = datadict['data']
27+
Y = datadict['labels']
28+
X = X.reshape(10000, 3, 32, 32).transpose(0,2,3,1).astype("float")
29+
Y = np.array(Y)
30+
return X, Y
31+
32+
def load_CIFAR10(ROOT):
33+
""" load all of cifar """
34+
xs = []
35+
ys = []
36+
for b in range(1,6):
37+
f = os.path.join(ROOT, 'data_batch_%d' % (b, ))
38+
X, Y = load_CIFAR_batch(f)
39+
xs.append(X)
40+
ys.append(Y)
41+
Xtr = np.concatenate(xs)
42+
Ytr = np.concatenate(ys)
43+
del X, Y
44+
Xte, Yte = load_CIFAR_batch(os.path.join(ROOT, 'test_batch'))
45+
return Xtr, Ytr, Xte, Yte
46+
47+
48+
def get_CIFAR10_data(num_training=49000, num_validation=1000, num_test=1000,
49+
subtract_mean=True, cifar10_dir = CIFAR_10_DIR):
50+
"""
51+
Load the CIFAR-10 dataset from disk and perform preprocessing to prepare
52+
it for classifiers. These are the same steps as we used for the SVM, but
53+
condensed to a single function.
54+
"""
55+
# Load the raw CIFAR-10 data
56+
X_train, y_train, X_test, y_test = load_CIFAR10(cifar10_dir)
57+
58+
# Subsample the data
59+
mask = list(range(num_training, num_training + num_validation))
60+
X_val = X_train[mask]
61+
y_val = y_train[mask]
62+
mask = list(range(num_training))
63+
X_train = X_train[mask]
64+
y_train = y_train[mask]
65+
mask = list(range(num_test))
66+
X_test = X_test[mask]
67+
y_test = y_test[mask]
68+
69+
# Normalize the data: subtract the mean image
70+
if subtract_mean:
71+
mean_image = np.mean(X_train, axis=0)
72+
X_train -= mean_image
73+
X_val -= mean_image
74+
X_test -= mean_image
75+
76+
# Transpose so that channels come first
77+
X_train = X_train.transpose(0, 3, 1, 2).copy()
78+
X_val = X_val.transpose(0, 3, 1, 2).copy()
79+
X_test = X_test.transpose(0, 3, 1, 2).copy()
80+
81+
# Package data into a dictionary
82+
return {
83+
'X_train': X_train, 'y_train': y_train,
84+
'X_val': X_val, 'y_val': y_val,
85+
'X_test': X_test, 'y_test': y_test,
86+
}

dataset.py

+116
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
from torch.utils.data import Dataset
2+
from torch.utils.data import DataLoader
3+
import numpy as np
4+
import torch
5+
6+
7+
def shit_mult(a, arr):
8+
"""Why cant broadcat"""
9+
new_arr = np.zeros_like(arr)
10+
for i in range(len(a)):
11+
new_arr[i] = arr[i] * a[i]
12+
return new_arr
13+
14+
15+
16+
class ArrayDataset(Dataset):
17+
def __init__(self, X, y, X_unlabeled):
18+
super().__init__()
19+
self.X = X
20+
self.y = y
21+
self.X_unlabeled = X_unlabeled
22+
self.last_labeled = False
23+
24+
25+
def __len__(self):
26+
return self.X.shape[0]
27+
28+
def __getitem__(self, index):
29+
"""Alternate generating labeled and unlabeled."""
30+
if self.last_labeled:
31+
self.last_labeled = False
32+
zeros = np.zeros_like(self.y[0])
33+
return (self.X_unlabeled[np.random.randint(0, len(self.X_unlabeled),)], zeros)
34+
else:
35+
self.last_labeled = True
36+
idx = np.random.randint(0, len(self.X), )
37+
return (self.X[idx], self.y[idx])
38+
39+
to_arr = lambda x: x.detach().numpy()
40+
41+
def sharpen(x, T):
42+
numerator = x ** (1 / T)
43+
return numerator / numerator.sum(dim=1, keepdim=True)
44+
45+
46+
def mixup_torch(x1, x2, y1, y2, alpha):
47+
beta = torch.Tensor(np.random.beta(alpha, alpha, x1.shape[0]))
48+
beta = torch.max(beta, 1-beta)
49+
print(f'beta: {beta.shape}, x1: {x1.shape}, x2: {x2.shape}')
50+
return lc2(x1, x2, beta), lc2(y1, y2, beta)
51+
52+
def lc2(x1, x2, l):
53+
orig = torch.cat([(x1[i] * l[i]).unsqueeze(0) for i in range(len(l))])
54+
other = torch.cat([(x2[i] * (1 - l[i])).unsqueeze(0) for i in range(len(l))])
55+
mixed = orig + other
56+
if len(mixed.shape) == 3: mixed = mixed.unsqueeze(0) # bs=2
57+
return mixed
58+
59+
from torch import nn
60+
class Flatten(nn.Module):
61+
def forward(self, x): return x.view(x.size(0), -1)
62+
63+
class MixupLoader(DataLoader):
64+
65+
def __init__(self, ds, batch_size, T=0.5, K=2, alpha=0.75, verbose=False):
66+
self.bs = batch_size
67+
assert self.bs % 2 == 0
68+
self.ds = ds
69+
self.T = T
70+
self.K = K
71+
self.alpha = alpha
72+
self.verbose = verbose
73+
super().__init__(ds, collate_fn=self.collate_fn, batch_size=self.bs,
74+
num_workers=0)
75+
76+
def get_pseudo_labels(self, ub):
77+
preds = self.model(ub) / self.K
78+
qb = sharpen(preds, self.T).detach()
79+
return qb
80+
81+
@staticmethod
82+
def augment_fn(X):
83+
# TODO(SS): fix me
84+
return X
85+
86+
def collate_fn(loader, examples):
87+
K,T,alpha = loader.K, loader.T, loader.alpha
88+
C = lambda arrs: np.concatenate(np.expand_dims(arrs, 0))
89+
X_labeled = C([X for X, y_ in examples if y_.sum() == 1])
90+
y = torch.Tensor(np.array([y_ for X, y_ in examples if y_.sum() == 1]))
91+
X_unlabeled = C([X for X, y_ in examples if y_.sum() == 0])
92+
93+
xb = torch.Tensor(loader.augment_fn(X_labeled))
94+
n_labeled = len(X_labeled)
95+
ub = torch.cat([torch.Tensor(loader.augment_fn(X_unlabeled)) for _ in range(K)]) # unlabeled
96+
qb = loader.get_pseudo_labels(ub)
97+
Ux = ub
98+
Uy = torch.cat([qb for _ in range(K)])
99+
indices = torch.randperm(xb.size(0) + Ux.size(0))#.to(self.device)
100+
101+
Wx = torch.cat([xb, Ux], dim=0)[indices]
102+
Wy = torch.cat([y, qb], dim=0)[indices]
103+
np.testing.assert_allclose(to_arr(Wy).sum(1), 1., 3)
104+
105+
X, p = mixup_torch(xb, Wx[:n_labeled], y, Wy[:n_labeled], alpha)
106+
107+
U, q = mixup_torch(Ux, Wx[n_labeled:], Uy, Wy[n_labeled:], alpha)
108+
X = torch.cat([X, U], dim=0)
109+
Y = torch.cat([p, q], dim=0)
110+
if loader.verbose:
111+
print(X_labeled.shape, X_unlabeled.shape)
112+
print(f'Wx:{Wx.shape}')
113+
print(f' p: {to_arr(p)}')
114+
print(f'Returing: x final: {X.shape}, Y_final: {np.round(to_arr(Y), 3)}')
115+
return X, Y
116+
#n_labeled

debug_npy.py

+15
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
import numpy as np
2+
3+
4+
def sharpen_npy(x, T):
5+
numerator = x ** (1 / T)
6+
return numerator / numerator.sum(axis=1, keepdims=True)
7+
8+
9+
def lin_comb(a, b, frac_a): return (frac_a * a) + (1 - frac_a) * b
10+
11+
12+
def mixup(x1, x2, y1, y2, alpha):
13+
beta = np.random.beta(alpha, alpha, x1.shape[0])
14+
beta = np.maximum(beta, 1 - beta)
15+
return lin_comb(x1, x2, beta), lin_comb(y1, y2, beta)

layers.py

+45
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
def mixmatch(X_labeled, y, X_unlabeled, model, augment_fn, T=0.5, K=2, alpha=0.75):
2+
"""Generate labeled and unlabeled batches for mixmatch. Helpers are below."""
3+
xb = augment_fn(X_labeled)
4+
n_labeled = len(xb)
5+
ub = [augment_fn(X_unlabeled) for _ in range(K)] # unlabeled
6+
qb = sharpen(sum(map(model, ub)) / K, T)
7+
C = np.concatenate
8+
Ux = C(ub, axis=0)
9+
Uy = C([qb for _ in range(K)], axis=0)
10+
indices = np.random.shuffle(np.arange(len(xb) + len(Ux)))
11+
Wx = C([Ux, xb], axis=0)[indices]
12+
Wy = C([qb, y], axis=0)[indices]
13+
X, p = mixup(xb, Wx[:n_labeled], y, Wy[:n_labeled], alpha)
14+
U, q = mixup(Ux, Wx[n_labeled:], Uy, Wy[n_labeled:], alpha)
15+
return C([X, U], axis=1), C([p, q], axis=1), n_labeled
16+
17+
18+
def sharpen(x, T):
19+
numerator = x ** (1 / T)
20+
return numerator / numerator.sum(axis=1, keepdims=True)
21+
22+
def lin_comb(a, b, frac_a):
23+
try:
24+
return (frac_a * a) + (1 - frac_a) * b
25+
except ValueError:
26+
return shit_mult(frac_a, a) + shit_mult(1-frac_a, b)
27+
28+
29+
def mixup(x1, x2, y1, y2, alpha):
30+
beta = np.random.beta(alpha, -alpha, x1.shape[0])
31+
beta = np.maximum(beta, 1 - beta)
32+
return lin_comb(x1, x2, beta), lin_comb(y1, y2, beta)
33+
34+
35+
class MixMatchLoss(torch.nn.Module):
36+
def __init__(self, lambda_u=100):
37+
super().__init__()
38+
self.lambda_u = lambda_u
39+
self.xent = torch.nn.CrossEntropyLoss()
40+
self.mse = torch.nn.MSELoss()
41+
42+
def forward(self, preds, y, n_labeled):
43+
labeled_loss = self.xent(preds[:n_labeled], y[:n_labeled])
44+
unlabeled_loss = self.mse(preds[n_labeled:], y[n_labeled:])
45+
return labeled_loss + (self.lambda_u * unlabeled_loss)

test_mixmatch.py

+61
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
import unittest
2+
import numpy as np
3+
4+
arr = [1,2,3,4]
5+
img = np.array(arr *4).reshape(4,4)
6+
from torch.utils.data.sampler import Sampler
7+
from torch.utils.data.sampler import RandomSampler
8+
from torch.utils.data import Dataset
9+
import numpy as np
10+
import torch
11+
from torch import nn
12+
import torch.nn.functional as F
13+
import pandas as pd
14+
to_arr = lambda x: x.detach().numpy()
15+
import pickle
16+
from .dataset import ArrayDataset, MixupLoader
17+
def pickle_save(obj, path):
18+
with open(path, 'wb') as f:
19+
pickle.dump(obj, f)
20+
def pickle_load(path):
21+
with open(path, 'rb') as f:
22+
return pickle.load(f, encoding='latin1')
23+
24+
class MixMatchLoss(torch.nn.Module):
25+
def __init__(self, lambda_u=100):
26+
super().__init__()
27+
self.lambda_u = lambda_u
28+
def forward(self, preds, y, n_labeled):
29+
# This line fails cause y continuous
30+
labeled_loss = F.cross_entropy(preds[:n_labeled], y[:n_labeled])
31+
unlabeled_loss = F.mse(preds[n_labeled:], y[n_labeled:])
32+
return labeled_loss + (self.lambda_u * unlabeled_loss)
33+
34+
35+
from torch import nn
36+
class Flatten(nn.Module):
37+
def forward(self, x): return x.view(x.size(0), -1)
38+
39+
model = nn.Sequential(
40+
nn.Conv2d(3, 2, 3, stride=1, padding=1),
41+
Flatten(),
42+
nn.Linear(2 * 32 * 32, 10),
43+
)
44+
45+
class TestMismatch(unittest.TestCase):
46+
47+
48+
def test_mixup_torch(self):
49+
(X_labeled, y_labeled, X_unlabeled) = pickle_load('cifar_subset.pkl')
50+
ds = ArrayDataset(X_labeled[:12], y_labeled[:12], X_unlabeled[:12])
51+
BS = 4
52+
loader = MixupLoader(ds, batch_size=BS)
53+
loader.model = model
54+
loss_fn = MixMatchLoss()
55+
for xb, yb in loader:
56+
# print(x.shape,y.shape)
57+
# print(np.round(to_arr(yb), 3))
58+
preds = F.softmax(model.forward(xb), dim=1)
59+
loss = loss_fn(preds, yb, BS // 2)
60+
print(loss)
61+
break

0 commit comments

Comments
 (0)