-
Notifications
You must be signed in to change notification settings - Fork 0
/
beitv2.py
83 lines (73 loc) · 2.93 KB
/
beitv2.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import argparse
import os
import sys
import warnings
import glob
import pandas as pd
warnings.filterwarnings("ignore")
import numpy as np
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torchvision import models
from torch.utils.data import random_split, Dataset, DataLoader
from utils import CustomDataset, bce_dice_loss, Trainer, JointTransform
from transformers import AutoModel
import timm
import random
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(42)
model_name = sys.argv[1]
ratio = float(sys.argv[2])
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = timm.create_model(
'beitv2_base_patch16_224.in1k_ft_in22k',
pretrained=True,
num_classes=4,
)
image_size, criterion_s, criterion_c = 224, None, torch.nn.CrossEntropyLoss()
num_classes = 4
for param in model.patch_embed.parameters():
param.requires_grad = False
freeze_up_to_layer = 8
for layer_index, layer in enumerate(model.blocks):
if layer_index < freeze_up_to_layer:
for param in layer.parameters():
param.requires_grad = False
# for name, param in model.named_parameters():
# print(f"{name} is {'frozen' if not param.requires_grad else 'not frozen'}")
image_transform = transforms.Compose([
transforms.Resize((image_size, image_size)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
mask_transform = transforms.Compose([
transforms.Resize((image_size, image_size)),
transforms.ToTensor()
])
root_dir = "COVID-19_Radiography_Dataset"
train_joint_transform = JointTransform(image_transform=image_transform, mask_transform=mask_transform, flip=True)
val_joint_transform = JointTransform(image_transform=image_transform, mask_transform=mask_transform, flip=False)
dataset = CustomDataset(root_dir)
dataset_size = len(dataset)
train_size = int(dataset_size * 0.8)
val_size = dataset_size - train_size
batch_size = 64
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
if(ratio < 1):
unused_size = int(train_size * (1 - ratio))
train_dataset, unused_dataset = random_split(train_dataset, [train_size - unused_size, unused_size])
train_dataset.dataset.transform = train_joint_transform
val_dataset.dataset.transform = val_joint_transform
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=2)
model = model.to(device)
epochs = 20
learning_rate = 0.00005
weight_decay = 1e-6
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
trainer = Trainer(model=model, model_name=model_name, num_epochs=epochs, optimizer=optimizer, device=device, project_name="DLMI_PROJECT", criterion_segmentation=criterion_s, criterion_classification=criterion_c)
trainer.train(train_loader, val_loader)