|
1 | 1 | from __future__ import absolute_import
|
2 | 2 |
|
3 | 3 | import os
|
| 4 | +from typing import Optional |
4 | 5 |
|
5 | 6 | import craft_text_detector.craft_utils as craft_utils
|
6 | 7 | import craft_text_detector.file_utils as file_utils
|
@@ -44,6 +45,8 @@ def __init__(
|
44 | 45 | long_size=1280,
|
45 | 46 | refiner=True,
|
46 | 47 | crop_type="poly",
|
| 48 | + weight_path_craft_net: Optional[str] = None, |
| 49 | + weight_path_refine_net: Optional[str] = None, |
47 | 50 | ):
|
48 | 51 | """
|
49 | 52 | Arguments:
|
@@ -72,22 +75,22 @@ def __init__(
|
72 | 75 | self.crop_type = crop_type
|
73 | 76 |
|
74 | 77 | # load craftnet
|
75 |
| - self.load_craftnet_model() |
| 78 | + self.load_craftnet_model(weight_path_craft_net) |
76 | 79 | # load refinernet if required
|
77 | 80 | if refiner:
|
78 |
| - self.load_refinenet_model() |
| 81 | + self.load_refinenet_model(weight_path_refine_net) |
79 | 82 |
|
80 |
| - def load_craftnet_model(self): |
| 83 | + def load_craftnet_model(self, weight_path: Optional[str] = None): |
81 | 84 | """
|
82 | 85 | Loads craftnet model
|
83 | 86 | """
|
84 |
| - self.craft_net = load_craftnet_model(self.cuda) |
| 87 | + self.craft_net = load_craftnet_model(self.cuda, weight_path=weight_path) |
85 | 88 |
|
86 |
| - def load_refinenet_model(self): |
| 89 | + def load_refinenet_model(self, weight_path: Optional[str] = None): |
87 | 90 | """
|
88 | 91 | Loads refinenet model
|
89 | 92 | """
|
90 |
| - self.refine_net = load_refinenet_model(self.cuda) |
| 93 | + self.refine_net = load_refinenet_model(self.cuda, weight_path=weight_path) |
91 | 94 |
|
92 | 95 | def unload_craftnet_model(self):
|
93 | 96 | """
|
|
0 commit comments