Skip to content
This repository was archived by the owner on Dec 2, 2022. It is now read-only.

Commit 725af71

Browse files
TanjaBayerTanja Bayer
and
Tanja Bayer
authoredMay 9, 2022
Enable package to load model from local path (#53)
* Use headless version of opencv * Provide possibility to load net from local path * Remove headless again for merge to official repo Co-authored-by: Tanja Bayer <tanja.bayer@widas.de>
1 parent c856736 commit 725af71

File tree

2 files changed

+10
-6
lines changed

2 files changed

+10
-6
lines changed
 

‎.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
result*
66
weights*
77
.vscode
8+
.pypirc
89

910
# Byte-compiled / optimized / DLL files
1011
__pycache__/

‎craft_text_detector/__init__.py

+9-6
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import absolute_import
22

33
import os
4+
from typing import Optional
45

56
import craft_text_detector.craft_utils as craft_utils
67
import craft_text_detector.file_utils as file_utils
@@ -44,6 +45,8 @@ def __init__(
4445
long_size=1280,
4546
refiner=True,
4647
crop_type="poly",
48+
weight_path_craft_net: Optional[str] = None,
49+
weight_path_refine_net: Optional[str] = None,
4750
):
4851
"""
4952
Arguments:
@@ -72,22 +75,22 @@ def __init__(
7275
self.crop_type = crop_type
7376

7477
# load craftnet
75-
self.load_craftnet_model()
78+
self.load_craftnet_model(weight_path_craft_net)
7679
# load refinernet if required
7780
if refiner:
78-
self.load_refinenet_model()
81+
self.load_refinenet_model(weight_path_refine_net)
7982

80-
def load_craftnet_model(self):
83+
def load_craftnet_model(self, weight_path: Optional[str] = None):
8184
"""
8285
Loads craftnet model
8386
"""
84-
self.craft_net = load_craftnet_model(self.cuda)
87+
self.craft_net = load_craftnet_model(self.cuda, weight_path=weight_path)
8588

86-
def load_refinenet_model(self):
89+
def load_refinenet_model(self, weight_path: Optional[str] = None):
8790
"""
8891
Loads refinenet model
8992
"""
90-
self.refine_net = load_refinenet_model(self.cuda)
93+
self.refine_net = load_refinenet_model(self.cuda, weight_path=weight_path)
9194

9295
def unload_craftnet_model(self):
9396
"""

0 commit comments

Comments
 (0)
This repository has been archived.