Skip to content

Commit

Permalink
Fix training issue
Browse files Browse the repository at this point in the history
  • Loading branch information
jrrodri committed Sep 23, 2024
1 parent ce4b794 commit abd77a4
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 57 deletions.
19 changes: 0 additions & 19 deletions abraia/classify.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,32 +64,13 @@ def create_model(class_names, pretrained=True):
return model


def save_model(path, model, device='cpu'):
model.to(device)
src = os.path.join(tempdir, path)
os.makedirs(os.path.dirname(src), exist_ok=True)
torch.save(model.state_dict(), src)
multiple.upload_file(src, path)


def load_model(path, class_names):
dest = multiple.cache_file(path)
model = create_model(class_names, pretrained=False)
model.load_state_dict(torch.load(dest))
return model


def export_onnx(path, model, device='cpu'):
model.to(device)
dummy_input = torch.randn(1, 3, 224, 224)
src = os.path.join(tempdir, path)
os.makedirs(os.path.dirname(src), exist_ok=True)
torch.onnx.export(model, dummy_input, src, export_params=True, opset_version=10, do_constant_folding=True, input_names=['input'], output_names=['output'])
onnx_model = onnx.load(src)
onnx.checker.check_model(onnx_model)
multiple.upload_file(src, path)


transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
Expand Down
41 changes: 3 additions & 38 deletions notebooks/training.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -312,41 +312,6 @@
" sorted_items = sorted(items, key=os.path.getctime)\n",
" return sorted_items\n",
"\n",
"def build_model_name(model_name, task):\n",
" if task == 'segment':\n",
" model_name = f\"{model_name}-seg\"\n",
" if task == 'classify':\n",
" model_name = f\"{model_name}-cls\"\n",
" return model_name\n",
"\n",
"def train_model(dataset, task, batch=32, epochs=100, imgsz=640):\n",
" model_name = build_model_name('yolov8n', task)\n",
" model = YOLO(f\"{model_name}.pt\", verbose=False)\n",
" data = f\"{dataset}\" if task == 'classify' else f\"{dataset}/data.yaml\"\n",
" results = model.train(data=data, batch=batch, epochs=epochs, imgsz=imgsz)\n",
" metrics = model.val(data=data)\n",
" return model, model_name\n",
"\n",
"def save_model(model, model_name, dataset, task, classes, imgsz=640):\n",
" model_src = model.export(format=\"onnx\", device=\"cpu\")\n",
" multiple.upload_file(model_src, f\"{dataset}/{model_name}.onnx\")\n",
" multiple.save_json(f\"{dataset}/{model_name}.json\", {'task': task, 'inputShape': [1, 3, imgsz, imgsz], 'classes': classes})\n",
"\n",
"def run_model(model, src, task='segment'):\n",
" objects = []\n",
" results = model.predict(src, verbose=False)[0]\n",
" if results:\n",
" for box, mask in zip(results.boxes, results.masks):\n",
" class_id = int(box.cls)\n",
" label = results.names[class_id]\n",
" confidence = float(box.conf)\n",
" x1, y1, x2, y2 = box.xyxy.squeeze().tolist()\n",
" object = {'label': label, 'confidence': confidence, 'color': detect.get_color(class_id), 'box': [x1, y1, x2 - x1, y2 - y1]}\n",
" if task == 'segment':\n",
" object['polygon'] = [(x, y) for x, y in mask.xy[0]]\n",
" objects.append(object)\n",
" return objects\n",
"\n",
"def plot_results(src, results):\n",
" im = Image.open(src).convert('RGB')\n",
" detect.render_results(im, results)\n",
Expand Down Expand Up @@ -416,8 +381,8 @@
" label_status.value = 'Creating dataset...'\n",
" training.create_dataset(dataset, task, classes)\n",
" label_status.value = 'Training model...'\n",
" model, model_name = train_model(dataset, task, epochs=epochs, imgsz=imgz)\n",
" save_model(model, model_name, dataset, task, classes, imgsz=imgz)\n",
" model, model_name = training.train_model(dataset, task, epochs=epochs, imgsz=imgz)\n",
" training.save_model(model, model_name, dataset, task, classes, imgsz=imgz)\n",
" label_status.value = 'Model saved.'\n",
" folders = sorted_folders(f\"runs/{task}/\")\n",
" folder = folders[-1]\n",
Expand All @@ -428,7 +393,7 @@
" display(show_image(filename=f\"{folder}/val_batch0_pred.jpg\"))\n",
" # Random test...\n",
" src = glob.glob(f\"{dataset}/test/*/*.png\")[0]\n",
" results = run_model(model, src)\n",
" results = training.run_model(model, src)\n",
" plot_results(src, results)\n",
" button_train.disabled = False\n",
"\n",
Expand Down

0 comments on commit abd77a4

Please sign in to comment.