Skip to content

Commit 2c4f0a5

Browse files
author
previtus
committed
Added (rather simplistic) k-fold crossval to see how that goes
1 parent 53ab69e commit 2c4f0a5

6 files changed

+138
-36
lines changed

Dataset.py

+17-3
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import DataLoader, DataPreprocesser, Debugger
22
import DatasetInstance_OurAerial, DatasetInstance_ONERA
3+
import numpy as np
4+
35

46
class Dataset(object):
57
"""
@@ -42,12 +44,24 @@ def init_from_stable_datasets(self):
4244
print("Dataset loaded with", len(self.data[0]), "images.")
4345

4446
# Shuffle
45-
self.data = self.shuffle_thyself(self.data)
47+
#self.data = self.shuffle_thyself(self.data)
4648

4749
# Split into training, validation and test:
48-
self.train, self.val, self.test = self.datasetInstance.split_train_val_test(self.data)
49-
self.train_paths, self.val_paths, self.test_paths = self.datasetInstance.split_train_val_test(self.paths)
50+
51+
K = self.settings.TestDataset_K_Folds
52+
test_fold = self.settings.TestDataset_Fold_Index
53+
print("K-Fold crossval: [",test_fold,"from",K,"]")
54+
self.train, self.val, self.test = self.datasetInstance.split_train_val_test_KFOLDCROSSVAL(self.data, test_fold=test_fold, K=K)
55+
self.paths = np.asarray(self.paths)
56+
self.train_paths, self.val_paths, self.test_paths = self.datasetInstance.split_train_val_test_KFOLDCROSSVAL(self.paths, test_fold=test_fold, K=K)
57+
5058
print("Has ", len(self.train[0]), "train, ", len(self.val[0]), "val, ", len(self.test[0]), "test, ")
59+
#print("Has ", len(self.train_paths[0]), "train_paths, ", len(self.val_paths[0]), "val_paths, ", len(self.test_paths[0]), "test_paths, ")
60+
61+
#print("Revert...")
62+
#self.train, self.val, self.test = self.datasetInstance.split_train_val_test(self.data)
63+
#self.train_paths, self.val_paths, self.test_paths = self.datasetInstance.split_train_val_test(self.paths)
64+
#print("Has ", len(self.train[0]), "train, ", len(self.val[0]), "val, ", len(self.test[0]), "test, ")
5165

5266
# preprocess the dataset
5367
self.train, self.val, self.test = self.dataPreprocesser.process_dataset(self.train, self.val, self.test)

DatasetInstance_OurAerial.py

+49
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,55 @@ def __init__(self, settings, dataLoader, variant = "256"):
147147
self.DEBUG_TURN_OFF_BALANCING = True
148148

149149

150+
151+
def split_train_val_test_KFOLDCROSSVAL(self, data, test_fold = 0, K = 4):
152+
lefts, rights, labels = data
153+
154+
# now we would like the val jump around the dataset (and the rest can be still separated into train - val
155+
156+
# split [0 - end] into K folds, one as a test the rest as a train (alt. val, but that can be 0)
157+
N = len(lefts)
158+
jump_by = int(N / K)
159+
160+
test_L = np.empty(((0,)+lefts.shape[1:]), lefts.dtype)
161+
train_L = np.empty(((0,)+lefts.shape[1:]), lefts.dtype)
162+
test_R = np.empty(((0,) + rights.shape[1:]), rights.dtype)
163+
train_R = np.empty(((0,) + rights.shape[1:]), rights.dtype)
164+
test_V = np.empty(((0,) + labels.shape[1:]), labels.dtype)
165+
train_V = np.empty(((0,) + labels.shape[1:]), labels.dtype)
166+
167+
data_start = 0
168+
for fold_index in range(K):
169+
data_until = data_start + jump_by
170+
if data_until > N:
171+
data_until = N
172+
173+
fold_L = lefts[data_start:data_until]
174+
fold_R = rights[data_start:data_until]
175+
fold_V = labels[data_start:data_until]
176+
177+
#print("fold_L.shape", fold_L.shape)
178+
179+
if fold_index == test_fold:
180+
# add to test set
181+
test_L = np.append(test_L, fold_L, 0)
182+
test_R = np.append(test_R, fold_R, 0)
183+
test_V = np.append(test_V, fold_V, 0)
184+
else:
185+
# add to train set
186+
train_L = np.append(train_L, fold_L, 0)
187+
train_R = np.append(train_R, fold_R, 0)
188+
train_V = np.append(train_V, fold_V, 0)
189+
190+
data_start += jump_by
191+
192+
train = [train_L, train_R, train_V]
193+
test = [test_L, test_R, test_V]
194+
val = test # hmmm
195+
196+
return train, val, test
197+
198+
150199
def split_train_val_test(self, data):
151200
lefts, rights, labels = data
152201

Evaluator.py

+38-13
Original file line numberDiff line numberDiff line change
@@ -43,24 +43,27 @@ def try_all_thresholds(self, predicted, labels, range_values = [0.0, 0.5, 1.0],
4343
ys_recalls = []
4444
ys_precisions = []
4545
ys_accuracies = []
46+
ys_f1s= []
4647
for thr in range_values: #np.arange(0.0,1.0,0.01):
4748
xs.append(thr)
4849
print("threshold=",thr)
4950
#_, recall, precision, accuracy = self.calculate_metrics(predicted, labels, threshold=thr)
5051
if "NoChange" in title_txt:
5152
print("from the position of NoChange class instead...")
52-
recall, precision, accuracy, f1 = self.calculate_recall_precision_accuracy_NOCHANGECLASS(predicted, labels, threshold=thr)
53+
recall, precision, accuracy, f1 = self.calculate_recall_precision_accuracy_NOCHANGECLASS(predicted, labels, threshold=thr, need_f1=True)
5354
else:
54-
recall, precision, accuracy, f1 = self.calculate_recall_precision_accuracy(predicted, labels, threshold=thr)
55+
recall, precision, accuracy, f1 = self.calculate_recall_precision_accuracy(predicted, labels, threshold=thr, need_f1=True)
5556

5657
ys_recalls.append(recall)
5758
ys_precisions.append(precision)
5859
ys_accuracies.append(accuracy)
60+
ys_f1s.append(f1)
5961

6062
print("xs", len(xs), xs)
6163
print("ys_recalls", len(ys_recalls), ys_recalls)
6264
print("ys_precisions", len(ys_precisions), ys_precisions)
6365
print("ys_accuracies", len(ys_accuracies), ys_accuracies)
66+
print("ys_f1s", len(ys_f1s), ys_f1s)
6467

6568
if title_txt == "":
6669
plt.title('Changing the threshold values')
@@ -72,6 +75,7 @@ def try_all_thresholds(self, predicted, labels, range_values = [0.0, 0.5, 1.0],
7275
plt.plot(xs, ys_recalls, '-o', label="Recall")
7376
plt.plot(xs, ys_precisions, '-o', label="Precision")
7477
plt.plot(xs, ys_accuracies, '-o', label="Accuracy")
78+
plt.plot(xs, ys_f1s, '-o', label="f1")
7579
plt.legend()
7680

7781
plt.ylim(0.0, 1.0)
@@ -83,6 +87,25 @@ def try_all_thresholds(self, predicted, labels, range_values = [0.0, 0.5, 1.0],
8387
if show:
8488
plt.show()
8589

90+
plt.close()
91+
92+
def calculate_f1(self, predictions, ground_truths, threshold = 0.5):
93+
if len(predictions.shape) > 1:
94+
predictions_copy = np.array(predictions)
95+
else:
96+
predictions_copy = np.array([predictions])
97+
98+
for image in predictions_copy:
99+
image[image >= threshold] = 1
100+
image[image < threshold] = 0
101+
102+
arr_predictions = predictions_copy.flatten()
103+
arr_gts = ground_truths.flatten()
104+
105+
sklearn_f1 = sklearn.metrics.f1_score(arr_gts, arr_predictions)
106+
107+
return sklearn_f1
108+
86109
def calculate_recall_precision_accuracy(self, predictions, ground_truths, threshold = 0.5, need_f1=False):
87110
if len(predictions.shape) > 1:
88111
predictions_copy = np.array(predictions)
@@ -303,8 +326,9 @@ def calculate_metrics_fast(self, predictions, ground_truths, threshold = 0.5, ve
303326
return predictions_thresholded, recall, precision, accuracy
304327

305328
# select thr which maximizes the f1 score
306-
def metrics_autothr_f1_max(self, predictions, ground_truths, verbose=2):
307-
range_values = np.arange(0.0, 1.0, 0.01)
329+
def metrics_autothr_f1_max(self, predictions, ground_truths, jump_by = 0.1):
330+
# force it selecting something 'sensible' for the threshold ...
331+
range_values = np.arange(0.1, 0.9, jump_by)
308332

309333
xs = []
310334
ys_recalls = []
@@ -313,17 +337,18 @@ def metrics_autothr_f1_max(self, predictions, ground_truths, verbose=2):
313337
ys_f1s = []
314338
for thr in range_values:
315339
xs.append(thr)
316-
print("threshold=", thr)
317-
318-
recall, precision, accuracy, f1 = self.calculate_recall_precision_accuracy(predictions, ground_truths, threshold=thr)
340+
print("auto threshold=", thr)
319341

320-
ys_recalls.append(recall)
321-
ys_precisions.append(precision)
322-
ys_accuracies.append(accuracy)
342+
f1 = self.calculate_f1(predictions, ground_truths, threshold=thr)
323343
ys_f1s.append(f1)
324344

325345
max_f1_idx = np.argmax(ys_f1s)
326-
selected_thr = xs[max_f1_idx]
346+
best_thr = xs[max_f1_idx]
347+
348+
selected_recall, selected_precision, selected_accuracy, _ = self.calculate_recall_precision_accuracy(predictions, ground_truths,threshold=thr, need_f1=False)
349+
selected_f1 = ys_f1s[max_f1_idx]
350+
351+
print("Selecting threshold as", best_thr, "as it maximizes the f1 score getting", selected_f1,
352+
"(other scores are: recall", selected_recall, ", precision", selected_precision, ", acc", selected_accuracy, ")")
327353

328-
print("Selecting threshold as", selected_thr, "as it maximizes the f1 score getting", ys_f1s[max_f1_idx],
329-
"(other scores are: recall", ys_recalls[max_f1_idx], ", precision", ys_precisions[max_f1_idx], ", acc", ys_accuracies[max_f1_idx], ")")
354+
return best_thr, selected_recall, selected_precision, selected_accuracy, selected_f1

Model2_SiamUnet_Encoder.py

+16-17
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,6 @@ def __init__(self, settings, dataset):
6363
BACKBONE = 'resnet34'
6464
BACKBONE = 'resnet50' #batch 16
6565
#BACKBONE = 'resnet101' #batch 8
66-
#BACKBONE = 'seresnext50' #trying batch 16 as well
6766
custom_weights_file = "imagenet"
6867

6968
#weights from imagenet finetuned on aerial data specific task - will it work? will it break?
@@ -105,8 +104,6 @@ def train(self, show=True, save=False):
105104
print("label images (train)")
106105
self.debugger.explore_set_stats(train_V)
107106

108-
added_plots = []
109-
110107
from albumentations.core.transforms_interface import DualTransform
111108
class RandomRotate90x1(DualTransform):
112109
def apply(self, img, factor=0, **params):
@@ -188,16 +185,16 @@ def get_params(self):
188185
del augmented1
189186
del augmented2
190187

191-
if False:
192-
# for sake of showing:
193-
aug_lefts_tmp, aug_rights_tmp = self.dataPreprocesser.postprocess_images(np.asarray(aug_lefts), np.asarray(aug_rights))
188+
if False:
189+
# for sake of showing:
190+
aug_lefts_tmp, aug_rights_tmp = self.dataPreprocesser.postprocess_images(np.asarray(aug_lefts), np.asarray(aug_rights))
194191

195-
#self.debugger.viewTripples(aug_lefts, aug_rights, aug_ys)
196-
by = 5
197-
off = i * by
198-
while off < len(aug_lefts):
199-
self.debugger.viewTripples(aug_lefts_tmp, aug_rights_tmp, aug_ys, how_many=by, off=off)
200-
off += by
192+
#self.debugger.viewTripples(aug_lefts, aug_rights, aug_ys)
193+
by = 5
194+
off = i * by
195+
while off < len(aug_lefts):
196+
self.debugger.viewTripples(aug_lefts_tmp, aug_rights_tmp, aug_ys, how_many=by, off=off)
197+
off += by
201198

202199
aug_lefts = np.asarray(aug_lefts)
203200
aug_rights = np.asarray(aug_rights)
@@ -337,8 +334,8 @@ def test(self, evaluator, show = True, save = False):
337334
print("indices:", misclassified_indices)
338335
misclassified_indices = misclassified_indices[0]
339336

340-
for ind in misclassified_indices:
341-
print("idx", ind, ":", predicted_classlabels[ind]," != ",test_classlabels[ind])
337+
#for ind in misclassified_indices:
338+
# print("idx", ind, ":", predicted_classlabels[ind]," != ",test_classlabels[ind])
342339

343340

344341
print("MASK EVALUATION")
@@ -370,7 +367,7 @@ def test(self, evaluator, show = True, save = False):
370367
print("predicted images (test)")
371368
self.debugger.explore_set_stats(predicted)
372369

373-
370+
"""
374371
if Tile_Based_Evaluation:
375372
print("Misclassified samples (in total", len(misclassified_indices),"):")
376373
if show:
@@ -381,7 +378,7 @@ def test(self, evaluator, show = True, save = False):
381378
#self.debugger.viewTripples(test_L, test_R, test_V, how_many=4, off=off)
382379
self.debugger.viewQuadrupples(test_L[misclassified_indices], test_R[misclassified_indices], test_V[misclassified_indices], predicted[misclassified_indices], how_many=by, off=off, show=show,save=save)
383380
off += by
384-
381+
"""
385382

386383
if show:
387384
off = 0
@@ -398,7 +395,9 @@ def test(self, evaluator, show = True, save = False):
398395
until_n = min(by*8, len(test_L))
399396
while off < until_n:
400397
#self.debugger.viewTripples(test_L, test_R, test_V, how_many=4, off=off)
401-
self.debugger.viewQuadrupples(test_L, test_R, test_V, predicted, how_many=by, off=off, show=show,save=save, name=self.save_plot_path+"quad"+str(off))
398+
kfold_txt = "KFold_" + str(self.settings.TestDataset_Fold_Index) + "z" + str(self.settings.TestDataset_K_Folds)
399+
400+
self.debugger.viewQuadrupples(test_L, test_R, test_V, predicted, how_many=by, off=off, show=show,save=save, name=self.save_plot_path+"quad"+str(off)+kfold_txt)
402401
off += by
403402

404403

Model2_builder.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,7 @@ def build_siamese_unet(backbone, classes, skip_connection_layers,
226226
branch_a = branch_a_outputs[0]
227227
branch_b = branch_b_outputs[0]
228228

229-
x = Concatenate()([branch_a, branch_b]) # both inputs, in theory 8x8x512 + 8x8x512 -> 8x8x1024
229+
x = Concatenate(name="concatHighLvlFeat")([branch_a, branch_b]) # both inputs, in theory 8x8x512 + 8x8x512 -> 8x8x1024
230230

231231
skip_connection_outputs_a = branch_a_outputs[1:]
232232
skip_connection_outputs_b = branch_b_outputs[1:]

main.py

+17-2
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,13 @@ def main(args):
1919
print(args)
2020

2121
settings = Settings.Settings(args)
22+
23+
# We already did these
24+
# ResNet50 and indices: 5, 2, 7, 3 (doing ? r.n.)
25+
settings.TestDataset_Fold_Index = 999 # can be 0 to 9 (K-1)
26+
settings.TestDataset_K_Folds = 10
27+
assert settings.TestDataset_Fold_Index < settings.TestDataset_K_Folds
28+
2229
dataset = Dataset.Dataset(settings)
2330
evaluator = Evaluator.Evaluator(settings)
2431

@@ -29,7 +36,7 @@ def main(args):
2936
#dataset.dataset
3037
model = ModelHandler.ModelHandler(settings, dataset)
3138

32-
#model.model.train(show=show,save=save)
39+
model.model.train(show=show,save=save)
3340

3441
# Model 2 ...
3542

@@ -43,11 +50,19 @@ def main(args):
4350
# - class weights changed ?
4451
# - ... any other special cool thing ...
4552

53+
# K-Fold_Crossval:
54+
kfold_txt = "KFold_"+str(settings.TestDataset_Fold_Index)+"z"+str(settings.TestDataset_K_Folds)
55+
print(kfold_txt)
56+
57+
# resnet 101 approx 5-6 hours (per fold - might be a bit less ...)
58+
# resnet 50 approx 3-4 hours
59+
model.model.save("/scratch/ruzicka/python_projects_large/ChangeDetectionProject_files/weightsModel2_cleanManual_100ep_ImagenetWgenetW_resnet50-8batch_Augmentation1to1_ClassWeights1to3_["+kfold_txt+"].h5")
60+
4661
# Next = train Resnet50 on the same dataset without the whole STRIP2 (to have some large Test images)
4762

4863
#model.model.load("/scratch/ruzicka/python_projects_large/ChangeDetectionProject_files/weightsModel2_cleanManual_100ep_ImagenetWgenetW_seresnext50-8batch_Augmentation1to1_ClassWeights1to3.h5")
4964

50-
model.model.load("/scratch/ruzicka/python_projects_large/ChangeDetectionProject_files/weightsModel2_cleanManual-noStrip2_100ep_ImagenetWgenetW_resnet50-16batch_Augmentation1to1_ClassWeights1to3.h5")
65+
#model.model.load("/scratch/ruzicka/python_projects_large/ChangeDetectionProject_files/weightsModel2_cleanManual-noStrip2_100ep_ImagenetWgenetW_resnet50-16batch_Augmentation1to1_ClassWeights1to3.h5")
5166
#model.model.load("/scratch/ruzicka/python_projects_large/ChangeDetectionProject_files/weightsModel2_cleanManual_100ep_ImagenetWgenetW_resnet101-8batch_Augmentation1to1_ClassWeights1to3.h5")
5267

5368
# Senet154 crashed, 10hrs train + Imagenet weights + Data Aug 1:1 + Class weight 1:3

0 commit comments

Comments
 (0)