-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
13 changed files
with
699 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,165 @@ | ||
import os | ||
import pickle | ||
import numpy as np | ||
import torch | ||
|
||
# ======================================================== | ||
# Usefull paths | ||
|
||
_cacheDir = "./cache" | ||
_maxRuns = 10000 | ||
_min_examples = -1 | ||
|
||
# ======================================================== | ||
# Module internal functions and variables | ||
|
||
_randStates = None | ||
_rsCfg = None | ||
|
||
|
||
def _load_pickle(file): | ||
with open(file, 'rb') as f: | ||
data = pickle.load(f) | ||
labels = [np.full(shape=len(data[key]), fill_value=key) | ||
for key in data] | ||
data = [features for key in data for features in data[key]] | ||
dataset = dict() | ||
dataset['data'] = torch.FloatTensor(np.stack(data, axis=0)) | ||
dataset['labels'] = torch.LongTensor(np.concatenate(labels)) | ||
return dataset | ||
|
||
|
||
# ========================================================= | ||
# Callable variables and functions from outside the module | ||
|
||
data = None | ||
labels = None | ||
dsName = None | ||
|
||
_datasetFeaturesFiles = {"miniImagenet": "./checkpoints/miniImagenet/novel_features.plk", | ||
"CUB": "./checkpoints/CUB/novel_features.plk", | ||
'tieredImagenet': "./checkpoints/tieredImagenet/novel_features.plk"} | ||
|
||
def loadDataSet(dsname, datasetfeatfiles): | ||
if dsname not in datasetfeatfiles: | ||
raise NameError('Unknwown dataset: {}'.format(dsname)) | ||
|
||
global dsName, data, labels, _randStates, _rsCfg, _min_examples | ||
dsName = dsname | ||
_randStates = None | ||
_rsCfg = None | ||
|
||
# Loading data from files on computer | ||
# home = expanduser("~") | ||
dataset = _load_pickle(datasetfeatfiles[dsname]) | ||
|
||
# Computing the number of items per class in the dataset | ||
_min_examples = dataset["labels"].shape[0] | ||
for i in range(dataset["labels"].shape[0]): | ||
if torch.where(dataset["labels"] == dataset["labels"][i])[0].shape[0] > 0: | ||
_min_examples = min(_min_examples, torch.where( | ||
dataset["labels"] == dataset["labels"][i])[0].shape[0]) | ||
print("Guaranteed number of items per class: {:d}\n".format(_min_examples)) | ||
|
||
# Generating data tensors | ||
data = torch.zeros((0, _min_examples, dataset["data"].shape[1])) | ||
labels = dataset["labels"].clone() | ||
while labels.shape[0] > 0: | ||
indices = torch.where(dataset["labels"] == labels[0])[0] | ||
data = torch.cat([data, dataset["data"][indices, :] | ||
[:_min_examples].view(1, _min_examples, -1)], dim=0) | ||
indices = torch.where(labels != labels[0])[0] | ||
labels = labels[indices] | ||
print("Total of {:d} classes, {:d} elements each, with dimension {:d}\n".format( | ||
data.shape[0], data.shape[1], data.shape[2])) | ||
|
||
|
||
def GenerateRun(iRun, cfg, regenRState=False, generate=True): | ||
global _randStates, data, _min_examples | ||
if not regenRState: | ||
np.random.set_state(_randStates[iRun]) | ||
|
||
classes = np.random.permutation(np.arange(data.shape[0]))[:cfg["ways"]] | ||
shuffle_indices = np.arange(_min_examples) | ||
dataset = None | ||
if generate: | ||
dataset = torch.zeros( | ||
(cfg['ways'], cfg['shot']+cfg['queries'], data.shape[2])) | ||
for i in range(cfg['ways']): | ||
shuffle_indices = np.random.permutation(shuffle_indices) | ||
if generate: | ||
dataset[i] = data[classes[i], shuffle_indices, | ||
:][:cfg['shot']+cfg['queries']] | ||
|
||
return dataset | ||
|
||
|
||
def ClassesInRun(iRun, cfg): | ||
global _randStates, data | ||
np.random.set_state(_randStates[iRun]) | ||
|
||
classes = np.random.permutation(np.arange(data.shape[0]))[:cfg["ways"]] | ||
return classes | ||
|
||
|
||
def setRandomStates(cfg): | ||
global _randStates, _maxRuns, _rsCfg | ||
if _rsCfg == cfg: | ||
return | ||
|
||
rsFile = os.path.join(_cacheDir, "RandStates_{}_s{}_q{}_w{}".format( | ||
dsName, cfg['shot'], cfg['queries'], cfg['ways'])) | ||
if not os.path.exists(rsFile): | ||
print("{} does not exist, regenerating it...".format(rsFile)) | ||
np.random.seed(0) | ||
_randStates = [] | ||
for iRun in range(_maxRuns): | ||
_randStates.append(np.random.get_state()) | ||
GenerateRun(iRun, cfg, regenRState=True, generate=False) | ||
torch.save(_randStates, rsFile) | ||
else: | ||
print("reloading random states from file....") | ||
_randStates = torch.load(rsFile) | ||
_rsCfg = cfg | ||
|
||
|
||
def GenerateRunSet(start=None, end=None, cfg=None): | ||
global dataset, _maxRuns | ||
if start is None: | ||
start = 0 | ||
if end is None: | ||
end = _maxRuns | ||
if cfg is None: | ||
cfg = {"shot": 1, "ways": 5, "queries": 15} | ||
|
||
setRandomStates(cfg) | ||
print("generating task from {} to {}".format(start, end)) | ||
|
||
dataset = torch.zeros( | ||
(end-start, cfg['ways'], cfg['shot']+cfg['queries'], data.shape[2])) | ||
for iRun in range(end-start): | ||
dataset[iRun] = GenerateRun(start+iRun, cfg) | ||
|
||
return dataset | ||
|
||
|
||
# define a main code to test this module | ||
if __name__ == "__main__": | ||
_datasetFeaturesFiles = {"miniImagenet": "./checkpoints/miniImagenet/novel_features.plk", | ||
"CUB": "./checkpoints/CUB/novel_features.plk", | ||
'tieredImagenet': "./checkpoints/tieredImagenet/novel_features.plk"} | ||
print("Testing Task loader for Few Shot Learning") | ||
loadDataSet('CUB',_datasetFeaturesFiles) | ||
|
||
cfg = {"shot": 1, "ways": 5, "queries": 15} | ||
setRandomStates(cfg) | ||
|
||
run10 = GenerateRun(10, cfg) | ||
print("First call:", run10[:2, :2, :2]) | ||
|
||
run10 = GenerateRun(10, cfg) | ||
print("Second call:", run10[:2, :2, :2]) | ||
|
||
ds = GenerateRunSet(start=2, end=12, cfg=cfg) | ||
print("Third call:", ds[8, :2, :2, :2]) | ||
print(ds.size()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,160 @@ | ||
import numpy as np | ||
import torch | ||
from sklearn.linear_model import LogisticRegression | ||
from sklearn.svm import SVC | ||
from tqdm import tqdm | ||
import os | ||
from sample_method.WideSearch import WS | ||
from sample_method.DeepSearch import DS | ||
import FSLTask | ||
import random | ||
import pickle | ||
import configargparse | ||
from Proto_train import proto_train | ||
from sklearn.neural_network import MLPClassifier | ||
|
||
use_gpu = torch.cuda.is_available() | ||
|
||
os.environ['KMP_DUPLICATE_LIB_OK']='True' | ||
|
||
def sample_case(ld_dict, shot, way=5, num_qry=15): | ||
# Sample meta task | ||
sample_class = random.sample(list(ld_dict.keys()), way) | ||
train_input = [] | ||
test_input = [] | ||
test_label = [] | ||
train_label = [] | ||
for each_class in sample_class: | ||
total_samples = shot + num_qry | ||
if len(ld_dict[each_class]) < total_samples: | ||
total_samples = len(ld_dict[each_class]) | ||
|
||
samples = random.sample(ld_dict[each_class], total_samples) | ||
train_label += [each_class] * len(samples[:shot]) | ||
test_label += [each_class] * len(samples[shot:]) | ||
train_input += samples[:shot] | ||
test_input += samples[shot:] | ||
train_input = np.array(train_input).astype(np.float32) | ||
test_input = np.array(test_input).astype(np.float32) | ||
return train_input, test_input, train_label, test_label | ||
|
||
|
||
def main(args): | ||
# ---- data loading | ||
|
||
beta = 0.5 | ||
if_tukey_transform = True | ||
if_sample = True | ||
|
||
|
||
_datasetFeaturesFiles = "../checkpoints/{}_{}/features.plk".format(args.dataset, args.backbone) | ||
|
||
|
||
with open(_datasetFeaturesFiles, 'rb') as f: | ||
myfeatures = pickle.load(f) | ||
|
||
novel_feature = myfeatures[2] | ||
# base_feature = myfeatures[4] | ||
|
||
|
||
# ---- classification for each task | ||
lr_acc_list, svm_acc_list, nn_acc_list, proto_acc_list = [], [], [], [] | ||
|
||
for i in tqdm(range(args.n_runs)): # ndatas: (n_runs, n_samples, dimension) | ||
|
||
support_data, query_data, support_label, query_label = \ | ||
sample_case(ld_dict=novel_feature, shot=args.shots,way=args.ways, num_qry=args.n_queries) | ||
|
||
support_label = np.array(support_label).reshape((args.ways, -1)).T.reshape(-1) | ||
support_data = np.array(support_data).reshape((args.ways, args.shots, -1)).transpose(1, 0, 2).reshape(args.ways * args.shots, -1) | ||
|
||
query_label = np.array(query_label).reshape((args.ways, -1)).T.reshape(-1) | ||
query_data = np.array(query_data).reshape((args.ways, args.n_queries, -1)).transpose(1, 0, 2).reshape(args.ways * args.n_queries, -1) | ||
|
||
|
||
# # ---- Tukey's transform | ||
if if_tukey_transform: | ||
support_data = np.power(support_data[:, ], beta) | ||
query_data = np.power(query_data[:, ], beta) | ||
|
||
# ---- feature sampling | ||
if if_sample: | ||
if args.method == 'WS' or 'Prototype': | ||
# train data | ||
sampled_data, sampled_label = WS(args, support_data, support_label, query_data) | ||
if args.method == 'DS': | ||
# train data | ||
sampled_data, sampled_label = DS(args, support_data, support_label, query_data) | ||
|
||
X_aug = np.concatenate([support_data, sampled_data]) | ||
Y_aug = np.concatenate([support_label, sampled_label]) | ||
|
||
else: | ||
X_aug, Y_aug = support_data, support_label | ||
|
||
if args.method == 'Prototype': | ||
|
||
proto_test_acc = proto_train( | ||
[X_aug, Y_aug, query_data, query_label], | ||
args.ways, args.shots, args.num_latent, args.topk) | ||
proto_acc_list.append(proto_test_acc.item()) | ||
# print('【Prototype】【%d/%d】%s %d way %d shot ACC : %f' % ( | ||
# i, n_runs, dataset, n_ways, n_shot, float(np.mean(proto_acc_list)))) | ||
|
||
else: | ||
|
||
# ---- LR train classifier | ||
LRclassifier = LogisticRegression(max_iter=1000).fit(X=X_aug, y=Y_aug) | ||
predicts = LRclassifier.predict(query_data) | ||
acc = np.mean(predicts == query_label) | ||
# print('【LR】【%d/%d】%s %d way %d shot ACC : %f' % ( | ||
# i, n_runs, dataset, n_ways, n_shot, float(np.mean(lr_acc_list)))) | ||
lr_acc_list.append(acc) | ||
|
||
# ---- SVM train classifier | ||
SVMclassifier = SVC(max_iter=1000).fit(X=X_aug, y=Y_aug) | ||
predicts = SVMclassifier.predict(query_data) | ||
acc = np.mean(predicts == query_label) | ||
# print('【SVM】【%d/%d】%s %d way %d shot ACC : %f' % ( | ||
# i, n_runs, dataset, n_ways, n_shot, float(np.mean(svm_acc_list)))) | ||
svm_acc_list.append(acc) | ||
|
||
# ---- NN train classifer | ||
NNclassifier = MLPClassifier(random_state=123, max_iter=500, hidden_layer_sizes=(128, 64)).fit(X=X_aug, y=Y_aug) | ||
predicts = NNclassifier.predict(query_data) | ||
acc = np.mean(predicts == query_label) | ||
nn_acc_list.append(acc) | ||
if args.method == 'Prototype': | ||
return float(np.mean(proto_acc_list)), 1.96*np.std(proto_acc_list)/np.sqrt(args.n_runs) | ||
else: | ||
return float(np.mean(lr_acc_list)), float(np.mean(svm_acc_list)), float(np.mean(nn_acc_list)), \ | ||
1.96*np.std(lr_acc_list)/np.sqrt(args.n_runs), 1.96*np.std(svm_acc_list)/np.sqrt(args.n_runs), 1.96*np.std(nn_acc_list)/np.sqrt(args.n_runs) | ||
|
||
|
||
if __name__ == '__main__': | ||
|
||
parser = configargparse.ArgParser(description='Learning2Capture') | ||
parser.add_argument('--dataset', type=str, default='mini', help='mini/tiered/cub') | ||
parser.add_argument('--method', type=str, default="Prototype", help='DS/WS/Prototype') | ||
parser.add_argument('--backbone', type=str, default='res18', help='res18/wrn/conv4') | ||
parser.add_argument('--ways', type=int, default=5, help='N-way K-shot task setup') | ||
parser.add_argument('--shots', type=int, default=1, help='N-way K-shot task setup {1/5}') | ||
parser.add_argument('--topk', type=int, default=10, help='topk selection in DS=3/WS=10/prototype=10') | ||
parser.add_argument('--num_latent', type=int, default=1, help='number of generated samples default=1') | ||
parser.add_argument('--n_queries', type=int, default=15, help='number of query samples') | ||
parser.add_argument('--n_runs', type=int, default=600, help='number of query samples') | ||
args = parser.parse_args() | ||
|
||
print("----------{}-{}-{}W{}S-{}----------".format | ||
(args.dataset, args.backbone, args.ways, args.shots, args.method)) | ||
best_acc = 0. | ||
|
||
if args.method == 'Prototype': | ||
proto_acc, proto_ci95 = main(args) | ||
print('Prototype-based Classifier: {:.4f}({:.4f})'.format(proto_acc*100, proto_ci95*100)) | ||
|
||
else: | ||
lr_acc, svm_acc, nn_acc, lr_ci95, svm_ci95, nn_ci95 = main(args) | ||
print('LR-based Classifier: {:.4f}({:.4f})'.format(lr_acc*100, lr_ci95*100)) | ||
print('SVM-based Classifier: {:.4f}({:.4f})'.format(svm_acc*100, svm_ci95*100)) | ||
print('NN-based Classifier: {:.4f}({:.4f})'.format(nn_acc*100, nn_ci95*100)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
from torch.utils.data import DataLoader, TensorDataset | ||
import torch | ||
from learn2capture.prototype_loss import PrototypicalLoss | ||
|
||
|
||
def proto_train(data, n_ways, n_shot, KL, topK): | ||
|
||
[X_aug, Y_aug, query_data, query_label] = data | ||
X_aug, Y_aug, query_data, query_label = \ | ||
torch.from_numpy(X_aug), torch.from_numpy(Y_aug), torch.from_numpy(query_data), torch.from_numpy(query_label) | ||
|
||
X_test, Y_test = torch.cat((X_aug, query_data), dim=0), torch.cat((Y_aug, query_label), dim=0) | ||
test_loader = DataLoader(TensorDataset(X_test, Y_test.long()), batch_size=X_test.size(0), shuffle=False) | ||
|
||
proto_loss = PrototypicalLoss(n_support=n_ways*n_shot+n_ways*topK*KL) | ||
|
||
# test on query set | ||
for i, (test_x, test_y) in enumerate(test_loader): | ||
test_x, test_y = test_x.cuda().float(), test_y.cuda().float() | ||
|
||
loss, query_acc = proto_loss(test_x, test_y.long()) | ||
|
||
return query_acc |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
save_dir = '..' | ||
data_dir = {} | ||
data_dir['CUB'] = './filelists/CUB/' | ||
data_dir['miniImagenet'] = './filelists/miniImagenet/' |
Oops, something went wrong.