Skip to content

Commit 0c86cf6

Browse files
committed
added model inference
1 parent 84915de commit 0c86cf6

File tree

6 files changed

+135
-7
lines changed

6 files changed

+135
-7
lines changed

cord/client.py

Lines changed: 49 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
# under the License.
1515

1616
""" ``cord.client`` provides a simple Python client that allows you
17-
to query project resources through the Cord API.
17+
to query project resources through the Cord REST API.
1818
1919
Here is a simple example for instantiating the client and obtaining project info:
2020
@@ -32,11 +32,14 @@
3232

3333
import sys
3434
import logging
35+
import json
36+
import base64
3537

3638
from cord.configs import CordConfig
3739
from cord.http.querier import Querier
3840
from cord.orm.project import Project
3941
from cord.orm.label_blurb import Label
42+
from cord.orm.model import Model, ModelInferenceParams
4043

4144
# Logging configuration
4245
logging.basicConfig(stream=sys.stdout,
@@ -146,5 +149,49 @@ def save_label_blurb(self, uid, label):
146149
CorruptedLabelError: If a blurb is corrupted (e.g. if the frame labels have more frames than the video).
147150
"""
148151
label = Label(label)
149-
return self._querier.basic_setter(Label, uid, label)
152+
return self._querier.basic_setter(Label, uid, payload=label)
153+
154+
def model_inference(self,
155+
uid,
156+
file_path,
157+
conf_thresh=0.6,
158+
iou_thresh=0.3,
159+
device="cuda",
160+
detection_frame_range=None,
161+
):
162+
"""
163+
Run inference with model trained on the platform.
164+
165+
Args:
166+
uid: A model_iteration_hash (uid) string.
167+
file_path: Local file path to image or video
168+
conf_thresh: Confidence threshold (default 0.6)
169+
iou_thresh: Intersection over union threshold (default 0.3)
170+
device: Device (CPU or CUDA, default is CUDA)
171+
detection_frame_range: Detection frame range (optional, if video)
172+
173+
Returns:
174+
Inference results: A dict of inference results.
175+
176+
Raises:
177+
AuthenticationError: If the project API key is invalid.
178+
AuthorisationError: If access to the specified resource is restricted.
179+
ResourceNotFoundError: If no model exists by the specified model_iteration_hash (uid).
180+
UnknownError: If an error occurs while running inference.
181+
FileTypeNotSupportedError: If the file type is not supported for inference (has to be an image or video)
182+
MustSetDetectionRangeError: If a detection range is not set for video inference
183+
"""
184+
if detection_frame_range is None:
185+
detection_frame_range = []
186+
187+
file = open(file_path, 'rb').read()
188+
189+
params = ModelInferenceParams({
190+
'file': base64.b64encode(file).decode('utf-8'),
191+
'conf_thresh': conf_thresh,
192+
'iou_thresh': iou_thresh,
193+
'device': device,
194+
'detection_frame_range': detection_frame_range,
195+
})
150196

197+
return self._querier.basic_setter(Model, uid, payload=params)

cord/configs.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,12 @@
1919
from abc import ABCMeta
2020
import cord.exceptions
2121

22-
CORD_ENDPOINT = 'https://api.cord.tech/public'
22+
CORD_ENDPOINT = 'http://127.0.0.1:8000/public'
2323
_CORD_PROJECT_ID = 'CORD_PROJECT_ID'
2424
_CORD_API_KEY = 'CORD_API_KEY'
2525

26-
READ_TIMEOUT = 15 # In seconds
27-
WRITE_TIMEOUT = 30 # In seconds
26+
READ_TIMEOUT = 120 # In seconds
27+
WRITE_TIMEOUT = 120 # In seconds
2828
CONNECT_TIMEOUT = 15 # In seconds
2929

3030
log = logging.getLogger(__name__)

cord/exceptions.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,5 +70,21 @@ class AnswerDictionaryError(CordException):
7070

7171

7272
class CorruptedLabelError(CordException):
73-
""" Exception thrown when a label is corrupted (e.g. the frame labels have more frames than the video. """
73+
""" Exception thrown when a label is corrupted (e.g. the frame labels have more frames than the video.) """
74+
pass
75+
76+
77+
class FileTypeNotSupportedError(CordException):
78+
""" Exception thrown when a file type is not supported
79+
Supported file types are: image/jpeg, image/png, video/webm, video/mp4. """
80+
pass
81+
82+
83+
class MustSetDetectionRangeError(CordException):
84+
""" Exception thrown when a detection range is not set for video """
85+
pass
86+
87+
88+
class DetectionRangeInvalidError(CordException):
89+
""" Exception thrown when a detection range is invalid (e.g. negative or higher than num frames in video) """
7490
pass

cord/http/error_utils.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,9 @@
88
OPERATION_NOT_ALLOWED_ERROR = ['OPERATION_NOT_ALLOWED']
99
ANSWER_DICTIONARY_ERROR = ['ANSWER_DICTIONARY_ERROR']
1010
CORRUPTED_LABEL_ERROR = ['CORRUPTED_LABEL_ERROR']
11+
FILE_TYPE_NOT_SUPPORTED_ERROR = ['FILE_TYPE_NOT_SUPPORTED_ERROR']
12+
MUST_SET_DETECTION_RANGE_ERROR = ['MUST_SET_DETECTION_RANGE_ERROR']
13+
DETECTION_RANGE_INVALID_ERROR = ['DETECTION_RANGE_INVALID_ERROR']
1114

1215

1316
def check_error_response(response):
@@ -36,4 +39,14 @@ def check_error_response(response):
3639
raise CorruptedLabelError("The label blurb is corrupted. This could be due to the number of "
3740
"frame labels exceeding the number of frames in the labelled video.")
3841

42+
if response == FILE_TYPE_NOT_SUPPORTED_ERROR:
43+
raise FileTypeNotSupportedError("Supported file types are: image/jpeg, image/png, video/webm, video/mp4.")
44+
45+
if response == MUST_SET_DETECTION_RANGE_ERROR:
46+
raise MustSetDetectionRangeError("You must set a detection range for video inference")
47+
48+
if response == DETECTION_RANGE_INVALID_ERROR:
49+
raise DetectionRangeInvalidError("The detection range is invalid (e.g. less than 0, or"
50+
" higher than num frames in the video)")
51+
3952
pass

cord/http/querier.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def basic_setter(self, db_object_type, uid, payload):
6868
res = self.execute(request)
6969

7070
if res:
71-
return True
71+
return res
7272
else:
7373
raise RequestException("Setting %s with uid %s failed." % (db_object_type, uid))
7474

cord/orm/model.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
#
2+
# Copyright (c) 2020 Cord Technologies Limited
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License"); you may
5+
# not use this file except in compliance with the License. You may obtain
6+
# a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12+
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13+
# License for the specific language governing permissions and limitations
14+
# under the License.
15+
16+
from collections import OrderedDict
17+
18+
from cord.orm import base_orm
19+
20+
21+
class Model(base_orm.BaseORM):
22+
"""
23+
Model base ORM.
24+
25+
ORM:
26+
27+
"""
28+
29+
DB_FIELDS = OrderedDict([])
30+
31+
32+
class ModelInferenceParams(base_orm.BaseORM):
33+
"""
34+
Model inference parameters for running models trained via the platform.
35+
36+
ORM:
37+
38+
local_file_path,
39+
conf_thresh,
40+
iou_thresh,
41+
device
42+
detection_frame_range (optional)
43+
44+
"""
45+
46+
DB_FIELDS = OrderedDict([
47+
("file", str),
48+
("conf_thresh", float), # Confidence threshold
49+
("iou_thresh", float), # Intersection over union threshold
50+
("device", str),
51+
("detection_frame_range", list)
52+
])

0 commit comments

Comments
 (0)