Skip to content

Commit

Permalink
fix image cls process
Browse files Browse the repository at this point in the history
  • Loading branch information
zhangyubo0722 committed Dec 31, 2024
1 parent 903b522 commit 00a9ae1
Show file tree
Hide file tree
Showing 5 changed files with 79 additions and 22 deletions.
4 changes: 2 additions & 2 deletions paddlex/inference/components/transforms/image/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,11 +278,11 @@ def apply(self, img):
x2 = min(w, x1 + cw)
y2 = min(h, y1 + ch)
coords = (x1, y1, x2, y2)
if coords == (0, 0, w, h):
img = F.slice(img, coords=coords)
if w < cw or h < ch:
raise ValueError(
f"Input image ({w}, {h}) smaller than the target size ({cw}, {ch})."
)
img = F.slice(img, coords=coords)
return {"img": img, "img_size": [img.shape[1], img.shape[0]]}


Expand Down
23 changes: 21 additions & 2 deletions paddlex/inference/models_new/common/vision/funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
# limitations under the License.

import cv2
import numpy as np
from PIL import Image


def check_image_size(input_):
Expand All @@ -26,13 +28,30 @@ def check_image_size(input_):
raise TypeError(f"{input_} cannot represent a valid image size.")


def resize(im, target_size, interp):
def resize(im, target_size, interp, backend="cv2"):
"""resize image to target size"""
w, h = target_size
im = cv2.resize(im, (w, h), interpolation=interp)
resize_functions = {"cv2": _cv2_resize, "pil": _pil_resize}
resize_function = resize_functions.get(backend.lower(), _cv2_resize)
im = resize_function(im, (w, h), interp)
return im


def _cv2_resize(src, size, resample):
return cv2.resize(src, size, interpolation=resample)


def _pil_resize(src, size, resample, return_numpy=True):
if isinstance(src, np.ndarray):
pil_img = Image.fromarray(src)
else:
pil_img = src
pil_img = pil_img.resize(size, resample)
if return_numpy:
return np.asarray(pil_img)
return pil_img


def flip_h(im):
"""flip image horizontally"""
if len(im.shape) == 3:
Expand Down
58 changes: 44 additions & 14 deletions paddlex/inference/models_new/common/vision/processors.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,20 +20,28 @@

import numpy as np
import cv2
from PIL import Image

from . import funcs as F


class _BaseResize:
_INTERP_DICT = {
_CV2_INTERP_DICT = {
"NEAREST": cv2.INTER_NEAREST,
"LINEAR": cv2.INTER_LINEAR,
"CUBIC": cv2.INTER_CUBIC,
"BICUBIC": cv2.INTER_CUBIC,
"AREA": cv2.INTER_AREA,
"LANCZOS4": cv2.INTER_LANCZOS4,
}
_PIL_INTERP_DICT = {
"NEAREST": Image.NEAREST,
"BILINEAR": Image.BILINEAR,
"BICUBIC": Image.BICUBIC,
"BOX": Image.BOX,
"LANCZOS4": Image.LANCZOS,
}

def __init__(self, size_divisor, interp):
def __init__(self, size_divisor, interp, backend="cv2"):
super().__init__()

if size_divisor is not None:
Expand All @@ -43,12 +51,21 @@ def __init__(self, size_divisor, interp):
self.size_divisor = size_divisor

try:
interp = self._INTERP_DICT[interp]
interp = interp.upper()
if backend == "cv2":
interp = self._CV2_INTERP_DICT[interp]
elif backend == "pil":
interp = self._PIL_INTERP_DICT[interp]
else:
raise ValueError("backend must be `cv2` or `pil`")
except KeyError:
raise ValueError(
"`interp` should be one of {}.".format(self._INTERP_DICT.keys())
"`interp` should be one of {} or {}.".format(
self._CV2_INTERP_DICT.keys(), self._PIL_INTERP_DICT.keys()
)
)
self.interp = interp
self.backend = backend

@staticmethod
def _rescale_size(img_size, target_size):
Expand All @@ -62,7 +79,12 @@ class Resize(_BaseResize):
"""Resize the image."""

def __init__(
self, target_size, keep_ratio=False, size_divisor=None, interp="LINEAR"
self,
target_size,
keep_ratio=False,
size_divisor=None,
interp="LINEAR",
backend="cv2",
):
"""
Initialize the instance.
Expand All @@ -76,7 +98,7 @@ def __init__(
interp (str, optional): Interpolation method. Choices are 'NEAREST',
'LINEAR', 'CUBIC', 'AREA', and 'LANCZOS4'. Default: 'LINEAR'.
"""
super().__init__(size_divisor=size_divisor, interp=interp)
super().__init__(size_divisor=size_divisor, interp=interp, backend=backend)

if isinstance(target_size, int):
target_size = [target_size, target_size]
Expand All @@ -102,7 +124,7 @@ def resize(self, img):
math.ceil(i / self.size_divisor) * self.size_divisor
for i in target_size
]
img = F.resize(img, target_size, interp=self.interp)
img = F.resize(img, target_size, interp=self.interp, backend=self.backend)
return img


Expand All @@ -112,7 +134,9 @@ class ResizeByLong(_BaseResize):
longest side.
"""

def __init__(self, target_long_edge, size_divisor=None, interp="LINEAR"):
def __init__(
self, target_long_edge, size_divisor=None, interp="LINEAR", backend="cv2"
):
"""
Initialize the instance.
Expand All @@ -123,7 +147,7 @@ def __init__(self, target_long_edge, size_divisor=None, interp="LINEAR"):
interp (str, optional): Interpolation method. Choices are 'NEAREST',
'LINEAR', 'CUBIC', 'AREA', and 'LANCZOS4'. Default: 'LINEAR'.
"""
super().__init__(size_divisor=size_divisor, interp=interp)
super().__init__(size_divisor=size_divisor, interp=interp, backend=backend)
self.target_long_edge = target_long_edge

def __call__(self, imgs):
Expand All @@ -139,7 +163,9 @@ def resize(self, img):
h_resize = math.ceil(h_resize / self.size_divisor) * self.size_divisor
w_resize = math.ceil(w_resize / self.size_divisor) * self.size_divisor

img = F.resize(img, (w_resize, h_resize), interp=self.interp)
img = F.resize(
img, (w_resize, h_resize), interp=self.interp, backend=self.backend
)
return img


Expand All @@ -149,7 +175,9 @@ class ResizeByShort(_BaseResize):
shortest side.
"""

def __init__(self, target_short_edge, size_divisor=None, interp="LINEAR"):
def __init__(
self, target_short_edge, size_divisor=None, interp="LINEAR", backend="cv2"
):
"""
Initialize the instance.
Expand All @@ -160,7 +188,7 @@ def __init__(self, target_short_edge, size_divisor=None, interp="LINEAR"):
interp (str, optional): Interpolation method. Choices are 'NEAREST',
'LINEAR', 'CUBIC', 'AREA', and 'LANCZOS4'. Default: 'LINEAR'.
"""
super().__init__(size_divisor=size_divisor, interp=interp)
super().__init__(size_divisor=size_divisor, interp=interp, backend=backend)
self.target_short_edge = target_short_edge

def __call__(self, imgs):
Expand All @@ -176,7 +204,9 @@ def resize(self, img):
h_resize = math.ceil(h_resize / self.size_divisor) * self.size_divisor
w_resize = math.ceil(w_resize / self.size_divisor) * self.size_divisor

img = F.resize(img, (w_resize, h_resize), interp=self.interp)
img = F.resize(
img, (w_resize, h_resize), interp=self.interp, backend=self.backend
)
return img


Expand Down
12 changes: 10 additions & 2 deletions paddlex/inference/models_new/image_classification/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,10 +138,18 @@ def build_resize(
assert resize_short or size
if resize_short:
op = ResizeByShort(
target_short_edge=resize_short, size_divisor=None, interp="LINEAR"
target_short_edge=resize_short,
size_divisor=None,
interp=interpolation,
backend=backend,
)
else:
op = Resize(target_size=size)
op = Resize(
target_size=size,
size_divisor=None,
interp=interpolation,
backend=backend,
)
return "Resize", op

@register("CropImage")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,11 @@ def crop(self, img):
x2 = min(w, x1 + cw)
y2 = min(h, y1 + ch)
coords = (x1, y1, x2, y2)
if coords == (0, 0, w, h):
img = F.slice(img, coords=coords)
if w < cw or h < ch:
raise ValueError(
f"Input image ({w}, {h}) smaller than the target size ({cw}, {ch})."
)
img = F.slice(img, coords=coords)
return img


Expand Down

0 comments on commit 00a9ae1

Please sign in to comment.