Skip to content

Commit

Permalink
Add files via upload
Browse files Browse the repository at this point in the history
  • Loading branch information
HiHippie authored Oct 19, 2021
1 parent 4c79ed3 commit 318a995
Show file tree
Hide file tree
Showing 13 changed files with 699 additions and 0 deletions.
165 changes: 165 additions & 0 deletions FSLTask.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
import os
import pickle
import numpy as np
import torch

# ========================================================
# Usefull paths

_cacheDir = "./cache"
_maxRuns = 10000
_min_examples = -1

# ========================================================
# Module internal functions and variables

_randStates = None
_rsCfg = None


def _load_pickle(file):
with open(file, 'rb') as f:
data = pickle.load(f)
labels = [np.full(shape=len(data[key]), fill_value=key)
for key in data]
data = [features for key in data for features in data[key]]
dataset = dict()
dataset['data'] = torch.FloatTensor(np.stack(data, axis=0))
dataset['labels'] = torch.LongTensor(np.concatenate(labels))
return dataset


# =========================================================
# Callable variables and functions from outside the module

data = None
labels = None
dsName = None

_datasetFeaturesFiles = {"miniImagenet": "./checkpoints/miniImagenet/novel_features.plk",
"CUB": "./checkpoints/CUB/novel_features.plk",
'tieredImagenet': "./checkpoints/tieredImagenet/novel_features.plk"}

def loadDataSet(dsname, datasetfeatfiles):
if dsname not in datasetfeatfiles:
raise NameError('Unknwown dataset: {}'.format(dsname))

global dsName, data, labels, _randStates, _rsCfg, _min_examples
dsName = dsname
_randStates = None
_rsCfg = None

# Loading data from files on computer
# home = expanduser("~")
dataset = _load_pickle(datasetfeatfiles[dsname])

# Computing the number of items per class in the dataset
_min_examples = dataset["labels"].shape[0]
for i in range(dataset["labels"].shape[0]):
if torch.where(dataset["labels"] == dataset["labels"][i])[0].shape[0] > 0:
_min_examples = min(_min_examples, torch.where(
dataset["labels"] == dataset["labels"][i])[0].shape[0])
print("Guaranteed number of items per class: {:d}\n".format(_min_examples))

# Generating data tensors
data = torch.zeros((0, _min_examples, dataset["data"].shape[1]))
labels = dataset["labels"].clone()
while labels.shape[0] > 0:
indices = torch.where(dataset["labels"] == labels[0])[0]
data = torch.cat([data, dataset["data"][indices, :]
[:_min_examples].view(1, _min_examples, -1)], dim=0)
indices = torch.where(labels != labels[0])[0]
labels = labels[indices]
print("Total of {:d} classes, {:d} elements each, with dimension {:d}\n".format(
data.shape[0], data.shape[1], data.shape[2]))


def GenerateRun(iRun, cfg, regenRState=False, generate=True):
global _randStates, data, _min_examples
if not regenRState:
np.random.set_state(_randStates[iRun])

classes = np.random.permutation(np.arange(data.shape[0]))[:cfg["ways"]]
shuffle_indices = np.arange(_min_examples)
dataset = None
if generate:
dataset = torch.zeros(
(cfg['ways'], cfg['shot']+cfg['queries'], data.shape[2]))
for i in range(cfg['ways']):
shuffle_indices = np.random.permutation(shuffle_indices)
if generate:
dataset[i] = data[classes[i], shuffle_indices,
:][:cfg['shot']+cfg['queries']]

return dataset


def ClassesInRun(iRun, cfg):
global _randStates, data
np.random.set_state(_randStates[iRun])

classes = np.random.permutation(np.arange(data.shape[0]))[:cfg["ways"]]
return classes


def setRandomStates(cfg):
global _randStates, _maxRuns, _rsCfg
if _rsCfg == cfg:
return

rsFile = os.path.join(_cacheDir, "RandStates_{}_s{}_q{}_w{}".format(
dsName, cfg['shot'], cfg['queries'], cfg['ways']))
if not os.path.exists(rsFile):
print("{} does not exist, regenerating it...".format(rsFile))
np.random.seed(0)
_randStates = []
for iRun in range(_maxRuns):
_randStates.append(np.random.get_state())
GenerateRun(iRun, cfg, regenRState=True, generate=False)
torch.save(_randStates, rsFile)
else:
print("reloading random states from file....")
_randStates = torch.load(rsFile)
_rsCfg = cfg


def GenerateRunSet(start=None, end=None, cfg=None):
global dataset, _maxRuns
if start is None:
start = 0
if end is None:
end = _maxRuns
if cfg is None:
cfg = {"shot": 1, "ways": 5, "queries": 15}

setRandomStates(cfg)
print("generating task from {} to {}".format(start, end))

dataset = torch.zeros(
(end-start, cfg['ways'], cfg['shot']+cfg['queries'], data.shape[2]))
for iRun in range(end-start):
dataset[iRun] = GenerateRun(start+iRun, cfg)

return dataset


# define a main code to test this module
if __name__ == "__main__":
_datasetFeaturesFiles = {"miniImagenet": "./checkpoints/miniImagenet/novel_features.plk",
"CUB": "./checkpoints/CUB/novel_features.plk",
'tieredImagenet': "./checkpoints/tieredImagenet/novel_features.plk"}
print("Testing Task loader for Few Shot Learning")
loadDataSet('CUB',_datasetFeaturesFiles)

cfg = {"shot": 1, "ways": 5, "queries": 15}
setRandomStates(cfg)

run10 = GenerateRun(10, cfg)
print("First call:", run10[:2, :2, :2])

run10 = GenerateRun(10, cfg)
print("Second call:", run10[:2, :2, :2])

ds = GenerateRunSet(start=2, end=12, cfg=cfg)
print("Third call:", ds[8, :2, :2, :2])
print(ds.size())
160 changes: 160 additions & 0 deletions Learning2Capture.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
import numpy as np
import torch
from sklearn.linear_model import LogisticRegression
from sklearn.svm import SVC
from tqdm import tqdm
import os
from sample_method.WideSearch import WS
from sample_method.DeepSearch import DS
import FSLTask
import random
import pickle
import configargparse
from Proto_train import proto_train
from sklearn.neural_network import MLPClassifier

use_gpu = torch.cuda.is_available()

os.environ['KMP_DUPLICATE_LIB_OK']='True'

def sample_case(ld_dict, shot, way=5, num_qry=15):
# Sample meta task
sample_class = random.sample(list(ld_dict.keys()), way)
train_input = []
test_input = []
test_label = []
train_label = []
for each_class in sample_class:
total_samples = shot + num_qry
if len(ld_dict[each_class]) < total_samples:
total_samples = len(ld_dict[each_class])

samples = random.sample(ld_dict[each_class], total_samples)
train_label += [each_class] * len(samples[:shot])
test_label += [each_class] * len(samples[shot:])
train_input += samples[:shot]
test_input += samples[shot:]
train_input = np.array(train_input).astype(np.float32)
test_input = np.array(test_input).astype(np.float32)
return train_input, test_input, train_label, test_label


def main(args):
# ---- data loading

beta = 0.5
if_tukey_transform = True
if_sample = True


_datasetFeaturesFiles = "../checkpoints/{}_{}/features.plk".format(args.dataset, args.backbone)


with open(_datasetFeaturesFiles, 'rb') as f:
myfeatures = pickle.load(f)

novel_feature = myfeatures[2]
# base_feature = myfeatures[4]


# ---- classification for each task
lr_acc_list, svm_acc_list, nn_acc_list, proto_acc_list = [], [], [], []

for i in tqdm(range(args.n_runs)): # ndatas: (n_runs, n_samples, dimension)

support_data, query_data, support_label, query_label = \
sample_case(ld_dict=novel_feature, shot=args.shots,way=args.ways, num_qry=args.n_queries)

support_label = np.array(support_label).reshape((args.ways, -1)).T.reshape(-1)
support_data = np.array(support_data).reshape((args.ways, args.shots, -1)).transpose(1, 0, 2).reshape(args.ways * args.shots, -1)

query_label = np.array(query_label).reshape((args.ways, -1)).T.reshape(-1)
query_data = np.array(query_data).reshape((args.ways, args.n_queries, -1)).transpose(1, 0, 2).reshape(args.ways * args.n_queries, -1)


# # ---- Tukey's transform
if if_tukey_transform:
support_data = np.power(support_data[:, ], beta)
query_data = np.power(query_data[:, ], beta)

# ---- feature sampling
if if_sample:
if args.method == 'WS' or 'Prototype':
# train data
sampled_data, sampled_label = WS(args, support_data, support_label, query_data)
if args.method == 'DS':
# train data
sampled_data, sampled_label = DS(args, support_data, support_label, query_data)

X_aug = np.concatenate([support_data, sampled_data])
Y_aug = np.concatenate([support_label, sampled_label])

else:
X_aug, Y_aug = support_data, support_label

if args.method == 'Prototype':

proto_test_acc = proto_train(
[X_aug, Y_aug, query_data, query_label],
args.ways, args.shots, args.num_latent, args.topk)
proto_acc_list.append(proto_test_acc.item())
# print('【Prototype】【%d/%d】%s %d way %d shot ACC : %f' % (
# i, n_runs, dataset, n_ways, n_shot, float(np.mean(proto_acc_list))))

else:

# ---- LR train classifier
LRclassifier = LogisticRegression(max_iter=1000).fit(X=X_aug, y=Y_aug)
predicts = LRclassifier.predict(query_data)
acc = np.mean(predicts == query_label)
# print('【LR】【%d/%d】%s %d way %d shot ACC : %f' % (
# i, n_runs, dataset, n_ways, n_shot, float(np.mean(lr_acc_list))))
lr_acc_list.append(acc)

# ---- SVM train classifier
SVMclassifier = SVC(max_iter=1000).fit(X=X_aug, y=Y_aug)
predicts = SVMclassifier.predict(query_data)
acc = np.mean(predicts == query_label)
# print('【SVM】【%d/%d】%s %d way %d shot ACC : %f' % (
# i, n_runs, dataset, n_ways, n_shot, float(np.mean(svm_acc_list))))
svm_acc_list.append(acc)

# ---- NN train classifer
NNclassifier = MLPClassifier(random_state=123, max_iter=500, hidden_layer_sizes=(128, 64)).fit(X=X_aug, y=Y_aug)
predicts = NNclassifier.predict(query_data)
acc = np.mean(predicts == query_label)
nn_acc_list.append(acc)
if args.method == 'Prototype':
return float(np.mean(proto_acc_list)), 1.96*np.std(proto_acc_list)/np.sqrt(args.n_runs)
else:
return float(np.mean(lr_acc_list)), float(np.mean(svm_acc_list)), float(np.mean(nn_acc_list)), \
1.96*np.std(lr_acc_list)/np.sqrt(args.n_runs), 1.96*np.std(svm_acc_list)/np.sqrt(args.n_runs), 1.96*np.std(nn_acc_list)/np.sqrt(args.n_runs)


if __name__ == '__main__':

parser = configargparse.ArgParser(description='Learning2Capture')
parser.add_argument('--dataset', type=str, default='mini', help='mini/tiered/cub')
parser.add_argument('--method', type=str, default="Prototype", help='DS/WS/Prototype')
parser.add_argument('--backbone', type=str, default='res18', help='res18/wrn/conv4')
parser.add_argument('--ways', type=int, default=5, help='N-way K-shot task setup')
parser.add_argument('--shots', type=int, default=1, help='N-way K-shot task setup {1/5}')
parser.add_argument('--topk', type=int, default=10, help='topk selection in DS=3/WS=10/prototype=10')
parser.add_argument('--num_latent', type=int, default=1, help='number of generated samples default=1')
parser.add_argument('--n_queries', type=int, default=15, help='number of query samples')
parser.add_argument('--n_runs', type=int, default=600, help='number of query samples')
args = parser.parse_args()

print("----------{}-{}-{}W{}S-{}----------".format
(args.dataset, args.backbone, args.ways, args.shots, args.method))
best_acc = 0.

if args.method == 'Prototype':
proto_acc, proto_ci95 = main(args)
print('Prototype-based Classifier: {:.4f}({:.4f})'.format(proto_acc*100, proto_ci95*100))

else:
lr_acc, svm_acc, nn_acc, lr_ci95, svm_ci95, nn_ci95 = main(args)
print('LR-based Classifier: {:.4f}({:.4f})'.format(lr_acc*100, lr_ci95*100))
print('SVM-based Classifier: {:.4f}({:.4f})'.format(svm_acc*100, svm_ci95*100))
print('NN-based Classifier: {:.4f}({:.4f})'.format(nn_acc*100, nn_ci95*100))
23 changes: 23 additions & 0 deletions Proto_train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from torch.utils.data import DataLoader, TensorDataset
import torch
from learn2capture.prototype_loss import PrototypicalLoss


def proto_train(data, n_ways, n_shot, KL, topK):

[X_aug, Y_aug, query_data, query_label] = data
X_aug, Y_aug, query_data, query_label = \
torch.from_numpy(X_aug), torch.from_numpy(Y_aug), torch.from_numpy(query_data), torch.from_numpy(query_label)

X_test, Y_test = torch.cat((X_aug, query_data), dim=0), torch.cat((Y_aug, query_label), dim=0)
test_loader = DataLoader(TensorDataset(X_test, Y_test.long()), batch_size=X_test.size(0), shuffle=False)

proto_loss = PrototypicalLoss(n_support=n_ways*n_shot+n_ways*topK*KL)

# test on query set
for i, (test_x, test_y) in enumerate(test_loader):
test_x, test_y = test_x.cuda().float(), test_y.cuda().float()

loss, query_acc = proto_loss(test_x, test_y.long())

return query_acc
4 changes: 4 additions & 0 deletions configs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
save_dir = '..'
data_dir = {}
data_dir['CUB'] = './filelists/CUB/'
data_dir['miniImagenet'] = './filelists/miniImagenet/'
Loading

0 comments on commit 318a995

Please sign in to comment.