Skip to content

Commit ae6e097

Browse files
committed
Updated Code
For MICCAI 2024
1 parent eb6367e commit ae6e097

File tree

12 files changed

+1029
-0
lines changed

12 files changed

+1029
-0
lines changed

data/.DS_Store

6 KB
Binary file not shown.

data/loader.py

+206
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,206 @@
1+
# encoding: utf-8
2+
"""
3+
Read images and corresponding labels.
4+
"""
5+
6+
import torch
7+
from torch.utils.data import Dataset
8+
from torchvision import transforms
9+
import pandas as pd
10+
import numpy as np
11+
from PIL import Image
12+
import os
13+
14+
from PIL import ImageFile
15+
ImageFile.LOAD_TRUNCATED_IMAGES = True
16+
17+
#############################################################
18+
##### Dataset with memory bank and contrastive samples. #####
19+
#############################################################
20+
21+
22+
class ISIC_InstanceSample(Dataset):
23+
24+
def __init__(self, root_dir, csv_file, CCD_mode, transform=None, p=10, k=4096,
25+
mode='exact', is_sample=True, percent=1.0):
26+
super(ISIC_InstanceSample, self).__init__()
27+
28+
self.p = p
29+
self.k = k
30+
self.mode = mode
31+
self.CCD_mode = CCD_mode
32+
self.is_sample = is_sample
33+
34+
file = pd.read_csv(csv_file)
35+
36+
self.root_dir = root_dir
37+
self.images = file['id_code'].values # image name
38+
self.labels = file['diagnosis'].values.astype(int) # scalar label
39+
n_classes = len(np.unique(self.labels))
40+
# one hot. [num_images, num_classes]
41+
self.labels = np.eye(n_classes)[self.labels.reshape(-1)]
42+
self.transform = transform
43+
44+
print('Total # images:{}, labels:{}'.format(
45+
len(self.images), len(self.labels)))
46+
47+
num_samples = len(self.images)
48+
label = np.argmax(self.labels, axis=1)
49+
50+
self.cls_positive = [[] for i in range(n_classes)]
51+
for i in range(num_samples):
52+
self.cls_positive[label[i]].append(i)
53+
54+
self.cls_negative = [[] for i in range(n_classes)]
55+
for i in range(n_classes):
56+
for j in range(n_classes):
57+
if j == i:
58+
continue
59+
self.cls_negative[i].extend(self.cls_positive[j])
60+
61+
self.cls_positive = [np.asarray(self.cls_positive[i])
62+
for i in range(n_classes)]
63+
self.cls_negative = [np.asarray(self.cls_negative[i])
64+
for i in range(n_classes)]
65+
66+
self.class_index = self.cls_positive
67+
68+
if 0 < percent < 1:
69+
n = int(len(self.cls_negative[0]) * percent)
70+
self.cls_negative = [np.random.permutation(self.cls_negative[i])[0:n]
71+
for i in range(n_classes)]
72+
73+
self.cls_positive = np.asarray(self.cls_positive, dtype=object)
74+
self.cls_negative = np.asarray(self.cls_negative, dtype=object)
75+
76+
def __getitem__(self, index):
77+
image_name = os.path.join(self.root_dir, self.images[index]+'.png')
78+
img = Image.open(image_name).convert('RGB')
79+
target = np.argmax(self.labels, axis=1)[index]
80+
label = self.labels[index]
81+
82+
if self.transform is not None:
83+
img = self.transform(img)
84+
85+
if not self.is_sample:
86+
return img, target, index
87+
else:
88+
# sample contrastive examples
89+
if self.mode == 'exact':
90+
pos_idx = index
91+
elif self.mode == 'relax':
92+
pos_idx = np.random.choice(self.cls_positive[target], 1)[0]
93+
elif self.mode == 'multi_pos':
94+
pos_idx = np.random.choice(
95+
self.cls_positive[target], self.p, replace=False)
96+
else:
97+
raise NotImplementedError(self.mode)
98+
99+
if self.CCD_mode == "sup":
100+
replace = True if self.k > len(
101+
self.cls_negative[target]) else False
102+
neg_idx = np.random.choice(
103+
self.cls_negative[target], self.k, replace=replace)
104+
elif self.CCD_mode == "unsup":
105+
pos_others = np.setdiff1d(self.cls_positive[target], ([index]))
106+
all_negative = np.hstack(
107+
(pos_others, self.cls_negative[target]))
108+
neg_idx = np.random.choice(all_negative, self.k, replace=True)
109+
110+
if self.mode == 'exact' or self.mode == 'relax':
111+
sample_idx = np.hstack((np.asarray([pos_idx]), neg_idx))
112+
elif self.mode == 'multi_pos':
113+
sample_idx = np.hstack((pos_idx, neg_idx))
114+
115+
return img, label, index, sample_idx
116+
117+
def __len__(self):
118+
return len(self.images)
119+
120+
121+
def load_dataset(args, p=10, k=4096, mode='exact', is_sample=True, percent=1.0):
122+
csv_file_train = args.csv_file_path + \
123+
args.dataset + '/' + args.split + '_train.csv'
124+
csv_file_test = args.csv_file_path + args.dataset + '/' + args.split + '_test.csv'
125+
126+
train_transform = TransformTwice(transforms.Compose([
127+
transforms.Resize((224, 224)),
128+
transforms.RandomAffine(degrees=10, translate=(0.02, 0.02)),
129+
transforms.RandomHorizontalFlip(),
130+
transforms.ToTensor(),
131+
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
132+
]))
133+
134+
test_transform = transforms.Compose([
135+
transforms.Resize((224, 224)),
136+
transforms.ToTensor(),
137+
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
138+
])
139+
140+
train_set = ISIC_InstanceSample(root_dir=args.root_path,
141+
csv_file=csv_file_train,
142+
CCD_mode=args.CCD_mode,
143+
transform=train_transform,
144+
p=p,
145+
k=k,
146+
mode=mode,
147+
is_sample=is_sample,
148+
percent=percent)
149+
150+
test_set = ISIC_Dataset(root_dir=args.root_path,
151+
csv_file=csv_file_test, transform=test_transform)
152+
153+
return train_set, test_set
154+
155+
156+
# Dataset without memory bank
157+
class ISIC_Dataset(Dataset):
158+
def __init__(self, root_dir, csv_file, transform=None):
159+
"""
160+
Args:
161+
data_dir: path to image directory.
162+
csv_file: path to the file containing images
163+
with corresponding labels.
164+
transform: optional transform to be applied on a sample.
165+
"""
166+
super(ISIC_Dataset, self).__init__()
167+
file = pd.read_csv(csv_file)
168+
169+
self.root_dir = root_dir
170+
self.images = file['id_code'].values # image name
171+
self.labels = file['diagnosis'].values.astype(int)
172+
self.n_classes = len(np.unique(self.labels))
173+
self.labels = np.eye(self.n_classes)[
174+
self.labels.reshape(-1)] # one_hot labels
175+
self.transform = transform
176+
177+
print('Total # images:{}, labels:{}'.format(
178+
len(self.images), len(self.labels)))
179+
180+
def __getitem__(self, index):
181+
"""
182+
Args:
183+
index: the index of item
184+
Returns:
185+
image and its labels
186+
"""
187+
image_name = os.path.join(self.root_dir, self.images[index]+'.png')
188+
image = Image.open(image_name).convert('RGB')
189+
label = self.labels[index]
190+
if self.transform is not None:
191+
image = self.transform(image)
192+
193+
return image, label
194+
195+
def __len__(self):
196+
return len(self.images)
197+
198+
199+
class TransformTwice:
200+
def __init__(self, transform):
201+
self.transform = transform
202+
203+
def __call__(self, inp):
204+
out1 = self.transform(inp)
205+
out2 = self.transform(inp)
206+
return out1, out2

logs/aptos/.DS_Store

6 KB
Binary file not shown.

models/.DS_Store

6 KB
Binary file not shown.

models/model.py

+32
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
import torch
2+
import torch.nn as nn
3+
import torch.nn.functional as F
4+
from torchvision.models import densenet121, DenseNet121_Weights
5+
6+
7+
class DenseNet121(nn.Module):
8+
def __init__(self, hidden_units, out_size, drop_rate=0):
9+
super(DenseNet121, self).__init__()
10+
self.densenet121 = densenet121(
11+
weights=DenseNet121_Weights.IMAGENET1K_V1)
12+
num_ftrs = self.densenet121.classifier.in_features
13+
14+
self.densenet121.fc_layer = nn.Linear(num_ftrs, hidden_units)
15+
self.densenet121.classifier = nn.Linear(hidden_units, out_size)
16+
17+
self.drop_rate = drop_rate
18+
self.drop_layer = nn.Dropout(p=drop_rate)
19+
20+
def forward(self, x):
21+
features = self.densenet121.features(x)
22+
fmaps_b4 = F.relu(features, inplace=True)
23+
24+
out = F.adaptive_avg_pool2d(
25+
fmaps_b4, (1, 1)).view(fmaps_b4.size(0), -1)
26+
27+
if self.drop_rate > 0:
28+
out = self.drop_layer(out)
29+
30+
feature4 = self.densenet121.fc_layer(out)
31+
logit_b4 = self.densenet121.classifier(feature4)
32+
return feature4, logit_b4

train.py

+144
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
from utils.utils import get_labels_frequency, set_logger
2+
from utils.trainer import fit
3+
from models.model import DenseNet121
4+
from data.loader import load_dataset
5+
from torch.utils.data import DataLoader
6+
import torch.backends.cudnn as cudnn
7+
import torch
8+
import numpy as np
9+
import random
10+
import logging
11+
import sys
12+
import os
13+
import argparse
14+
import warnings
15+
warnings.simplefilter('ignore')
16+
17+
18+
def get_args():
19+
parser = argparse.ArgumentParser()
20+
parser.add_argument('--root_path', type=str,
21+
default='../Datasets/APTOS/APTOS_images/train_images')
22+
parser.add_argument('--csv_file_path', type=str, default='../CSVs/')
23+
parser.add_argument("--logdir", type=str, required=False,
24+
default="./logs/aptos/", help="Log directory path")
25+
parser.add_argument('--dataset', type=str, default='aptos')
26+
parser.add_argument('--split', type=str, default='split1')
27+
28+
parser.add_argument('--n_distill', type=int, default=20,
29+
help='start to use the kld loss')
30+
31+
parser.add_argument('--mode', default='exact', type=str,
32+
choices=['exact', 'relax', 'multi_pos'])
33+
parser.add_argument('--nce_p', default=1, type=int,
34+
help='number of positive samples for NCE')
35+
parser.add_argument('--nce_k', default=4096, type=int,
36+
help='number of negative samples for NCE')
37+
parser.add_argument('--nce_t', default=0.07, type=float,
38+
help='temperature parameter for softmax')
39+
parser.add_argument('--nce_m', default=0.5, type=float,
40+
help='momentum for non-parametric updates')
41+
parser.add_argument('--CCD_mode', type=str,
42+
default="sup", choices=['sup', 'unsup'])
43+
parser.add_argument('--rel_weight', type=float, default=25,
44+
help='whether use the relation loss')
45+
parser.add_argument('--ccd_weight', type=float,
46+
default=0.1, help='whether use the CCD loss')
47+
48+
parser.add_argument('--anchor_type', type=str,
49+
default="center", choices=['center', 'class'])
50+
parser.add_argument('--class_anchor', default=30, type=int,
51+
help='number of anchors in each class')
52+
53+
parser.add_argument('--feat_dim', type=int, default=128,
54+
help='reduced feature dimension')
55+
parser.add_argument('--s_dim', type=int, default=128,
56+
help='feature dim of the student model')
57+
parser.add_argument('--t_dim', type=int, default=128,
58+
help='feature dim of the EMA teacher')
59+
parser.add_argument('--n_data', type=int, default=3662,
60+
help='total number of training samples.')
61+
parser.add_argument('--t_decay', type=float,
62+
default=0.99, help='ema_decay')
63+
64+
parser.add_argument('--epochs', type=int, default=80,
65+
help='maximum epoch number to train')
66+
parser.add_argument('--batch_size', type=int,
67+
default=64, help='batch_size per gpu')
68+
parser.add_argument('--drop_rate', type=int,
69+
default=0, help='dropout rate')
70+
parser.add_argument('--lr', type=float, default=1e-4,
71+
help='learning rate')
72+
parser.add_argument('--seed', type=int, default=2024, help='random seed')
73+
74+
parser.add_argument('--optimizer', type=str, default='adam', help='optim')
75+
parser.add_argument('--scheduler', type=str,
76+
default='OneCycleLR', help='sch_str')
77+
parser.add_argument('--device', type=str, default='cuda:0', help='device')
78+
79+
parser.add_argument('--consistency', type=float,
80+
default=1, help='consistency')
81+
parser.add_argument('--consistency_rampup', type=float,
82+
default=30, help='consistency_rampup')
83+
84+
args = parser.parse_args()
85+
return args
86+
87+
# Function to set the seed for all random number generators to ensure reproducibility
88+
89+
90+
def set_seed(seed):
91+
cudnn.benchmark = False
92+
cudnn.deterministic = True
93+
random.seed(seed)
94+
np.random.seed(seed)
95+
torch.manual_seed(seed)
96+
torch.cuda.manual_seed(seed)
97+
98+
99+
if __name__ == "__main__":
100+
# Get arguments
101+
args = get_args()
102+
103+
# Set seed
104+
set_seed(args.seed)
105+
106+
# Set Logger
107+
if not os.path.exists(args.logdir):
108+
os.makedirs(args.logdir)
109+
logger = set_logger(args)
110+
logger.info(args)
111+
112+
# Loading Data
113+
train_ds, test_ds = load_dataset(args, p=args.nce_p, mode=args.mode)
114+
n_classes = test_ds.n_classes
115+
class_index = train_ds.class_index
116+
print(n_classes)
117+
118+
def worker_init_fn(worker_id):
119+
random.seed(args.seed+worker_id)
120+
train_dl = DataLoader(train_ds, batch_size=args.batch_size,
121+
shuffle=True, num_workers=12, pin_memory=True,
122+
worker_init_fn=worker_init_fn)
123+
124+
test_dl = DataLoader(test_ds, batch_size=args.batch_size,
125+
shuffle=False, num_workers=12, pin_memory=True,
126+
worker_init_fn=worker_init_fn)
127+
freq = get_labels_frequency(args.csv_file_path + args.dataset +
128+
'/' + args.split + '_train.csv', 'diagnosis', 'id_code')
129+
freq = freq.values
130+
weights = freq.sum() / freq
131+
print(weights)
132+
133+
# Loading Models
134+
student = DenseNet121(hidden_units=args.feat_dim,
135+
out_size=n_classes, drop_rate=args.drop_rate)
136+
teacher = DenseNet121(hidden_units=args.feat_dim,
137+
out_size=n_classes, drop_rate=args.drop_rate)
138+
139+
for param in teacher.parameters():
140+
param.detach_()
141+
142+
# Fit the model
143+
fit(student, teacher, train_dl, test_dl, weights,
144+
class_index, logger, args, device=args.device)

utils/.DS_Store

6 KB
Binary file not shown.

0 commit comments

Comments
 (0)