From f45e5ea755bdaa29b3ddb8bcf33e16f9f9468639 Mon Sep 17 00:00:00 2001 From: Muhammad Mudassar Date: Sun, 3 Nov 2024 11:53:06 +0000 Subject: [PATCH 1/5] added requirements.txt. --- .../object_detector/detectron2/requirements.txt | 14 ++++++++++++++ 1 file changed, 14 insertions(+) create mode 100644 examples/object_detector/detectron2/requirements.txt diff --git a/examples/object_detector/detectron2/requirements.txt b/examples/object_detector/detectron2/requirements.txt new file mode 100644 index 0000000000..d788fbd136 --- /dev/null +++ b/examples/object_detector/detectron2/requirements.txt @@ -0,0 +1,14 @@ +opencv-python==4.10.0.84 +python-multipart==0.0.9 +torch==2.2.0 +torchvision==0.17.0 +transformers==4.44.2 +torchvision==0.17.0 +numpy==1.24.4 +torchserve==0.12.0 +torch-model-archiver==0.12.0 +torch-workflow-archiver==0.2.15 +pillow==11.0.0 +pillow-avif-plugin==1.4.6 +pillow-jxl-plugin==1.2.8 +pillow_heif==0.20.0 \ No newline at end of file From f81c0354b8a60c2a367c4a3be940acc11592ec6a Mon Sep 17 00:00:00 2001 From: Muhammad Mudassar Date: Sun, 3 Nov 2024 12:02:31 +0000 Subject: [PATCH 2/5] added detectron2 handler file. --- .../detectron2/detectron2-handler.py | 265 ++++++++++++++++++ 1 file changed, 265 insertions(+) create mode 100644 examples/object_detector/detectron2/detectron2-handler.py diff --git a/examples/object_detector/detectron2/detectron2-handler.py b/examples/object_detector/detectron2/detectron2-handler.py new file mode 100644 index 0000000000..78ed8c2b4c --- /dev/null +++ b/examples/object_detector/detectron2/detectron2-handler.py @@ -0,0 +1,265 @@ +import io +import json +import time +import torch +import logging +import numpy as np +from os import path +from detectron2.config import get_cfg +from PIL import Image, UnidentifiedImageError +from detectron2.engine import DefaultPredictor +from detectron2.utils.logger import setup_logger +try: + import pillow_heif + import pillow_avif + import pillow_jxl + # Register openers for extended formats + pillow_heif.register_heif_opener() + # For pillow_avif and pillow_jxl, openers are registered upon import +except ImportError as e: + raise ImportError( + "Please install 'pillow-heif', 'pillow-avif', and 'pillow-jxl' to handle extended image formats. " + f"Missing package error: {e}" + ) +######################################################################################################################################## +setup_logger() +logger = logging.getLogger(__name__) +logging.basicConfig(level=logging.INFO) +######################################################################################################################################## +class ModelHandler: + """ + A base ModelHandler implementation for loading and running Detectron2 models with TorchServe. + Compatible with both CPU and GPU. + """ + def __init__(self): + """ + Initialize the ModelHandler instance. + """ + self.error = None + self._context = None + self._batch_size = 0 + self.initialized = False + self.predictor = None + self.model_file = "model.pth" + self.config_file = "config.yaml" + self.device = "cpu" + if torch.cuda.is_available(): + self.device = "cuda" + logger.info("Using GPU for inference.") + else: + logger.info("Using CPU for inference.") + + def initialize(self, context): + """ + Load the model and initialize the predictor. + Args: + context (Context): Initial context contains model server system properties. + """ + logger.info("Initializing model...") + + self._context = context + self._batch_size = context.system_properties.get("batch_size", 1) + model_dir = context.system_properties.get("model_dir") + model_path = path.join(model_dir, self.model_file) + config_path = path.join(model_dir, self.config_file) + logger.debug(f"Checking model file: {model_path} exists: {path.exists(model_path)}") + logger.debug(f"Checking config file: {config_path} exists: {path.exists(config_path)}") + if not path.exists(model_path): + error_msg = f"Model file {model_path} does not exist." + logger.error(error_msg) + self.error = error_msg + self.initialized = False + return + if not path.exists(config_path): + error_msg = f"Config file {config_path} does not exist." + logger.error(error_msg) + self.error = error_msg + self.initialized = False + return + try: + cfg = get_cfg() + cfg.merge_from_file(config_path) + cfg.MODEL.WEIGHTS = model_path + cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5 + cfg.MODEL.DEVICE = self.device + self.predictor = DefaultPredictor(cfg) + logger.info("Predictor initialized successfully.") + if self.predictor is None: + raise RuntimeError("Predictor initialization failed, the predictor is None.") + self.initialized = True + logger.info("Model initialization complete.") + except Exception as e: + error_msg = "Error during model initialization" + logger.exception(error_msg) + self.error = str(e) + self.initialized = False + + def preprocess(self, batch): + """ + Transform raw input into model input data. + + Args: + batch (List[Dict]): List of raw requests, should match batch size. + + Returns: + List[np.ndarray]: List of preprocessed images. + """ + logger.info(f"Pre-processing started for a batch of {len(batch)}.") + + images = [] + for idx, request in enumerate(batch): + request_body = request.get("body") + if request_body is None: + error_msg = f"Request {idx} does not contain 'body'." + logger.error(error_msg) + raise ValueError(error_msg) + try: + image_stream = io.BytesIO(request_body) + try: + pil_image = Image.open(image_stream) + pil_image = pil_image.convert("RGB") + img = np.array(pil_image) + img = img[:, :, ::-1] + except UnidentifiedImageError as e: + error_msg = f"Failed to identify image for request {idx}. Error: {e}" + logger.error(error_msg) + raise ValueError(error_msg) + except Exception as e: + error_msg = f"Failed to decode image for request {idx}. Error: {e}" + logger.error(error_msg) + raise ValueError(error_msg) + images.append(img) + except Exception as e: + logger.exception(f"Error preprocessing request {idx}") + raise e + logger.info(f"Pre-processing finished for a batch of {len(batch)}.") + return images + + def inference(self, model_input): + """ + Perform inference on the model input. + + Args: + model_input (List[np.ndarray]): List of preprocessed images. + + Returns: + List[Dict]: List of inference outputs. + """ + logger.info(f"Inference started for a batch of {len(model_input)}.") + + outputs = [] + for idx, image in enumerate(model_input): + try: + logger.debug(f"Processing image {idx}: shape={image.shape}, dtype={image.dtype}") + output = self.predictor(image) + outputs.append(output) + except Exception as e: + logger.exception(f"Error during inference on image {idx}") + raise e + logger.info(f"Inference finished for a batch of {len(model_input)}.") + return outputs + + def postprocess(self, inference_outputs): + """ + Post-process the inference outputs to a serializable format. + + Args: + inference_outputs (List[Dict]): List of inference outputs. + + Returns: + List[str]: List of JSON strings containing predictions. + """ + start_time = time.time() + logger.info(f"Post-processing started at {start_time} for a batch of {len(inference_outputs)}.") + responses = [] + for idx, output in enumerate(inference_outputs): + try: + predictions = output["instances"].to("cpu") + logger.debug(f"Available prediction fields: {predictions.get_fields().keys()}") + response = {} + if predictions.has("pred_classes"): + classes = predictions.pred_classes.numpy().tolist() + response["classes"] = classes + if predictions.has("pred_boxes"): + boxes = predictions.pred_boxes.tensor.numpy().tolist() + response["boxes"] = boxes + if predictions.has("scores"): + scores = predictions.scores.numpy().tolist() + response["scores"] = scores + if predictions.has("pred_masks"): + response["masks_present"] = True + responses.append(json.dumps(response)) + except Exception as e: + logger.exception(f"Error during post-processing of output {idx}") + raise e + elapsed_time = time.time() - start_time + logger.info(f"Post-processing finished for a batch of {len(inference_outputs)} in {elapsed_time:.2f} seconds.") + + return responses + + def handle(self, data, context): + """ + Entry point for TorchServe to interact with the ModelHandler. + + Args: + data (List[Dict]): Input data. + context (Context): Model server context. + + Returns: + List[str]: List of predictions. + """ + logger.info("Handling request...") + start_time = time.time() + if not self.initialized: + self.initialize(context) + if not self.initialized: + error_message = f"Model failed to initialize: {self.error}" + logger.error(error_message) + return [error_message] + + if data is None: + error_message = "No data received for inference." + logger.error(error_message) + return [error_message] + + try: + preprocess_start = time.time() + model_input = self.preprocess(data) + preprocess_time = time.time() - preprocess_start + + inference_start = time.time() + model_output = self.inference(model_input) + inference_time = time.time() - inference_start + + postprocess_start = time.time() + output = self.postprocess(model_output) + postprocess_time = time.time() - postprocess_start + + total_time = time.time() - start_time + logger.info( + f"Handling request finished in {total_time:.2f} seconds. " + f"(Preprocess: {preprocess_time:.2f}s, " + f"Inference: {inference_time:.2f}s, " + f"Postprocess: {postprocess_time:.2f}s)" + ) + return output + except Exception as e: + error_message = f"Error in handling request: {str(e)}" + logger.exception(error_message) + return [error_message] +######################################################################################################################################## +_service = ModelHandler() + +def handle(data, context): + """ + Entry point for TorchServe to interact with the ModelHandler. + + Args: + data (List[Dict]): Input data. + context (Context): Model server context. + + Returns: + List[str]: List of predictions. + """ + return _service.handle(data, context) +######################################################################################################################################## \ No newline at end of file From 3dd5ce5c3210266f29090b4e69d588464723eb0b Mon Sep 17 00:00:00 2001 From: Muhammad Mudassar Date: Wed, 11 Dec 2024 22:24:20 +0000 Subject: [PATCH 3/5] added readme.md file. --- examples/object_detector/detectron2/README.md | 67 +++++++++++++++++++ 1 file changed, 67 insertions(+) create mode 100644 examples/object_detector/detectron2/README.md diff --git a/examples/object_detector/detectron2/README.md b/examples/object_detector/detectron2/README.md new file mode 100644 index 0000000000..adf922141a --- /dev/null +++ b/examples/object_detector/detectron2/README.md @@ -0,0 +1,67 @@ +# Object Detection using TorchServe and Detectron2 + +## Overview + +This folder leverages **TorchServe** to deploy a Detectron2-based object detection model using a custom handler. It provides scalable and efficient object detection capabilities with support for both CPU and GPU environments. + +--- + +## Table of Contents + +1. [Pre-requirements](#pre-requirements) +2. [Installation](#installation) +3. [Usage](#usage) +4. [Documentation](#documentation) +5. [Contributors](#contributors) + +--- + +## Pre-requirements + +- **Python 3.8 or higher** (tested on Python 3.10.15). + +--- + +## Installation + +Follow these steps to set up the project: + +1. Clone the repository: + + ```bash + git clone https://github.com/pytorch/serve.git + ``` + +2. Make sure the terminal's current directory is set to the folder where this README file is located: + + ```bash + cd serve/examples/object_detector/detectron2 + ``` + +3. Install dependencies: + + ```bash + pip install -r requirements.txt + pip install git+https://github.com/facebookresearch/detectron2.git && pip install numpy==1.21.6 + ``` + +--- + +## Usage + +Refer to the [Documentation](#documentation) for detailed usage instructions. + +--- + +## Documentation + +For detailed information on using TorchServe and Detectron2 for object detection, refer to the documentation provided in the [Upstart Commerce Blog](https://upstartcommerce.com/blogs/). + +--- + +## Contributors + +- **[Muhammad Mudassar](https://github.com/Mudassar-MLE)** + - [LinkedIn](https://www.linkedin.com/in/muhammad-mudassar-a65645192/) + - [Email](mailto:mmudassards@gmail.com) +--- From 65c6715604c2f0ff6136bffd1f22cf16f1dac0d5 Mon Sep 17 00:00:00 2001 From: Muhammad Mudassar Date: Thu, 19 Dec 2024 04:41:33 +0000 Subject: [PATCH 4/5] updated readme file. --- examples/object_detector/detectron2/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/object_detector/detectron2/README.md b/examples/object_detector/detectron2/README.md index adf922141a..85bc5a46ab 100644 --- a/examples/object_detector/detectron2/README.md +++ b/examples/object_detector/detectron2/README.md @@ -55,7 +55,7 @@ Refer to the [Documentation](#documentation) for detailed usage instructions. ## Documentation -For detailed information on using TorchServe and Detectron2 for object detection, refer to the documentation provided in the [Upstart Commerce Blog](https://upstartcommerce.com/blogs/). +For detailed information on using TorchServe and Detectron2 for object detection, refer to the documentation provided in the [Upstart Commerce Blog](https://upstartcommerce.com/optimizing-pytorch-model-serving-at-scale-with-torchserve/). --- From 715d1be3785116957661d9daad610fc9b1ead1c6 Mon Sep 17 00:00:00 2001 From: Muhammad Mudassar Date: Wed, 8 Jan 2025 18:03:46 +0000 Subject: [PATCH 5/5] updated readme file and used time from utility. --- examples/object_detector/detectron2/README.md | 121 ++++++++---------- .../detectron2/detectron2-handler.py | 30 +---- 2 files changed, 61 insertions(+), 90 deletions(-) diff --git a/examples/object_detector/detectron2/README.md b/examples/object_detector/detectron2/README.md index 85bc5a46ab..c0b297e345 100644 --- a/examples/object_detector/detectron2/README.md +++ b/examples/object_detector/detectron2/README.md @@ -1,67 +1,54 @@ -# Object Detection using TorchServe and Detectron2 - -## Overview - -This folder leverages **TorchServe** to deploy a Detectron2-based object detection model using a custom handler. It provides scalable and efficient object detection capabilities with support for both CPU and GPU environments. - ---- - -## Table of Contents - -1. [Pre-requirements](#pre-requirements) -2. [Installation](#installation) -3. [Usage](#usage) -4. [Documentation](#documentation) -5. [Contributors](#contributors) - ---- - -## Pre-requirements - -- **Python 3.8 or higher** (tested on Python 3.10.15). - ---- - -## Installation - -Follow these steps to set up the project: - -1. Clone the repository: - - ```bash - git clone https://github.com/pytorch/serve.git - ``` - -2. Make sure the terminal's current directory is set to the folder where this README file is located: - - ```bash - cd serve/examples/object_detector/detectron2 - ``` - -3. Install dependencies: - - ```bash - pip install -r requirements.txt - pip install git+https://github.com/facebookresearch/detectron2.git && pip install numpy==1.21.6 - ``` - ---- - -## Usage - -Refer to the [Documentation](#documentation) for detailed usage instructions. - ---- - -## Documentation - -For detailed information on using TorchServe and Detectron2 for object detection, refer to the documentation provided in the [Upstart Commerce Blog](https://upstartcommerce.com/optimizing-pytorch-model-serving-at-scale-with-torchserve/). - ---- - -## Contributors - -- **[Muhammad Mudassar](https://github.com/Mudassar-MLE)** - - [LinkedIn](https://www.linkedin.com/in/muhammad-mudassar-a65645192/) - - [Email](mailto:mmudassards@gmail.com) ---- +# Object Detection using torchvision's pretrained fast-rcnn model + +* Download the pre-trained fast-rcnn object detection model's state_dict from the following URL : + +https://download.pytorch.org/models/fasterrcnn_resnet50_fpn_coco-258fb6c6.pth + +```bash +wget https://download.pytorch.org/models/fasterrcnn_resnet50_fpn_coco-258fb6c6.pth +``` + +* Create a model archive file and serve the fastrcnn model in TorchServe using below commands + + ```bash + torch-model-archiver --model-name fastrcnn --version 1.0 --model-file examples/object_detector/fast-rcnn/model.py --serialized-file fasterrcnn_resnet50_fpn_coco-258fb6c6.pth --handler object_detector --extra-files examples/object_detector/index_to_name.json + mkdir model_store + mv fastrcnn.mar model_store/ + torchserve --start --model-store model_store --models fastrcnn=fastrcnn.mar --disable-token-auth --enable-model-api + curl http://127.0.0.1:8080/predictions/fastrcnn -T examples/object_detector/detectron2/person.jpg + ``` +* Note : The objects detected have scores greater than "0.5". This threshold value is set in object_detector handler. + +* Output + +```json +[ + { + "person": [ + 362.34539794921875, + 161.9876251220703, + 515.53662109375, + 385.2342834472656 + ], + "score": 0.9977679252624512 + }, + { + "handbag": [ + 67.37423706054688, + 277.63787841796875, + 111.6810073852539, + 400.26470947265625 + ], + "score": 0.9925485253334045 + }, + { + "handbag": [ + 228.7159423828125, + 145.87753295898438, + 303.5065612792969, + 231.10513305664062 + ], + "score": 0.9921919703483582 + } +] +``` diff --git a/examples/object_detector/detectron2/detectron2-handler.py b/examples/object_detector/detectron2/detectron2-handler.py index 78ed8c2b4c..9bb7edb005 100644 --- a/examples/object_detector/detectron2/detectron2-handler.py +++ b/examples/object_detector/detectron2/detectron2-handler.py @@ -1,11 +1,11 @@ import io import json -import time import torch import logging import numpy as np from os import path from detectron2.config import get_cfg +from ts.handler_utils.timer import timed from PIL import Image, UnidentifiedImageError from detectron2.engine import DefaultPredictor from detectron2.utils.logger import setup_logger @@ -94,6 +94,7 @@ def initialize(self, context): self.error = str(e) self.initialized = False + @timed def preprocess(self, batch): """ Transform raw input into model input data. @@ -135,6 +136,7 @@ def preprocess(self, batch): logger.info(f"Pre-processing finished for a batch of {len(batch)}.") return images + @timed def inference(self, model_input): """ Perform inference on the model input. @@ -158,7 +160,7 @@ def inference(self, model_input): raise e logger.info(f"Inference finished for a batch of {len(model_input)}.") return outputs - + @timed def postprocess(self, inference_outputs): """ Post-process the inference outputs to a serializable format. @@ -169,8 +171,7 @@ def postprocess(self, inference_outputs): Returns: List[str]: List of JSON strings containing predictions. """ - start_time = time.time() - logger.info(f"Post-processing started at {start_time} for a batch of {len(inference_outputs)}.") + logger.info(f"Post-processing for a batch of {len(inference_outputs)}.") responses = [] for idx, output in enumerate(inference_outputs): try: @@ -192,11 +193,11 @@ def postprocess(self, inference_outputs): except Exception as e: logger.exception(f"Error during post-processing of output {idx}") raise e - elapsed_time = time.time() - start_time - logger.info(f"Post-processing finished for a batch of {len(inference_outputs)} in {elapsed_time:.2f} seconds.") + logger.info(f"Post-processing finished for a batch of {len(inference_outputs)}.") return responses + @timed def handle(self, data, context): """ Entry point for TorchServe to interact with the ModelHandler. @@ -209,7 +210,6 @@ def handle(self, data, context): List[str]: List of predictions. """ logger.info("Handling request...") - start_time = time.time() if not self.initialized: self.initialize(context) if not self.initialized: @@ -223,25 +223,9 @@ def handle(self, data, context): return [error_message] try: - preprocess_start = time.time() model_input = self.preprocess(data) - preprocess_time = time.time() - preprocess_start - - inference_start = time.time() model_output = self.inference(model_input) - inference_time = time.time() - inference_start - - postprocess_start = time.time() output = self.postprocess(model_output) - postprocess_time = time.time() - postprocess_start - - total_time = time.time() - start_time - logger.info( - f"Handling request finished in {total_time:.2f} seconds. " - f"(Preprocess: {preprocess_time:.2f}s, " - f"Inference: {inference_time:.2f}s, " - f"Postprocess: {postprocess_time:.2f}s)" - ) return output except Exception as e: error_message = f"Error in handling request: {str(e)}"