Skip to content

Commit e38bacb

Browse files
committed
for export
1 parent e074f5f commit e38bacb

File tree

6 files changed

+72
-20
lines changed

6 files changed

+72
-20
lines changed

deploy/py_infer/src/core/model/model.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -106,8 +106,8 @@ def warmup(self):
106106
height, width = hw_list[0]
107107
warmup_shape = [(*other_shape, height, width)] # Only single input
108108

109-
dummy_tensor = [np.random.randn(*shape).astype(dtype) for shape, dtype in zip(warmup_shape, self.input_dtype)]
110-
self.model.infer(dummy_tensor)
109+
# dummy_tensor = [np.random.randn(*shape).astype(dtype) for shape, dtype in zip(warmup_shape, self.input_dtype)]
110+
# self.model.infer(dummy_tensor)
111111

112112
def __del__(self):
113113
if hasattr(self, "model") and self.model:

mindocr/losses/det_loss.py

+10-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import logging
2+
import os
23
from math import pi
34
from typing import Tuple, Union
45

@@ -10,6 +11,8 @@
1011
__all__ = ["DBLoss", "PSEDiceLoss", "EASTLoss", "FCELoss"]
1112
_logger = logging.getLogger(__name__)
1213

14+
OFFLINE_MODE = os.getenv("OFFLINE_MODE", None)
15+
1316

1417
class DBLoss(nn.LossBase):
1518
"""
@@ -165,7 +168,13 @@ def construct(self, pred: Tensor, gt: Tensor, mask: Tensor) -> Tensor:
165168
neg_loss = (loss * negative).view(loss.shape[0], -1)
166169

167170
neg_vals, _ = ops.sort(neg_loss)
168-
neg_index = ops.stack((mnp.arange(loss.shape[0]), neg_vals.shape[1] - neg_count), axis=1)
171+
172+
if OFFLINE_MODE is None:
173+
neg_index = ops.stack((mnp.arange(loss.shape[0]), neg_vals.shape[1] - neg_count), axis=1)
174+
else:
175+
neg_index = ops.stack(
176+
(ops.arange(loss.shape[0], dtype=neg_count.dtype), neg_vals.shape[1] - neg_count), axis=1
177+
)
169178
min_neg_score = ops.expand_dims(ops.gather_nd(neg_vals, neg_index), axis=1)
170179

171180
neg_loss_mask = (neg_loss >= min_neg_score).astype(ms.float32) # filter values less than top k

mindocr/losses/rec_loss.py

+13-2
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import os
2+
13
import numpy as np
24

35
import mindspore as ms
@@ -6,6 +8,8 @@
68

79
__all__ = ["CTCLoss", "AttentionLoss", "VisionLANLoss"]
810

11+
OFFLINE_MODE = os.getenv("OFFLINE_MODE", None)
12+
913

1014
class CTCLoss(LossBase):
1115
"""
@@ -147,14 +151,21 @@ class AttentionLoss(LossBase):
147151
def __init__(self, reduction: str = "mean", ignore_index: int = 0) -> None:
148152
super().__init__()
149153
# ignore <GO> symbol, assume it is placed at 0th index
150-
self.criterion = nn.CrossEntropyLoss(reduction=reduction, ignore_index=ignore_index)
154+
if OFFLINE_MODE is None:
155+
self.criterion = nn.CrossEntropyLoss(reduction=reduction, ignore_index=ignore_index)
156+
else:
157+
self.reduction = reduction
158+
self.ignore_index = ignore_index
151159

152160
def construct(self, logits: Tensor, labels: Tensor) -> Tensor:
153161
labels = labels[:, 1:] # without <GO> symbol
154162
num_classes = logits.shape[-1]
155163
logits = ops.reshape(logits, (-1, num_classes))
156164
labels = ops.reshape(labels, (-1,))
157-
return self.criterion(logits, labels)
165+
if OFFLINE_MODE is None:
166+
return self.criterion(logits, labels)
167+
else:
168+
return ops.cross_entropy(logits, labels, reduction=self.reduction, ignore_index=self.ignore_index)
158169

159170

160171
class SARLoss(LossBase):

mindocr/models/necks/fpn.py

+24-10
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import os
12
from typing import List, Tuple
23

34
from mindspore import Tensor, nn, ops
@@ -7,14 +8,20 @@
78
from ..utils.attention_cells import SEModule
89
from .asf import AdaptiveScaleFusion
910

11+
OFFLINE_MODE = os.getenv("OFFLINE_MODE", None)
1012

11-
def _resize_nn(x: Tensor, scale: int = 0, shape: Tuple[int] = None):
12-
if scale == 1 or shape == x.shape[2:]:
13-
return x
1413

15-
if scale:
16-
shape = (x.shape[2] * scale, x.shape[3] * scale)
17-
return ops.ResizeNearestNeighbor(shape)(x)
14+
if OFFLINE_MODE is None:
15+
def _resize_nn(x: Tensor, scale: int = 0, shape: Tuple[int] = None):
16+
if scale == 1 or shape == x.shape[2:]:
17+
return x
18+
19+
if scale:
20+
shape = (x.shape[2] * scale, x.shape[3] * scale)
21+
return ops.ResizeNearestNeighbor(shape)(x)
22+
else:
23+
def _resize_nn(x: Tensor, shape: Tensor):
24+
return ops.ResizeNearestNeighborV2()(x, shape)
1825

1926

2027
class FPN(nn.Cell):
@@ -64,11 +71,18 @@ def construct(self, features: List[Tensor]) -> Tensor:
6471
for i, uc_op in enumerate(self.unify_channels):
6572
features[i] = uc_op(features[i])
6673

67-
for i in range(2, -1, -1):
68-
features[i] += _resize_nn(features[i + 1], shape=features[i].shape[2:])
74+
if OFFLINE_MODE is None:
75+
for i in range(2, -1, -1):
76+
features[i] += _resize_nn(features[i + 1], shape=features[i].shape[2:])
77+
78+
for i, out in enumerate(self.out):
79+
features[i] = _resize_nn(out(features[i]), shape=features[0].shape[2:])
80+
else:
81+
for i in range(2, -1, -1):
82+
features[i] += _resize_nn(features[i + 1], shape=ops.dyn_shape(features[i])[2:])
6983

70-
for i, out in enumerate(self.out):
71-
features[i] = _resize_nn(out(features[i]), shape=features[0].shape[2:])
84+
for i, out in enumerate(self.out):
85+
features[i] = _resize_nn(out(features[i]), shape=ops.dyn_shape(features[0])[2:])
7286

7387
return self.fuse(features[::-1]) # matching the reverse order of the original work
7488

mindocr/models/transforms/tps_spatial_transformer.py

+12-2
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import itertools
2+
import os
23
from typing import Optional, Tuple
34

45
import numpy as np
@@ -8,6 +9,8 @@
89
import mindspore.ops as ops
910
from mindspore import Tensor
1011

12+
OFFLINE_MODE = os.getenv("OFFLINE_MODE", None)
13+
1114

1215
def grid_sample(input: Tensor, grid: Tensor, canvas: Optional[Tensor] = None) -> Tensor:
1316
output = ops.grid_sample(input, grid)
@@ -111,15 +114,22 @@ def __init__(
111114
self.target_coordinate_repr = Tensor(target_coordinate_repr, dtype=ms.float32)
112115
self.target_control_points = Tensor(target_control_points, dtype=ms.float32)
113116

117+
if OFFLINE_MODE is not None:
118+
self.matmul = ops.BatchMatMul()
119+
114120
def construct(
115121
self, input: Tensor, source_control_points: Tensor
116122
) -> Tuple[Tensor, Tensor]:
117123
batch_size = ops.shape(source_control_points)[0]
118124

119125
padding_matrix = ops.tile(self.padding_matrix, (batch_size, 1, 1))
120126
Y = ops.concat([source_control_points, padding_matrix], axis=1)
121-
mapping_matrix = ops.matmul(self.inverse_kernel, Y)
122-
source_coordinate = ops.matmul(self.target_coordinate_repr, mapping_matrix)
127+
if OFFLINE_MODE is None:
128+
mapping_matrix = ops.matmul(self.inverse_kernel, Y)
129+
source_coordinate = ops.matmul(self.target_coordinate_repr, mapping_matrix)
130+
else:
131+
mapping_matrix = self.matmul(self.inverse_kernel[None, ...], Y)
132+
source_coordinate = self.matmul(self.target_coordinate_repr[None, ...], mapping_matrix)
123133
grid = ops.reshape(
124134
source_coordinate,
125135
(-1, self.target_height, self.target_width, 2),

mindocr/models/utils/attention_cells.py

+11-3
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import os
12
from typing import Optional, Tuple
23

34
import numpy as np
@@ -9,6 +10,8 @@
910

1011
__all__ = ["MultiHeadAttention", "PositionwiseFeedForward", "PositionalEncoding", "SEModule"]
1112

13+
OFFLINE_MODE = os.getenv("OFFLINE_MODE", None)
14+
1215

1316
class MultiHeadAttention(nn.Cell):
1417
def __init__(
@@ -108,9 +111,14 @@ def __init__(
108111
self.pe = Tensor(pe, dtype=ms.float32)
109112

110113
def construct(self, input_tensor: Tensor) -> Tensor:
111-
input_tensor = (
112-
input_tensor + self.pe[:, : input_tensor.shape[1]]
113-
) # pe 1 5000 512
114+
if OFFLINE_MODE is None:
115+
input_tensor = (
116+
input_tensor + self.pe[:, : input_tensor.shape[1]]
117+
) # pe 1 5000 512
118+
else:
119+
input_tensor = (
120+
input_tensor + self.pe[:, : ops.dyn_shape(input_tensor)[1]]
121+
) # pe 1 5000 512
114122
return self.dropout(input_tensor)
115123

116124

0 commit comments

Comments
 (0)