Skip to content

Commit 3e8d4dd

Browse files
authored
[Model] Optimize image preprocess for vision model (#2981)
This PR improves the performance of image preprocessing for vision model. Generally, image storage formats include NCHW and NHWC, with NCHW being more suitable for GPU-based computations. Additionally, performance is enhanced by binding the for loop to GPU hardware threads.
1 parent 0034e3c commit 3e8d4dd

File tree

3 files changed

+148
-101
lines changed

3 files changed

+148
-101
lines changed

python/mlc_llm/model/llava/llava_model.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -155,9 +155,12 @@ def embed(self, input_ids: Tensor) -> Tensor:
155155
return self.language_model.embed(input_ids)
156156

157157
def image_preprocess(self, pixel_values: Tensor) -> Tensor:
158-
# pixel_values shape is NHWC
158+
pixel_values = permute_dims(pixel_values, axes=(0, 3, 1, 2)) # NHWC -> NCHW
159159
pixel_values = self.image_processor.resize(
160-
pixel_values, {"shortest_edge": self.config.vision_config.image_size}
160+
pixel_values,
161+
{
162+
"shortest_edge": self.config.vision_config.image_size,
163+
},
161164
)
162165
pixel_values = self.image_processor.crop(
163166
pixel_values,
@@ -168,7 +171,6 @@ def image_preprocess(self, pixel_values: Tensor) -> Tensor:
168171
)
169172
pixel_values = self.image_processor.rescale(pixel_values)
170173
pixel_values = self.image_processor.normalize(pixel_values)
171-
pixel_values = permute_dims(pixel_values, axes=(0, 3, 1, 2)) # NHWC -> NCHW
172174
return pixel_values
173175

174176
def image_embed(self, pixel_values: Tensor) -> Tensor:

python/mlc_llm/model/phi3v/phi3v_model.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,7 @@ def embed(self, input_ids: Tensor):
219219

220220
# pylint: disable=protected-access
221221
def image_preprocess(self, pixel_values: Tensor, num_crops=16) -> Tensor:
222+
pixel_values = op.permute_dims(pixel_values, axes=(0, 3, 1, 2)) # NHWC -> NCHW
222223
pixel_values = self.image_processor.resize(pixel_values, params={"hd_transform": 336})
223224
new_h = tir.Var("new_h", "int64")
224225
new_w = tir.Var("new_w", "int64")
@@ -228,7 +229,7 @@ def image_preprocess(self, pixel_values: Tensor, num_crops=16) -> Tensor:
228229
.match_cast(
229230
pixel_values._expr,
230231
relax.TensorStructInfo(
231-
[pixel_values.shape[0], new_h, new_w, pixel_values.shape[3]], pixel_values.dtype
232+
[pixel_values.shape[0], pixel_values.shape[1], new_h, new_w], pixel_values.dtype
232233
),
233234
),
234235
"pixel_values",
@@ -246,16 +247,14 @@ def image_preprocess(self, pixel_values: Tensor, num_crops=16) -> Tensor:
246247
.match_cast(
247248
global_image._expr,
248249
relax.TensorStructInfo(
249-
[global_image.shape[0], 336, 336, global_image.shape[3]], global_image.dtype
250+
[global_image.shape[0], global_image.shape[1], 336, 336], global_image.dtype
250251
),
251252
),
252253
"global_image",
253254
)
254255

255-
global_image = op.permute_dims(global_image, axes=(0, 3, 1, 2))
256-
n, h, w, c = pixel_values.shape # pylint: disable=unused-variable
256+
n, c, h, w = pixel_values.shape # pylint: disable=unused-variable
257257
assert isinstance(h, tir.Mul) and isinstance(h.b, tir.IntImm) and h.b.value == 336
258-
pixel_values = op.permute_dims(pixel_values, axes=(0, 3, 1, 2)) # NHWC -> NCHW
259258
pixel_values = op.reshape(pixel_values, shape=(1, 3, h.a, 336, w // 336, 336))
260259
pixel_values = op.permute_dims(pixel_values, axes=(0, 2, 4, 1, 3, 5))
261260
pixel_values = op.reshape(pixel_values, shape=(-1, 3, 336, 336))

python/mlc_llm/model/vision/image_processing.py

Lines changed: 139 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -7,35 +7,49 @@
77
from tvm.script import tir as T
88

99

10-
def _var(dtype):
11-
return T.alloc_buffer((1,), dtype, scope="local")
10+
def _var(dtype, size=1):
11+
return T.alloc_buffer((size,), dtype, scope="local")
1212

1313

1414
# pylint: disable=invalid-name,missing-docstring,no-else-return,too-many-locals,useless-parent-delegation
1515
class ImageProcessor(Module):
1616
def __init__(self):
1717
super().__init__()
1818

19-
def resize(self, image: Tensor, params):
19+
# pylint: disable=dangerous-default-value
20+
def apply_schedule(self, sch, block, bdx=32, tile=[32, 32]):
21+
loop_x, loop_y = sch.get_loops(block)[-2:]
22+
xo, xi = sch.split(loop_x, factors=[tile[0], None])
23+
yo, yi = sch.split(loop_y, factors=[tile[1], None])
24+
sch.reorder(xo, yo, xi, yi)
25+
t = sch.fuse(xo, yo)
26+
ty, tx = sch.split(t, factors=[None, bdx])
27+
sch.bind(ty, "threadIdx.y")
28+
sch.bind(tx, "threadIdx.x")
29+
30+
def resize(self, image: Tensor, params): # image layout:NCHW
31+
assert 4 == image.ndim, "image should be 4D data tensor"
32+
assert 3 == image.shape[1], "image layout should be NCHW"
33+
2034
def get_output_image_size(image: Tensor):
21-
if 4 == image.ndim:
22-
h = image.shape[1]
23-
w = image.shape[2]
24-
elif 3 == image.ndim:
25-
h = image.shape[0]
26-
w = image.shape[1]
27-
else:
28-
assert False, "not supported image shape"
35+
h = image.shape[2]
36+
w = image.shape[3]
2937

3038
if "height" in params and "width" in params:
3139
return (params["height"], params["width"])
3240
elif "shortest_edge" in params:
33-
short = tir.Select(w > h, w, h)
34-
long = tir.Select(w > h, h, w)
41+
short = tir.Select(w < h, w, h)
42+
long = tir.Select(w > h, w, h)
3543
requested_new_short = params["shortest_edge"]
3644
new_short, new_long = tir.generic.cast(
3745
requested_new_short, "int64"
38-
), tir.generic.cast(requested_new_short * tir.div(long, short), "int64")
46+
), tir.generic.cast(
47+
requested_new_short
48+
* tir.div(
49+
tir.generic.cast(long, "float32"), tir.generic.cast(short, "float32")
50+
),
51+
"int64",
52+
)
3953
ret_h = tir.Select(w <= h, new_long, new_short)
4054
ret_w = tir.Select(w <= h, new_short, new_long)
4155
return (ret_h, ret_w)
@@ -63,14 +77,15 @@ def get_output_image_size(image: Tensor):
6377
assert False, "not supported resize parameter"
6478

6579
(new_h, new_w) = get_output_image_size(image)
66-
if 3 == image.ndim:
67-
image = op.unsqueeze(image, 0)
68-
out = op.interpolate(image, (new_h, new_w), data_layout="NHWC", mode="bicubic")
80+
out = op.interpolate(image, (new_h, new_w), data_layout="NCHW", mode="bicubic")
6981
return out
7082

7183
# pylint: disable=too-many-arguments,too-many-locals
7284
def crop(self, image: Tensor, crop_size):
73-
def create_crop_func(dtype):
85+
assert 4 == image.ndim, "image should be 4D data tensor"
86+
assert 3 == image.shape[1], "image layout should be NCHW"
87+
88+
def create_crop_func(dtype): # , top, bottom, left, right):
7489
@T.prim_func
7590
def crop_func(
7691
image: T.handle,
@@ -82,59 +97,70 @@ def crop_func(
8297
):
8398
T.func_attr({"op_pattern": 8, "tir.noalias": True, "tir.is_scheduled": 1})
8499
n, c, h, w = T.int64(), T.int64(), T.int64(), T.int64()
85-
image_buf = T.match_buffer(image, (n, h, w, c), dtype=dtype)
86-
out_buf = T.match_buffer(out, (n, bottom - top, right - left, c), dtype=dtype)
87-
with T.block("root"):
88-
for n_idx in T.thread_binding(n, thread="blockIdx.x"):
89-
for h_idx in range((bottom - top)):
90-
for w_idx in range((right - left)):
91-
for c_idx in range(c):
92-
with T.block("compute"):
93-
T.writes(out_buf[n_idx, h_idx, w_idx, c_idx])
94-
out_buf[n_idx, h_idx, w_idx, c_idx] = image_buf[
95-
n_idx, h_idx + top, w_idx + left, c_idx
96-
]
97-
98-
return crop_func
99-
100-
n, orig_height, orig_width, c = image.shape
101-
assert n == 1
100+
image_buf = T.match_buffer(image, (n, c, h, w), dtype=dtype)
101+
out_buf = T.match_buffer(out, (n, c, bottom - top, right - left), dtype=dtype)
102+
out_h = bottom - top
103+
out_w = right - left
104+
for n_idx in T.thread_binding(n, thread="blockIdx.x"):
105+
for c_idx in T.thread_binding(c, thread="blockIdx.y"):
106+
for h_idx, w_idx in T.grid(out_h, out_w):
107+
with T.block("crop"):
108+
if (h_idx + T.int64(top)) < h and (w_idx + T.int64(left)) < w:
109+
T.writes(out_buf[n_idx, c_idx, h_idx, w_idx])
110+
T.reads(image_buf[n_idx, c_idx, h_idx + top, w_idx + left])
111+
out_buf[n_idx, c_idx, h_idx, w_idx] = image_buf[
112+
n_idx, c_idx, h_idx + top, w_idx + left
113+
]
114+
115+
sch = tir.Schedule(crop_func)
116+
self.apply_schedule(sch, sch.get_block("crop"))
117+
return sch.mod["main"].with_attr("tir.is_scheduled", 1)
118+
119+
n, c, orig_height, orig_width = image.shape
102120
crop_height = crop_size["height"]
103121
crop_width = crop_size["width"]
104122

105123
top = (orig_height - crop_height) // 2
106-
bottom = top + crop_height
124+
bottom = orig_height - top
125+
107126
left = (orig_width - crop_width) // 2
108-
right = left + crop_width
109-
new_height = bottom - top
110-
new_width = right - left
127+
right = orig_width - left
128+
111129
out = op.tensor_ir_op(
112130
create_crop_func(image.dtype),
113131
"crop",
114132
[image, top, bottom, left, right],
115-
[Tensor.placeholder([n, new_height, new_width, c], image.dtype)],
133+
[Tensor.placeholder([n, c, crop_height, crop_width], image.dtype)],
116134
)
117135
return out
118136

119137
def rescale(self, image: Tensor, rescale_factor=1 / 255.0, o_dtype="float32"):
138+
assert 4 == image.ndim, "image should be 4D data tensor"
139+
assert 3 == image.shape[1], "image layout should be NCHW"
140+
120141
def create_rescale_func(rescale_factor, dtype, o_dtype):
121142
@T.prim_func
122143
def rescale_func(image: T.handle, out: T.handle):
123144
T.func_attr({"op_pattern": 8, "tir.noalias": True, "tir.is_scheduled": 1})
124145
n, c, h, w = T.int64(), T.int64(), T.int64(), T.int64()
125-
image_buf = T.match_buffer(image, (n, h, w, c), dtype=dtype)
126-
out_buf = T.match_buffer(out, (n, h, w, c), dtype=o_dtype)
146+
image_buf = T.match_buffer(image, (n, c, h, w), dtype=dtype)
147+
out_buf = T.match_buffer(out, (n, c, h, w), dtype=o_dtype)
148+
127149
for n_idx in T.thread_binding(n, thread="blockIdx.x"):
128-
for h_idx, w_idx, c_idx in T.grid(h, w, c):
129-
with T.block("compute"):
130-
T.reads(image_buf[n_idx, h_idx, w_idx, c_idx])
131-
T.writes(out_buf[n_idx, h_idx, w_idx, c_idx])
132-
out_buf[n_idx, h_idx, w_idx, c_idx] = (
133-
T.cast(image_buf[n_idx, h_idx, w_idx, c_idx], o_dtype)
134-
* rescale_factor
135-
)
150+
for c_idx in T.thread_binding(c, thread="blockIdx.y"):
151+
for h_idx, w_idx in T.grid(h, w):
152+
with T.block("rescale"):
153+
T.reads(image_buf[n_idx, c_idx, h_idx, w_idx])
154+
T.writes(out_buf[n_idx, c_idx, h_idx, w_idx])
155+
if h_idx < h and w_idx < w:
156+
out_buf[n_idx, c_idx, h_idx, w_idx] = (
157+
T.cast(image_buf[n_idx, c_idx, h_idx, w_idx], o_dtype)
158+
* rescale_factor
159+
)
136160

137-
return rescale_func
161+
sch = tir.Schedule(rescale_func)
162+
self.apply_schedule(sch, sch.get_block("rescale"))
163+
return sch.mod["main"].with_attr("tir.is_scheduled", 1)
138164

139165
out = op.tensor_ir_op(
140166
create_rescale_func(rescale_factor, image.dtype, o_dtype),
@@ -145,35 +171,44 @@ def rescale_func(image: T.handle, out: T.handle):
145171
return out
146172

147173
def normalize(self, image: Tensor, o_dtype="float32"):
174+
assert 4 == image.ndim, "image should be 4D data tensor"
175+
assert 3 == image.shape[1], "image layout should be NCHW"
176+
148177
def create_normalize_func(dtype, o_dtype):
149178
@T.prim_func
150179
def normalize_func(image: T.handle, out: T.handle):
151-
T.func_attr({"op_pattern": 8, "tir.noalias": True, "tir.is_scheduled": 1})
152180
n, c, h, w = T.int64(), T.int64(), T.int64(), T.int64()
153-
image_buf = T.match_buffer(image, (n, h, w, c), dtype=dtype)
154-
out_buf = T.match_buffer(out, (n, h, w, c), dtype=o_dtype)
155-
mean = _var(o_dtype)
156-
stddev = _var(o_dtype)
181+
image_buf = T.match_buffer(image, (n, c, h, w), dtype=dtype)
182+
out_buf = T.match_buffer(out, (n, c, h, w), dtype=o_dtype)
183+
mean = _var(o_dtype, 3)
184+
stddev = _var(o_dtype, 3)
185+
157186
for n_idx in T.thread_binding(n, thread="blockIdx.x"):
158-
for h_idx, w_idx, c_idx in T.grid(h, w, c):
159-
with T.block("compute"):
160-
T.reads(image_buf[n_idx, h_idx, w_idx, c_idx])
161-
T.writes(out_buf[n_idx, h_idx, w_idx, c_idx])
162-
if 0 == c_idx:
163-
mean[0] = 0.48145466
164-
stddev[0] = 0.26862954
165-
elif 1 == c_idx:
166-
mean[0] = 0.4578275
167-
stddev[0] = 0.26130258
168-
elif 2 == c_idx:
169-
mean[0] = 0.40821073
170-
stddev[0] = 0.27577711
171-
172-
out_buf[n_idx, h_idx, w_idx, c_idx] = (
173-
T.cast(image_buf[n_idx, h_idx, w_idx, c_idx], o_dtype) - mean[0]
174-
) / stddev[0]
175-
176-
return normalize_func
187+
for c_idx in T.thread_binding(c, thread="blockIdx.y"):
188+
for h_idx, w_idx in T.grid(h, w):
189+
with T.block("normalize"):
190+
T.reads(
191+
image_buf[n_idx, c_idx, h_idx, w_idx],
192+
mean[c_idx],
193+
stddev[c_idx],
194+
)
195+
T.writes(out_buf[n_idx, c_idx, h_idx, w_idx])
196+
with T.init():
197+
mean[0] = 0.48145466
198+
stddev[0] = 0.26862954
199+
mean[1] = 0.4578275
200+
stddev[1] = 0.26130258
201+
mean[2] = 0.40821073
202+
stddev[2] = 0.27577711
203+
if h_idx < h and w_idx < w:
204+
out_buf[n_idx, c_idx, h_idx, w_idx] = (
205+
T.cast(image_buf[n_idx, c_idx, h_idx, w_idx], o_dtype)
206+
- mean[c_idx]
207+
) / stddev[c_idx]
208+
209+
sch = tir.Schedule(normalize_func)
210+
self.apply_schedule(sch, sch.get_block("normalize"))
211+
return sch.mod["main"].with_attr("tir.is_scheduled", 1)
177212

178213
out = op.tensor_ir_op(
179214
create_normalize_func(image.dtype, o_dtype),
@@ -184,40 +219,51 @@ def normalize_func(image: T.handle, out: T.handle):
184219
return out
185220

186221
def pad(self, image: Tensor, dtype="uint8"):
222+
assert 4 == image.ndim, "image should be 4D data tensor"
223+
assert 3 == image.shape[1], "image layout should be NCHW"
224+
187225
def create_pad_func(l, r, fill=255):
188226
@T.prim_func
189227
def pad_func(image: T.handle, out: T.handle, t: T.int64(), b: T.int64()):
190228
T.func_attr({"op_pattern": 8, "tir.noalias": True, "tir.is_scheduled": 1})
191229
n, c, h, w = T.int64(), T.int64(), T.int64(), T.int64()
192-
image_buf = T.match_buffer(image, (n, h, w, c), dtype=dtype)
193-
out_buf = T.match_buffer(out, (n, h + t + b, w + l + r, c), dtype=dtype)
230+
image_buf = T.match_buffer(image, (n, c, h, w), dtype=dtype)
231+
out_buf = T.match_buffer(out, (n, c, h + t + b, w + l + r), dtype=dtype)
232+
out_h = h + t + b
233+
out_w = w + l + r
194234

195235
for n_idx in T.thread_binding(n, thread="blockIdx.x"):
196-
for h_idx, w_idx, c_idx in T.grid(h + t + b, w + l + r, c):
197-
with T.block("compute"):
198-
T.reads(image_buf[n_idx, h_idx, w_idx, c_idx])
199-
T.writes(out_buf[n_idx, h_idx, w_idx, c_idx])
200-
if h_idx < t or h_idx > h + b or w_idx < l or w_idx > w + r:
201-
out_buf[n_idx, h_idx, w_idx, c_idx] = fill
202-
else:
203-
out_buf[n_idx, h_idx, w_idx, c_idx] = image_buf[
204-
n_idx, h_idx - t, w_idx - l, c_idx
205-
]
206-
207-
return pad_func
208-
209-
h = image.shape[1]
236+
for c_idx in T.thread_binding(c, thread="blockIdx.y"):
237+
for h_idx, w_idx in T.grid(out_h, out_w):
238+
with T.block("pad"):
239+
T.reads(image_buf[n_idx, c_idx, h_idx, w_idx])
240+
T.writes(out_buf[n_idx, c_idx, h_idx, w_idx])
241+
if h_idx < t or h_idx > h + b or w_idx < l or w_idx > w + r:
242+
out_buf[n_idx, c_idx, h_idx, w_idx] = fill
243+
else:
244+
out_buf[n_idx, c_idx, h_idx, w_idx] = image_buf[
245+
n_idx, c_idx, h_idx - t, w_idx - l
246+
]
247+
248+
sch = tir.Schedule(pad_func)
249+
self.apply_schedule(sch, sch.get_block("pad"))
250+
return sch.mod["main"].with_attr("tir.is_scheduled", 1)
251+
252+
h = image.shape[2]
210253
tar = tir.truncdiv(h + 335, 336) * 336
211254
t = tir.div(tar - h, 2)
212255
b = tar - h - t
213256
l = 0
214257
r = 0
215258

216-
n, h, w, c = image.shape
259+
n, c, h, w = image.shape
217260
out = op.tensor_ir_op(
218261
create_pad_func(l, r),
219262
"pad",
220263
[image, t, b],
221-
[Tensor.placeholder((n, tar, w, c), image.dtype)],
264+
[Tensor.placeholder((n, c, tar, w), image.dtype)],
222265
)
223266
return out
267+
268+
def preprocess(self, pixel_values):
269+
return pixel_values

0 commit comments

Comments
 (0)