Skip to content

Commit

Permalink
[Model] Optimize image preprocess for vision model (#2981)
Browse files Browse the repository at this point in the history
This PR improves the performance of image preprocessing for vision model.
Generally, image storage formats include NCHW and NHWC, with NCHW being
more suitable for GPU-based computations. Additionally, performance is
enhanced by binding the for loop to GPU hardware threads.
  • Loading branch information
mengshyu authored Oct 16, 2024
1 parent 0034e3c commit 3e8d4dd
Show file tree
Hide file tree
Showing 3 changed files with 148 additions and 101 deletions.
8 changes: 5 additions & 3 deletions python/mlc_llm/model/llava/llava_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,9 +155,12 @@ def embed(self, input_ids: Tensor) -> Tensor:
return self.language_model.embed(input_ids)

def image_preprocess(self, pixel_values: Tensor) -> Tensor:
# pixel_values shape is NHWC
pixel_values = permute_dims(pixel_values, axes=(0, 3, 1, 2)) # NHWC -> NCHW
pixel_values = self.image_processor.resize(
pixel_values, {"shortest_edge": self.config.vision_config.image_size}
pixel_values,
{
"shortest_edge": self.config.vision_config.image_size,
},
)
pixel_values = self.image_processor.crop(
pixel_values,
Expand All @@ -168,7 +171,6 @@ def image_preprocess(self, pixel_values: Tensor) -> Tensor:
)
pixel_values = self.image_processor.rescale(pixel_values)
pixel_values = self.image_processor.normalize(pixel_values)
pixel_values = permute_dims(pixel_values, axes=(0, 3, 1, 2)) # NHWC -> NCHW
return pixel_values

def image_embed(self, pixel_values: Tensor) -> Tensor:
Expand Down
9 changes: 4 additions & 5 deletions python/mlc_llm/model/phi3v/phi3v_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,7 @@ def embed(self, input_ids: Tensor):

# pylint: disable=protected-access
def image_preprocess(self, pixel_values: Tensor, num_crops=16) -> Tensor:
pixel_values = op.permute_dims(pixel_values, axes=(0, 3, 1, 2)) # NHWC -> NCHW
pixel_values = self.image_processor.resize(pixel_values, params={"hd_transform": 336})
new_h = tir.Var("new_h", "int64")
new_w = tir.Var("new_w", "int64")
Expand All @@ -228,7 +229,7 @@ def image_preprocess(self, pixel_values: Tensor, num_crops=16) -> Tensor:
.match_cast(
pixel_values._expr,
relax.TensorStructInfo(
[pixel_values.shape[0], new_h, new_w, pixel_values.shape[3]], pixel_values.dtype
[pixel_values.shape[0], pixel_values.shape[1], new_h, new_w], pixel_values.dtype
),
),
"pixel_values",
Expand All @@ -246,16 +247,14 @@ def image_preprocess(self, pixel_values: Tensor, num_crops=16) -> Tensor:
.match_cast(
global_image._expr,
relax.TensorStructInfo(
[global_image.shape[0], 336, 336, global_image.shape[3]], global_image.dtype
[global_image.shape[0], global_image.shape[1], 336, 336], global_image.dtype
),
),
"global_image",
)

global_image = op.permute_dims(global_image, axes=(0, 3, 1, 2))
n, h, w, c = pixel_values.shape # pylint: disable=unused-variable
n, c, h, w = pixel_values.shape # pylint: disable=unused-variable
assert isinstance(h, tir.Mul) and isinstance(h.b, tir.IntImm) and h.b.value == 336
pixel_values = op.permute_dims(pixel_values, axes=(0, 3, 1, 2)) # NHWC -> NCHW
pixel_values = op.reshape(pixel_values, shape=(1, 3, h.a, 336, w // 336, 336))
pixel_values = op.permute_dims(pixel_values, axes=(0, 2, 4, 1, 3, 5))
pixel_values = op.reshape(pixel_values, shape=(-1, 3, 336, 336))
Expand Down
232 changes: 139 additions & 93 deletions python/mlc_llm/model/vision/image_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,35 +7,49 @@
from tvm.script import tir as T


def _var(dtype):
return T.alloc_buffer((1,), dtype, scope="local")
def _var(dtype, size=1):
return T.alloc_buffer((size,), dtype, scope="local")


# pylint: disable=invalid-name,missing-docstring,no-else-return,too-many-locals,useless-parent-delegation
class ImageProcessor(Module):
def __init__(self):
super().__init__()

def resize(self, image: Tensor, params):
# pylint: disable=dangerous-default-value
def apply_schedule(self, sch, block, bdx=32, tile=[32, 32]):
loop_x, loop_y = sch.get_loops(block)[-2:]
xo, xi = sch.split(loop_x, factors=[tile[0], None])
yo, yi = sch.split(loop_y, factors=[tile[1], None])
sch.reorder(xo, yo, xi, yi)
t = sch.fuse(xo, yo)
ty, tx = sch.split(t, factors=[None, bdx])
sch.bind(ty, "threadIdx.y")
sch.bind(tx, "threadIdx.x")

def resize(self, image: Tensor, params): # image layout:NCHW
assert 4 == image.ndim, "image should be 4D data tensor"
assert 3 == image.shape[1], "image layout should be NCHW"

def get_output_image_size(image: Tensor):
if 4 == image.ndim:
h = image.shape[1]
w = image.shape[2]
elif 3 == image.ndim:
h = image.shape[0]
w = image.shape[1]
else:
assert False, "not supported image shape"
h = image.shape[2]
w = image.shape[3]

if "height" in params and "width" in params:
return (params["height"], params["width"])
elif "shortest_edge" in params:
short = tir.Select(w > h, w, h)
long = tir.Select(w > h, h, w)
short = tir.Select(w < h, w, h)
long = tir.Select(w > h, w, h)
requested_new_short = params["shortest_edge"]
new_short, new_long = tir.generic.cast(
requested_new_short, "int64"
), tir.generic.cast(requested_new_short * tir.div(long, short), "int64")
), tir.generic.cast(
requested_new_short
* tir.div(
tir.generic.cast(long, "float32"), tir.generic.cast(short, "float32")
),
"int64",
)
ret_h = tir.Select(w <= h, new_long, new_short)
ret_w = tir.Select(w <= h, new_short, new_long)
return (ret_h, ret_w)
Expand Down Expand Up @@ -63,14 +77,15 @@ def get_output_image_size(image: Tensor):
assert False, "not supported resize parameter"

(new_h, new_w) = get_output_image_size(image)
if 3 == image.ndim:
image = op.unsqueeze(image, 0)
out = op.interpolate(image, (new_h, new_w), data_layout="NHWC", mode="bicubic")
out = op.interpolate(image, (new_h, new_w), data_layout="NCHW", mode="bicubic")
return out

# pylint: disable=too-many-arguments,too-many-locals
def crop(self, image: Tensor, crop_size):
def create_crop_func(dtype):
assert 4 == image.ndim, "image should be 4D data tensor"
assert 3 == image.shape[1], "image layout should be NCHW"

def create_crop_func(dtype): # , top, bottom, left, right):
@T.prim_func
def crop_func(
image: T.handle,
Expand All @@ -82,59 +97,70 @@ def crop_func(
):
T.func_attr({"op_pattern": 8, "tir.noalias": True, "tir.is_scheduled": 1})
n, c, h, w = T.int64(), T.int64(), T.int64(), T.int64()
image_buf = T.match_buffer(image, (n, h, w, c), dtype=dtype)
out_buf = T.match_buffer(out, (n, bottom - top, right - left, c), dtype=dtype)
with T.block("root"):
for n_idx in T.thread_binding(n, thread="blockIdx.x"):
for h_idx in range((bottom - top)):
for w_idx in range((right - left)):
for c_idx in range(c):
with T.block("compute"):
T.writes(out_buf[n_idx, h_idx, w_idx, c_idx])
out_buf[n_idx, h_idx, w_idx, c_idx] = image_buf[
n_idx, h_idx + top, w_idx + left, c_idx
]

return crop_func

n, orig_height, orig_width, c = image.shape
assert n == 1
image_buf = T.match_buffer(image, (n, c, h, w), dtype=dtype)
out_buf = T.match_buffer(out, (n, c, bottom - top, right - left), dtype=dtype)
out_h = bottom - top
out_w = right - left
for n_idx in T.thread_binding(n, thread="blockIdx.x"):
for c_idx in T.thread_binding(c, thread="blockIdx.y"):
for h_idx, w_idx in T.grid(out_h, out_w):
with T.block("crop"):
if (h_idx + T.int64(top)) < h and (w_idx + T.int64(left)) < w:
T.writes(out_buf[n_idx, c_idx, h_idx, w_idx])
T.reads(image_buf[n_idx, c_idx, h_idx + top, w_idx + left])
out_buf[n_idx, c_idx, h_idx, w_idx] = image_buf[
n_idx, c_idx, h_idx + top, w_idx + left
]

sch = tir.Schedule(crop_func)
self.apply_schedule(sch, sch.get_block("crop"))
return sch.mod["main"].with_attr("tir.is_scheduled", 1)

n, c, orig_height, orig_width = image.shape
crop_height = crop_size["height"]
crop_width = crop_size["width"]

top = (orig_height - crop_height) // 2
bottom = top + crop_height
bottom = orig_height - top

left = (orig_width - crop_width) // 2
right = left + crop_width
new_height = bottom - top
new_width = right - left
right = orig_width - left

out = op.tensor_ir_op(
create_crop_func(image.dtype),
"crop",
[image, top, bottom, left, right],
[Tensor.placeholder([n, new_height, new_width, c], image.dtype)],
[Tensor.placeholder([n, c, crop_height, crop_width], image.dtype)],
)
return out

def rescale(self, image: Tensor, rescale_factor=1 / 255.0, o_dtype="float32"):
assert 4 == image.ndim, "image should be 4D data tensor"
assert 3 == image.shape[1], "image layout should be NCHW"

def create_rescale_func(rescale_factor, dtype, o_dtype):
@T.prim_func
def rescale_func(image: T.handle, out: T.handle):
T.func_attr({"op_pattern": 8, "tir.noalias": True, "tir.is_scheduled": 1})
n, c, h, w = T.int64(), T.int64(), T.int64(), T.int64()
image_buf = T.match_buffer(image, (n, h, w, c), dtype=dtype)
out_buf = T.match_buffer(out, (n, h, w, c), dtype=o_dtype)
image_buf = T.match_buffer(image, (n, c, h, w), dtype=dtype)
out_buf = T.match_buffer(out, (n, c, h, w), dtype=o_dtype)

for n_idx in T.thread_binding(n, thread="blockIdx.x"):
for h_idx, w_idx, c_idx in T.grid(h, w, c):
with T.block("compute"):
T.reads(image_buf[n_idx, h_idx, w_idx, c_idx])
T.writes(out_buf[n_idx, h_idx, w_idx, c_idx])
out_buf[n_idx, h_idx, w_idx, c_idx] = (
T.cast(image_buf[n_idx, h_idx, w_idx, c_idx], o_dtype)
* rescale_factor
)
for c_idx in T.thread_binding(c, thread="blockIdx.y"):
for h_idx, w_idx in T.grid(h, w):
with T.block("rescale"):
T.reads(image_buf[n_idx, c_idx, h_idx, w_idx])
T.writes(out_buf[n_idx, c_idx, h_idx, w_idx])
if h_idx < h and w_idx < w:
out_buf[n_idx, c_idx, h_idx, w_idx] = (
T.cast(image_buf[n_idx, c_idx, h_idx, w_idx], o_dtype)
* rescale_factor
)

return rescale_func
sch = tir.Schedule(rescale_func)
self.apply_schedule(sch, sch.get_block("rescale"))
return sch.mod["main"].with_attr("tir.is_scheduled", 1)

out = op.tensor_ir_op(
create_rescale_func(rescale_factor, image.dtype, o_dtype),
Expand All @@ -145,35 +171,44 @@ def rescale_func(image: T.handle, out: T.handle):
return out

def normalize(self, image: Tensor, o_dtype="float32"):
assert 4 == image.ndim, "image should be 4D data tensor"
assert 3 == image.shape[1], "image layout should be NCHW"

def create_normalize_func(dtype, o_dtype):
@T.prim_func
def normalize_func(image: T.handle, out: T.handle):
T.func_attr({"op_pattern": 8, "tir.noalias": True, "tir.is_scheduled": 1})
n, c, h, w = T.int64(), T.int64(), T.int64(), T.int64()
image_buf = T.match_buffer(image, (n, h, w, c), dtype=dtype)
out_buf = T.match_buffer(out, (n, h, w, c), dtype=o_dtype)
mean = _var(o_dtype)
stddev = _var(o_dtype)
image_buf = T.match_buffer(image, (n, c, h, w), dtype=dtype)
out_buf = T.match_buffer(out, (n, c, h, w), dtype=o_dtype)
mean = _var(o_dtype, 3)
stddev = _var(o_dtype, 3)

for n_idx in T.thread_binding(n, thread="blockIdx.x"):
for h_idx, w_idx, c_idx in T.grid(h, w, c):
with T.block("compute"):
T.reads(image_buf[n_idx, h_idx, w_idx, c_idx])
T.writes(out_buf[n_idx, h_idx, w_idx, c_idx])
if 0 == c_idx:
mean[0] = 0.48145466
stddev[0] = 0.26862954
elif 1 == c_idx:
mean[0] = 0.4578275
stddev[0] = 0.26130258
elif 2 == c_idx:
mean[0] = 0.40821073
stddev[0] = 0.27577711

out_buf[n_idx, h_idx, w_idx, c_idx] = (
T.cast(image_buf[n_idx, h_idx, w_idx, c_idx], o_dtype) - mean[0]
) / stddev[0]

return normalize_func
for c_idx in T.thread_binding(c, thread="blockIdx.y"):
for h_idx, w_idx in T.grid(h, w):
with T.block("normalize"):
T.reads(
image_buf[n_idx, c_idx, h_idx, w_idx],
mean[c_idx],
stddev[c_idx],
)
T.writes(out_buf[n_idx, c_idx, h_idx, w_idx])
with T.init():
mean[0] = 0.48145466
stddev[0] = 0.26862954
mean[1] = 0.4578275
stddev[1] = 0.26130258
mean[2] = 0.40821073
stddev[2] = 0.27577711
if h_idx < h and w_idx < w:
out_buf[n_idx, c_idx, h_idx, w_idx] = (
T.cast(image_buf[n_idx, c_idx, h_idx, w_idx], o_dtype)
- mean[c_idx]
) / stddev[c_idx]

sch = tir.Schedule(normalize_func)
self.apply_schedule(sch, sch.get_block("normalize"))
return sch.mod["main"].with_attr("tir.is_scheduled", 1)

out = op.tensor_ir_op(
create_normalize_func(image.dtype, o_dtype),
Expand All @@ -184,40 +219,51 @@ def normalize_func(image: T.handle, out: T.handle):
return out

def pad(self, image: Tensor, dtype="uint8"):
assert 4 == image.ndim, "image should be 4D data tensor"
assert 3 == image.shape[1], "image layout should be NCHW"

def create_pad_func(l, r, fill=255):
@T.prim_func
def pad_func(image: T.handle, out: T.handle, t: T.int64(), b: T.int64()):
T.func_attr({"op_pattern": 8, "tir.noalias": True, "tir.is_scheduled": 1})
n, c, h, w = T.int64(), T.int64(), T.int64(), T.int64()
image_buf = T.match_buffer(image, (n, h, w, c), dtype=dtype)
out_buf = T.match_buffer(out, (n, h + t + b, w + l + r, c), dtype=dtype)
image_buf = T.match_buffer(image, (n, c, h, w), dtype=dtype)
out_buf = T.match_buffer(out, (n, c, h + t + b, w + l + r), dtype=dtype)
out_h = h + t + b
out_w = w + l + r

for n_idx in T.thread_binding(n, thread="blockIdx.x"):
for h_idx, w_idx, c_idx in T.grid(h + t + b, w + l + r, c):
with T.block("compute"):
T.reads(image_buf[n_idx, h_idx, w_idx, c_idx])
T.writes(out_buf[n_idx, h_idx, w_idx, c_idx])
if h_idx < t or h_idx > h + b or w_idx < l or w_idx > w + r:
out_buf[n_idx, h_idx, w_idx, c_idx] = fill
else:
out_buf[n_idx, h_idx, w_idx, c_idx] = image_buf[
n_idx, h_idx - t, w_idx - l, c_idx
]

return pad_func

h = image.shape[1]
for c_idx in T.thread_binding(c, thread="blockIdx.y"):
for h_idx, w_idx in T.grid(out_h, out_w):
with T.block("pad"):
T.reads(image_buf[n_idx, c_idx, h_idx, w_idx])
T.writes(out_buf[n_idx, c_idx, h_idx, w_idx])
if h_idx < t or h_idx > h + b or w_idx < l or w_idx > w + r:
out_buf[n_idx, c_idx, h_idx, w_idx] = fill
else:
out_buf[n_idx, c_idx, h_idx, w_idx] = image_buf[
n_idx, c_idx, h_idx - t, w_idx - l
]

sch = tir.Schedule(pad_func)
self.apply_schedule(sch, sch.get_block("pad"))
return sch.mod["main"].with_attr("tir.is_scheduled", 1)

h = image.shape[2]
tar = tir.truncdiv(h + 335, 336) * 336
t = tir.div(tar - h, 2)
b = tar - h - t
l = 0
r = 0

n, h, w, c = image.shape
n, c, h, w = image.shape
out = op.tensor_ir_op(
create_pad_func(l, r),
"pad",
[image, t, b],
[Tensor.placeholder((n, tar, w, c), image.dtype)],
[Tensor.placeholder((n, c, tar, w), image.dtype)],
)
return out

def preprocess(self, pixel_values):
return pixel_values

0 comments on commit 3e8d4dd

Please sign in to comment.