Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dev gdino sam zzl #2808

Merged
merged 9 commits into from
Jan 14, 2025
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions paddlex/inference/models_new/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@
from .ts_classification import TSClsPredictor
from .image_unwarping import WarpPredictor
from .image_multilabel_classification import MLClasPredictor
from .open_vocabulary_detection import OVDetPredictor
from .open_vocabulary_segmentation import OVSegPredictor


# from .table_recognition import TablePredictor
Expand Down
1 change: 1 addition & 0 deletions paddlex/inference/models_new/common/tokenizer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,4 @@

from .tokenizer_utils import PretrainedTokenizer
from .gpt_tokenizer import GPTTokenizer
from .bert_tokenizer import BertTokenizer
629 changes: 629 additions & 0 deletions paddlex/inference/models_new/common/tokenizer/bert_tokenizer.py

Large diffs are not rendered by default.

77 changes: 77 additions & 0 deletions paddlex/inference/models_new/common/tokenizer/tokenizer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2046,3 +2046,80 @@ def decode_token(
return new_text, read_offset, len(all_input_ids)
else:
return "", prefix_offset, read_offset

def _is_control(char):
"""Checks whether `chars` is a control character."""
# These are technically control characters but we count them as whitespace
# characters.
if char == "\t" or char == "\n" or char == "\r":
return False
cat = unicodedata.category(char)
if cat.startswith("C"):
return True
return False

def _is_punctuation(char):
"""Checks whether `chars` is a punctuation character."""
cp = ord(char)
# We treat all non-letter/number ASCII as punctuation.
# Characters such as "^", "$", and "`" are not in the Unicode
# Punctuation class but we treat them as punctuation anyways, for
# consistency.
if (cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126):
return True
cat = unicodedata.category(char)
if cat.startswith("P"):
return True
return False

def _is_symbol(char):
"""Check whether CP is the codepoint of a Symbol character."""
cp = ord(char)
if unicodedata.category(char).startswith("S") or (
cp in [0x00AD, 0x00B2, 0x00BA, 0x3007, 0x00B5, 0x00D8, 0x014B, 0x01B1]
):
return True
return False

def _is_whitespace(char):
"""
Checks whether `chars` is a whitespace character.
"""
# \t, \n, and \r are technically contorl characters but we treat them
# as whitespace since they are generally considered as such.
if char == " " or char == "\t" or char == "\n" or char == "\r":
return True
cat = unicodedata.category(char)
if cat == "Zs":
return True
return False

def convert_to_unicode(text):
"""
Converts `text` to Unicode (if it's not already), assuming utf-8 input.
Args:
text (str|bytes): Text to be converted to unicode.
Returns:
str: converted text.
"""
if isinstance(text, str):
return text
elif isinstance(text, bytes):
return text.decode("utf-8", "ignore")
else:
raise ValueError("Unsupported string type: %s" % (type(text)))

def whitespace_tokenize(text):
"""
Runs basic whitespace cleaning and splitting on a peice of text.
Args:
text (str): Text to be tokenized.
Returns:
list(str): Token list.
"""
text = text.strip()
if not text:
return []
tokens = text.split()
return tokens

15 changes: 15 additions & 0 deletions paddlex/inference/models_new/open_vocabulary_detection/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .predictor import OVDetPredictor
152 changes: 152 additions & 0 deletions paddlex/inference/models_new/open_vocabulary_detection/predictor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Union, Dict, List, Tuple, Optional, Callable
import numpy as np
import inspect

from ....utils.func_register import FuncRegister
from ....modules.open_vocabulary_detection.model_list import MODELS
from ...common.batch_sampler import ImageBatchSampler
from ...common.reader import ReadImage
from .processors import (
GroundingDINOProcessor,
GroundingDINOPostProcessor
)
from ..common import StaticInfer
from ..base import BasicPredictor
from ..object_detection.result import DetResult


class OVDetPredictor(BasicPredictor):

entities = MODELS

_FUNC_MAP = {}
register = FuncRegister(_FUNC_MAP)

def __init__(self, *args, thresholds: Optional[Union[Dict, float]] = None, **kwargs):
"""Initializes DetPredictor.
Args:
*args: Arbitrary positional arguments passed to the superclass.
thresholds (Optional[Union[Dict, float]], optional): The thresholds for filtering out low-confidence predictions, using a dict to record multiple thresholds
Defaults to None.
**kwargs: Arbitrary keyword arguments passed to the superclass.
"""
super().__init__(*args, **kwargs)
if isinstance(thresholds, float):
thresholds = {"threshold": thresholds}
self.thresholds = thresholds
self.pre_ops, self.infer, self.post_op = self._build()

def _build_batch_sampler(self):
return ImageBatchSampler()

def _get_result_class(self):
return DetResult

def _build(self):
# build model preprocess ops
pre_ops = [ReadImage(format="RGB")]
for cfg in self.config["Preprocess"]:
tf_key = cfg["type"]
func = self._FUNC_MAP[tf_key]
cfg.pop("type")
args = cfg
op = func(self, **args) if args else func(self)
if op:
pre_ops.append(op)

# build infer
infer = StaticInfer(
model_dir=self.model_dir,
model_prefix=self.MODEL_FILE_PREFIX,
option=self.pp_option,
)

# build postprocess op
post_op = self.build_postprocess(pre_ops = pre_ops)

return pre_ops, infer, post_op

def process(self, batch_data: List[Any], prompt: str, thresholds: Optional[dict] = None):
"""
Process a batch of data through the preprocessing, inference, and postprocessing.

Args:
batch_data (List[str]): A batch of input data (e.g., image file paths).
prompt (str): Text prompt for open vocabulary detection.
thresholds (Optional[dict]): thresholds used for postprocess.

Returns:
dict: A dictionary containing the input path, raw image, class IDs, scores, and label names
for every instance of the batch. Keys include 'input_path', 'input_img', 'class_ids', 'scores', and 'label_names'.
"""
image_paths = batch_data
src_images = self.pre_ops[0](batch_data)
datas = src_images
# preprocess
for pre_op in self.pre_ops[1:-1]:
datas = pre_op(datas)

# use Model-specific preprocessor to format batch inputs
batch_inputs = self.pre_ops[-1](datas, prompt)

# do infer
batch_preds = self.infer(batch_inputs)

# postprocess
current_thresholds = self._parse_current_thresholds(
self.post_op, self.thresholds, thresholds
)
boxes = self.post_op(
*batch_preds, prompt=prompt, src_images=src_images, **current_thresholds
)

return {
"input_path": image_paths,
"input_img": src_images,
"boxes": boxes,
}

def _parse_current_thresholds(self, func, init_thresholds, process_thresholds):
assert isinstance(func, Callable)
thr2val = {}
for name, param in inspect.signature(func).parameters.items():
if "threshold" in name:
thr2val[name] = None
if init_thresholds is not None:
thr2val.update(init_thresholds)
if process_thresholds is not None:
thr2val.update(process_thresholds)
return thr2val

def build_postprocess(self, **kwargs):
if "GroundingDINO" in self.model_name:
pre_ops = kwargs.get("pre_ops")
return GroundingDINOPostProcessor(
tokenizer=pre_ops[-1].tokenizer,
box_threshold=self.config["box_threshold"],
text_threshold=self.config["text_threshold"],
)
else:
raise NotImplementedError

@register("GroundingDINOProcessor")
def build_grounding_dino_preprocessor(self, text_max_words=256, target_size=(800, 1333)):
return GroundingDINOProcessor(
model_dir=self.model_dir,
text_max_words=text_max_words,
target_size=target_size
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .groundingdino_processors import GroundingDINOProcessor, GroundingDINOPostProcessor
Loading