|
14 | 14 | # under the License. |
15 | 15 |
|
16 | 16 | """ ``cord.client`` provides a simple Python client that allows you |
17 | | -to query project resources through the Cord API. |
| 17 | +to query project resources through the Cord REST API. |
18 | 18 |
|
19 | 19 | Here is a simple example for instantiating the client and obtaining project info: |
20 | 20 |
|
|
32 | 32 |
|
33 | 33 | import sys |
34 | 34 | import logging |
| 35 | +import json |
| 36 | +import base64 |
35 | 37 |
|
36 | 38 | from cord.configs import CordConfig |
37 | 39 | from cord.http.querier import Querier |
38 | 40 | from cord.orm.project import Project |
39 | 41 | from cord.orm.label_blurb import Label |
| 42 | +from cord.orm.model import Model, ModelInferenceParams |
40 | 43 |
|
41 | 44 | # Logging configuration |
42 | 45 | logging.basicConfig(stream=sys.stdout, |
@@ -146,5 +149,49 @@ def save_label_blurb(self, uid, label): |
146 | 149 | CorruptedLabelError: If a blurb is corrupted (e.g. if the frame labels have more frames than the video). |
147 | 150 | """ |
148 | 151 | label = Label(label) |
149 | | - return self._querier.basic_setter(Label, uid, label) |
| 152 | + return self._querier.basic_setter(Label, uid, payload=label) |
| 153 | + |
| 154 | + def model_inference(self, |
| 155 | + uid, |
| 156 | + file_path, |
| 157 | + conf_thresh=0.6, |
| 158 | + iou_thresh=0.3, |
| 159 | + device="cuda", |
| 160 | + detection_frame_range=None, |
| 161 | + ): |
| 162 | + """ |
| 163 | + Run inference with model trained on the platform. |
| 164 | +
|
| 165 | + Args: |
| 166 | + uid: A model_iteration_hash (uid) string. |
| 167 | + file_path: Local file path to image or video |
| 168 | + conf_thresh: Confidence threshold (default 0.6) |
| 169 | + iou_thresh: Intersection over union threshold (default 0.3) |
| 170 | + device: Device (CPU or CUDA, default is CUDA) |
| 171 | + detection_frame_range: Detection frame range (optional, if video) |
| 172 | +
|
| 173 | + Returns: |
| 174 | + Inference results: A dict of inference results. |
| 175 | +
|
| 176 | + Raises: |
| 177 | + AuthenticationError: If the project API key is invalid. |
| 178 | + AuthorisationError: If access to the specified resource is restricted. |
| 179 | + ResourceNotFoundError: If no model exists by the specified model_iteration_hash (uid). |
| 180 | + UnknownError: If an error occurs while running inference. |
| 181 | + FileTypeNotSupportedError: If the file type is not supported for inference (has to be an image or video) |
| 182 | + MustSetDetectionRangeError: If a detection range is not set for video inference |
| 183 | + """ |
| 184 | + if detection_frame_range is None: |
| 185 | + detection_frame_range = [] |
| 186 | + |
| 187 | + file = open(file_path, 'rb').read() |
| 188 | + |
| 189 | + params = ModelInferenceParams({ |
| 190 | + 'file': base64.b64encode(file).decode('utf-8'), |
| 191 | + 'conf_thresh': conf_thresh, |
| 192 | + 'iou_thresh': iou_thresh, |
| 193 | + 'device': device, |
| 194 | + 'detection_frame_range': detection_frame_range, |
| 195 | + }) |
150 | 196 |
|
| 197 | + return self._querier.basic_setter(Model, uid, payload=params) |
0 commit comments