7
7
from tvm .script import tir as T
8
8
9
9
10
- def _var (dtype ):
11
- return T .alloc_buffer ((1 ,), dtype , scope = "local" )
10
+ def _var (dtype , size = 1 ):
11
+ return T .alloc_buffer ((size ,), dtype , scope = "local" )
12
12
13
13
14
14
# pylint: disable=invalid-name,missing-docstring,no-else-return,too-many-locals,useless-parent-delegation
15
15
class ImageProcessor (Module ):
16
16
def __init__ (self ):
17
17
super ().__init__ ()
18
18
19
- def resize (self , image : Tensor , params ):
19
+ # pylint: disable=dangerous-default-value
20
+ def apply_schedule (self , sch , block , bdx = 32 , tile = [32 , 32 ]):
21
+ loop_x , loop_y = sch .get_loops (block )[- 2 :]
22
+ xo , xi = sch .split (loop_x , factors = [tile [0 ], None ])
23
+ yo , yi = sch .split (loop_y , factors = [tile [1 ], None ])
24
+ sch .reorder (xo , yo , xi , yi )
25
+ t = sch .fuse (xo , yo )
26
+ ty , tx = sch .split (t , factors = [None , bdx ])
27
+ sch .bind (ty , "threadIdx.y" )
28
+ sch .bind (tx , "threadIdx.x" )
29
+
30
+ def resize (self , image : Tensor , params ): # image layout:NCHW
31
+ assert 4 == image .ndim , "image should be 4D data tensor"
32
+ assert 3 == image .shape [1 ], "image layout should be NCHW"
33
+
20
34
def get_output_image_size (image : Tensor ):
21
- if 4 == image .ndim :
22
- h = image .shape [1 ]
23
- w = image .shape [2 ]
24
- elif 3 == image .ndim :
25
- h = image .shape [0 ]
26
- w = image .shape [1 ]
27
- else :
28
- assert False , "not supported image shape"
35
+ h = image .shape [2 ]
36
+ w = image .shape [3 ]
29
37
30
38
if "height" in params and "width" in params :
31
39
return (params ["height" ], params ["width" ])
32
40
elif "shortest_edge" in params :
33
- short = tir .Select (w > h , w , h )
34
- long = tir .Select (w > h , h , w )
41
+ short = tir .Select (w < h , w , h )
42
+ long = tir .Select (w > h , w , h )
35
43
requested_new_short = params ["shortest_edge" ]
36
44
new_short , new_long = tir .generic .cast (
37
45
requested_new_short , "int64"
38
- ), tir .generic .cast (requested_new_short * tir .div (long , short ), "int64" )
46
+ ), tir .generic .cast (
47
+ requested_new_short
48
+ * tir .div (
49
+ tir .generic .cast (long , "float32" ), tir .generic .cast (short , "float32" )
50
+ ),
51
+ "int64" ,
52
+ )
39
53
ret_h = tir .Select (w <= h , new_long , new_short )
40
54
ret_w = tir .Select (w <= h , new_short , new_long )
41
55
return (ret_h , ret_w )
@@ -63,14 +77,15 @@ def get_output_image_size(image: Tensor):
63
77
assert False , "not supported resize parameter"
64
78
65
79
(new_h , new_w ) = get_output_image_size (image )
66
- if 3 == image .ndim :
67
- image = op .unsqueeze (image , 0 )
68
- out = op .interpolate (image , (new_h , new_w ), data_layout = "NHWC" , mode = "bicubic" )
80
+ out = op .interpolate (image , (new_h , new_w ), data_layout = "NCHW" , mode = "bicubic" )
69
81
return out
70
82
71
83
# pylint: disable=too-many-arguments,too-many-locals
72
84
def crop (self , image : Tensor , crop_size ):
73
- def create_crop_func (dtype ):
85
+ assert 4 == image .ndim , "image should be 4D data tensor"
86
+ assert 3 == image .shape [1 ], "image layout should be NCHW"
87
+
88
+ def create_crop_func (dtype ): # , top, bottom, left, right):
74
89
@T .prim_func
75
90
def crop_func (
76
91
image : T .handle ,
@@ -82,59 +97,70 @@ def crop_func(
82
97
):
83
98
T .func_attr ({"op_pattern" : 8 , "tir.noalias" : True , "tir.is_scheduled" : 1 })
84
99
n , c , h , w = T .int64 (), T .int64 (), T .int64 (), T .int64 ()
85
- image_buf = T .match_buffer (image , (n , h , w , c ), dtype = dtype )
86
- out_buf = T .match_buffer (out , (n , bottom - top , right - left , c ), dtype = dtype )
87
- with T .block ("root" ):
88
- for n_idx in T .thread_binding (n , thread = "blockIdx.x" ):
89
- for h_idx in range ((bottom - top )):
90
- for w_idx in range ((right - left )):
91
- for c_idx in range (c ):
92
- with T .block ("compute" ):
93
- T .writes (out_buf [n_idx , h_idx , w_idx , c_idx ])
94
- out_buf [n_idx , h_idx , w_idx , c_idx ] = image_buf [
95
- n_idx , h_idx + top , w_idx + left , c_idx
96
- ]
97
-
98
- return crop_func
99
-
100
- n , orig_height , orig_width , c = image .shape
101
- assert n == 1
100
+ image_buf = T .match_buffer (image , (n , c , h , w ), dtype = dtype )
101
+ out_buf = T .match_buffer (out , (n , c , bottom - top , right - left ), dtype = dtype )
102
+ out_h = bottom - top
103
+ out_w = right - left
104
+ for n_idx in T .thread_binding (n , thread = "blockIdx.x" ):
105
+ for c_idx in T .thread_binding (c , thread = "blockIdx.y" ):
106
+ for h_idx , w_idx in T .grid (out_h , out_w ):
107
+ with T .block ("crop" ):
108
+ if (h_idx + T .int64 (top )) < h and (w_idx + T .int64 (left )) < w :
109
+ T .writes (out_buf [n_idx , c_idx , h_idx , w_idx ])
110
+ T .reads (image_buf [n_idx , c_idx , h_idx + top , w_idx + left ])
111
+ out_buf [n_idx , c_idx , h_idx , w_idx ] = image_buf [
112
+ n_idx , c_idx , h_idx + top , w_idx + left
113
+ ]
114
+
115
+ sch = tir .Schedule (crop_func )
116
+ self .apply_schedule (sch , sch .get_block ("crop" ))
117
+ return sch .mod ["main" ].with_attr ("tir.is_scheduled" , 1 )
118
+
119
+ n , c , orig_height , orig_width = image .shape
102
120
crop_height = crop_size ["height" ]
103
121
crop_width = crop_size ["width" ]
104
122
105
123
top = (orig_height - crop_height ) // 2
106
- bottom = top + crop_height
124
+ bottom = orig_height - top
125
+
107
126
left = (orig_width - crop_width ) // 2
108
- right = left + crop_width
109
- new_height = bottom - top
110
- new_width = right - left
127
+ right = orig_width - left
128
+
111
129
out = op .tensor_ir_op (
112
130
create_crop_func (image .dtype ),
113
131
"crop" ,
114
132
[image , top , bottom , left , right ],
115
- [Tensor .placeholder ([n , new_height , new_width , c ], image .dtype )],
133
+ [Tensor .placeholder ([n , c , crop_height , crop_width ], image .dtype )],
116
134
)
117
135
return out
118
136
119
137
def rescale (self , image : Tensor , rescale_factor = 1 / 255.0 , o_dtype = "float32" ):
138
+ assert 4 == image .ndim , "image should be 4D data tensor"
139
+ assert 3 == image .shape [1 ], "image layout should be NCHW"
140
+
120
141
def create_rescale_func (rescale_factor , dtype , o_dtype ):
121
142
@T .prim_func
122
143
def rescale_func (image : T .handle , out : T .handle ):
123
144
T .func_attr ({"op_pattern" : 8 , "tir.noalias" : True , "tir.is_scheduled" : 1 })
124
145
n , c , h , w = T .int64 (), T .int64 (), T .int64 (), T .int64 ()
125
- image_buf = T .match_buffer (image , (n , h , w , c ), dtype = dtype )
126
- out_buf = T .match_buffer (out , (n , h , w , c ), dtype = o_dtype )
146
+ image_buf = T .match_buffer (image , (n , c , h , w ), dtype = dtype )
147
+ out_buf = T .match_buffer (out , (n , c , h , w ), dtype = o_dtype )
148
+
127
149
for n_idx in T .thread_binding (n , thread = "blockIdx.x" ):
128
- for h_idx , w_idx , c_idx in T .grid (h , w , c ):
129
- with T .block ("compute" ):
130
- T .reads (image_buf [n_idx , h_idx , w_idx , c_idx ])
131
- T .writes (out_buf [n_idx , h_idx , w_idx , c_idx ])
132
- out_buf [n_idx , h_idx , w_idx , c_idx ] = (
133
- T .cast (image_buf [n_idx , h_idx , w_idx , c_idx ], o_dtype )
134
- * rescale_factor
135
- )
150
+ for c_idx in T .thread_binding (c , thread = "blockIdx.y" ):
151
+ for h_idx , w_idx in T .grid (h , w ):
152
+ with T .block ("rescale" ):
153
+ T .reads (image_buf [n_idx , c_idx , h_idx , w_idx ])
154
+ T .writes (out_buf [n_idx , c_idx , h_idx , w_idx ])
155
+ if h_idx < h and w_idx < w :
156
+ out_buf [n_idx , c_idx , h_idx , w_idx ] = (
157
+ T .cast (image_buf [n_idx , c_idx , h_idx , w_idx ], o_dtype )
158
+ * rescale_factor
159
+ )
136
160
137
- return rescale_func
161
+ sch = tir .Schedule (rescale_func )
162
+ self .apply_schedule (sch , sch .get_block ("rescale" ))
163
+ return sch .mod ["main" ].with_attr ("tir.is_scheduled" , 1 )
138
164
139
165
out = op .tensor_ir_op (
140
166
create_rescale_func (rescale_factor , image .dtype , o_dtype ),
@@ -145,35 +171,44 @@ def rescale_func(image: T.handle, out: T.handle):
145
171
return out
146
172
147
173
def normalize (self , image : Tensor , o_dtype = "float32" ):
174
+ assert 4 == image .ndim , "image should be 4D data tensor"
175
+ assert 3 == image .shape [1 ], "image layout should be NCHW"
176
+
148
177
def create_normalize_func (dtype , o_dtype ):
149
178
@T .prim_func
150
179
def normalize_func (image : T .handle , out : T .handle ):
151
- T .func_attr ({"op_pattern" : 8 , "tir.noalias" : True , "tir.is_scheduled" : 1 })
152
180
n , c , h , w = T .int64 (), T .int64 (), T .int64 (), T .int64 ()
153
- image_buf = T .match_buffer (image , (n , h , w , c ), dtype = dtype )
154
- out_buf = T .match_buffer (out , (n , h , w , c ), dtype = o_dtype )
155
- mean = _var (o_dtype )
156
- stddev = _var (o_dtype )
181
+ image_buf = T .match_buffer (image , (n , c , h , w ), dtype = dtype )
182
+ out_buf = T .match_buffer (out , (n , c , h , w ), dtype = o_dtype )
183
+ mean = _var (o_dtype , 3 )
184
+ stddev = _var (o_dtype , 3 )
185
+
157
186
for n_idx in T .thread_binding (n , thread = "blockIdx.x" ):
158
- for h_idx , w_idx , c_idx in T .grid (h , w , c ):
159
- with T .block ("compute" ):
160
- T .reads (image_buf [n_idx , h_idx , w_idx , c_idx ])
161
- T .writes (out_buf [n_idx , h_idx , w_idx , c_idx ])
162
- if 0 == c_idx :
163
- mean [0 ] = 0.48145466
164
- stddev [0 ] = 0.26862954
165
- elif 1 == c_idx :
166
- mean [0 ] = 0.4578275
167
- stddev [0 ] = 0.26130258
168
- elif 2 == c_idx :
169
- mean [0 ] = 0.40821073
170
- stddev [0 ] = 0.27577711
171
-
172
- out_buf [n_idx , h_idx , w_idx , c_idx ] = (
173
- T .cast (image_buf [n_idx , h_idx , w_idx , c_idx ], o_dtype ) - mean [0 ]
174
- ) / stddev [0 ]
175
-
176
- return normalize_func
187
+ for c_idx in T .thread_binding (c , thread = "blockIdx.y" ):
188
+ for h_idx , w_idx in T .grid (h , w ):
189
+ with T .block ("normalize" ):
190
+ T .reads (
191
+ image_buf [n_idx , c_idx , h_idx , w_idx ],
192
+ mean [c_idx ],
193
+ stddev [c_idx ],
194
+ )
195
+ T .writes (out_buf [n_idx , c_idx , h_idx , w_idx ])
196
+ with T .init ():
197
+ mean [0 ] = 0.48145466
198
+ stddev [0 ] = 0.26862954
199
+ mean [1 ] = 0.4578275
200
+ stddev [1 ] = 0.26130258
201
+ mean [2 ] = 0.40821073
202
+ stddev [2 ] = 0.27577711
203
+ if h_idx < h and w_idx < w :
204
+ out_buf [n_idx , c_idx , h_idx , w_idx ] = (
205
+ T .cast (image_buf [n_idx , c_idx , h_idx , w_idx ], o_dtype )
206
+ - mean [c_idx ]
207
+ ) / stddev [c_idx ]
208
+
209
+ sch = tir .Schedule (normalize_func )
210
+ self .apply_schedule (sch , sch .get_block ("normalize" ))
211
+ return sch .mod ["main" ].with_attr ("tir.is_scheduled" , 1 )
177
212
178
213
out = op .tensor_ir_op (
179
214
create_normalize_func (image .dtype , o_dtype ),
@@ -184,40 +219,51 @@ def normalize_func(image: T.handle, out: T.handle):
184
219
return out
185
220
186
221
def pad (self , image : Tensor , dtype = "uint8" ):
222
+ assert 4 == image .ndim , "image should be 4D data tensor"
223
+ assert 3 == image .shape [1 ], "image layout should be NCHW"
224
+
187
225
def create_pad_func (l , r , fill = 255 ):
188
226
@T .prim_func
189
227
def pad_func (image : T .handle , out : T .handle , t : T .int64 (), b : T .int64 ()):
190
228
T .func_attr ({"op_pattern" : 8 , "tir.noalias" : True , "tir.is_scheduled" : 1 })
191
229
n , c , h , w = T .int64 (), T .int64 (), T .int64 (), T .int64 ()
192
- image_buf = T .match_buffer (image , (n , h , w , c ), dtype = dtype )
193
- out_buf = T .match_buffer (out , (n , h + t + b , w + l + r , c ), dtype = dtype )
230
+ image_buf = T .match_buffer (image , (n , c , h , w ), dtype = dtype )
231
+ out_buf = T .match_buffer (out , (n , c , h + t + b , w + l + r ), dtype = dtype )
232
+ out_h = h + t + b
233
+ out_w = w + l + r
194
234
195
235
for n_idx in T .thread_binding (n , thread = "blockIdx.x" ):
196
- for h_idx , w_idx , c_idx in T .grid (h + t + b , w + l + r , c ):
197
- with T .block ("compute" ):
198
- T .reads (image_buf [n_idx , h_idx , w_idx , c_idx ])
199
- T .writes (out_buf [n_idx , h_idx , w_idx , c_idx ])
200
- if h_idx < t or h_idx > h + b or w_idx < l or w_idx > w + r :
201
- out_buf [n_idx , h_idx , w_idx , c_idx ] = fill
202
- else :
203
- out_buf [n_idx , h_idx , w_idx , c_idx ] = image_buf [
204
- n_idx , h_idx - t , w_idx - l , c_idx
205
- ]
206
-
207
- return pad_func
208
-
209
- h = image .shape [1 ]
236
+ for c_idx in T .thread_binding (c , thread = "blockIdx.y" ):
237
+ for h_idx , w_idx in T .grid (out_h , out_w ):
238
+ with T .block ("pad" ):
239
+ T .reads (image_buf [n_idx , c_idx , h_idx , w_idx ])
240
+ T .writes (out_buf [n_idx , c_idx , h_idx , w_idx ])
241
+ if h_idx < t or h_idx > h + b or w_idx < l or w_idx > w + r :
242
+ out_buf [n_idx , c_idx , h_idx , w_idx ] = fill
243
+ else :
244
+ out_buf [n_idx , c_idx , h_idx , w_idx ] = image_buf [
245
+ n_idx , c_idx , h_idx - t , w_idx - l
246
+ ]
247
+
248
+ sch = tir .Schedule (pad_func )
249
+ self .apply_schedule (sch , sch .get_block ("pad" ))
250
+ return sch .mod ["main" ].with_attr ("tir.is_scheduled" , 1 )
251
+
252
+ h = image .shape [2 ]
210
253
tar = tir .truncdiv (h + 335 , 336 ) * 336
211
254
t = tir .div (tar - h , 2 )
212
255
b = tar - h - t
213
256
l = 0
214
257
r = 0
215
258
216
- n , h , w , c = image .shape
259
+ n , c , h , w = image .shape
217
260
out = op .tensor_ir_op (
218
261
create_pad_func (l , r ),
219
262
"pad" ,
220
263
[image , t , b ],
221
- [Tensor .placeholder ((n , tar , w , c ), image .dtype )],
264
+ [Tensor .placeholder ((n , c , tar , w ), image .dtype )],
222
265
)
223
266
return out
267
+
268
+ def preprocess (self , pixel_values ):
269
+ return pixel_values
0 commit comments