Skip to content

Commit

Permalink
Merge pull request #1250 from lemonviv/update-tedct-train
Browse files Browse the repository at this point in the history
Update the train file for TED CT Detection application
  • Loading branch information
nudles authored Dec 26, 2024
2 parents bc93719 + 71ad0e4 commit 30f06a2
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 131 deletions.
119 changes: 0 additions & 119 deletions examples/healthcare/application/TED_CT_Detection/model.py

This file was deleted.

35 changes: 23 additions & 12 deletions examples/healthcare/application/TED_CT_Detection/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,10 @@
from PIL import Image

import sys
sys.path.append("../../..")

sys.path.append(".")
print(sys.path)

import examples.cnn.model.cnn as cnn
from examples.cnn.data import cifar10
import model as cpl
from healthcare.data import cifar10
from healthcare.models import tedct_net


def accuracy(pred, target):
Expand Down Expand Up @@ -60,6 +57,7 @@ def resize_dataset(x, image_size):

def run(
local_rank,
dir_path,
max_epoch,
batch_size,
sgd,
Expand All @@ -68,18 +66,19 @@ def run(
dist_option="plain",
spars=None,
):
dev = device.create_cuda_gpu_on(local_rank)
# dev = device.create_cuda_gpu_on(local_rank)
dev = device.get_default_device()
dev.SetRandSeed(0)
np.random.seed(0)

train_x, train_y, val_x, val_y = cifar10.load()
train_x, train_y, val_x, val_y = cifar10.load(dir_path)

num_channels = train_x.shape[1]
data_size = np.prod(train_x.shape[1 : train_x.ndim]).item()
num_classes = (np.max(train_y) + 1).item()

backbone = cnn.create_model(num_channels=num_channels, num_classes=num_classes)
model = cpl.create_model(backbone, prototype_count=10, lamb=0.5, temp=10)
backbone = tedct_net.create_cnn_model(num_channels=num_channels, num_classes=num_classes)
model = tedct_net.create_model(backbone, prototype_count=10, lamb=0.5, temp=10)

if backbone.dimension == 4:
tx = tensor.Tensor(
Expand Down Expand Up @@ -139,6 +138,12 @@ def run(

if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Train a CPL model")
parser.add_argument('-dir',
'--dir-path',
default="/tmp/cifar-10-batches-py",
type=str,
help='the directory to store the dataset',
dest='dir_path')
parser.add_argument(
"-m",
"--max-epoch",
Expand Down Expand Up @@ -187,5 +192,11 @@ def run(

sgd = opt.SGD(lr=args.lr, momentum=0.9, weight_decay=1e-5)
run(
args.device_id, args.max_epoch, args.batch_size, sgd, args.graph, args.verbosity
)
args.device_id,
args.dir_path,
args.max_epoch,
args.batch_size,
sgd,
args.graph,
args.verbosity
)

0 comments on commit 30f06a2

Please sign in to comment.