Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

yowo for paddlex #2784

Open
wants to merge 6 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions api_examples/pipelines/test_video_detection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from paddlex import create_pipeline

pipeline = create_pipeline(pipeline="video_detection")
output = pipeline.predict("./test_samples/HorseRiding.avi")

for res in output:
print(res)
res.print() ## 打印预测的结构化输出
res.save_to_video("./output/") ## 保存结果可视化视频
res.save_to_json("./output/") ## 保存预测的结构化输出
40 changes: 40 additions & 0 deletions paddlex/configs/modules/video_detection/YOWO.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
Global:
model: YOWO
mode: check_dataset # check_dataset/train/evaluate/predict
dataset_dir: "/paddle/dataset/paddlex/video_det/video_det_examples"
device: gpu:0
output: "output"

CheckDataset:
convert:
enable: False
src_dataset_type: null
split:
enable: False
train_percent: null
val_percent: null

Train:
num_classes: 24
epochs_iters: 5
batch_size: 8
learning_rate: 0.0001
pretrain_weight_path: "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_pretrained_model/YOWO_pretrain.pdparams"
resume_path: null
log_interval: 10
eval_interval: 1
save_interval: 1

Evaluate:
weight_path: "output/best_model/best_model.pdparams"
log_interval: 1

Export:
weight_path: "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_pretrained_model/YOWO_pretrain.pdparams"

Predict:
batch_size: 1
model_dir: "output/best_model/inference"
input: "https://paddle-model-ecology.bj.bcebos.com/paddlex/videos/demo_video/HorseRiding.avi"
kernel_option:
run_mode: paddle
9 changes: 9 additions & 0 deletions paddlex/configs/pipelines/video_detection.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
pipeline_name: video_detection

SubModules:
VideoDetection:
module_name: video_detection
model_name: YOWO
model_dir: null
batch_size: 1
topk: 1
2 changes: 2 additions & 0 deletions paddlex/inference/models_new/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,11 @@
# from .table_recognition import TablePredictor
# from .general_recognition import ShiTuRecPredictor
from .anomaly_detection import UadPredictor

# from .face_recognition import FaceRecPredictor
from .multilingual_speech_recognition import WhisperPredictor
from .video_classification import VideoClasPredictor
from .video_detection import VideoDetPredictor


def _create_hp_predictor(
Expand Down
15 changes: 15 additions & 0 deletions paddlex/inference/models_new/video_detection/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .predictor import VideoDetPredictor
117 changes: 117 additions & 0 deletions paddlex/inference/models_new/video_detection/predictor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Union, Dict, List, Tuple
from ....utils.func_register import FuncRegister
from ....modules.video_detection.model_list import MODELS
from ...common.batch_sampler import VideoBatchSampler
from ...common.reader import ReadVideo
from ..common import (
ToBatch,
StaticInfer,
)
from ..base import BasicPredictor
from .processors import ResizeVideo, Image2Array, NormalizeVideo, DetVideoPostProcess
from .result import DetVideoResult


class VideoDetPredictor(BasicPredictor):

entities = MODELS

_FUNC_MAP = {}
register = FuncRegister(_FUNC_MAP)

def __init__(self, topk: Union[int, None] = None, *args, **kwargs):
super().__init__(*args, **kwargs)
self.pre_tfs, self.infer, self.post_op = self._build()

def _build_batch_sampler(self):
return VideoBatchSampler()

def _get_result_class(self):
return DetVideoResult

def _build(self):
pre_tfs = {}
for cfg in self.config["PreProcess"]["transform_ops"]:
tf_key = list(cfg.keys())[0]
assert tf_key in self._FUNC_MAP
func = self._FUNC_MAP[tf_key]
args = cfg.get(tf_key, {})
name, op = func(self, **args) if args else func(self)
if op:
pre_tfs[name] = op

infer = StaticInfer(
model_dir=self.model_dir,
model_prefix=self.MODEL_FILE_PREFIX,
option=self.pp_option,
)
post_op = {}
for cfg in self.config["PostProcess"]["transform_ops"]:
tf_key = list(cfg.keys())[0]
assert tf_key in self._FUNC_MAP
func = self._FUNC_MAP[tf_key]
args = cfg.get(tf_key, {})
if tf_key == "DetVideoPostProcess":
args["label_list"] = self.config["label_list"]
name, op = func(self, **args) if args else func(self)
if op:
post_op[name] = op

return pre_tfs, infer, post_op

def process(self, batch_data):
batch_raw_videos = self.pre_tfs["ReadVideo"](videos=batch_data)
batch_videos = self.pre_tfs["ResizeVideo"](videos=batch_raw_videos)
batch_videos = self.pre_tfs["Image2Array"](videos=batch_videos)
x = self.pre_tfs["NormalizeVideo"](videos=batch_videos)
num_seg = len(x[0])
pred_seg = []
for i in range(num_seg):
batch_preds = self.infer(x=[x[0][i]])
pred_seg.append(batch_preds)
batch_bboxes = self.post_op["DetVideoPostProcess"](preds=[pred_seg])
return {
"input_path": batch_data,
"result": batch_bboxes,
}

@register("ReadVideo")
def build_readvideo(self, num_seg=8):
return "ReadVideo", ReadVideo(backend="opencv", num_seg=num_seg)

@register("ResizeVideo")
def build_resize(self, target_size=224):
return "ResizeVideo", ResizeVideo(
target_size=target_size,
)

@register("Image2Array")
def build_image2array(self, data_format="tchw"):
return "Image2Array", Image2Array(data_format="tchw")

@register("NormalizeVideo")
def build_normalize(
self,
scale=255.0,
):
return "NormalizeVideo", NormalizeVideo(scale=scale)

@register("DetVideoPostProcess")
def build_postprocess(self, nms_thresh=0.5, score_thresh=0.4, label_list=[]):
return "DetVideoPostProcess", DetVideoPostProcess(
nms_thresh=nms_thresh, score_thresh=score_thresh, label_list=label_list
)
Loading