Skip to content

Commit fa4624f

Browse files
committed
isort then black
1 parent a9c5962 commit fa4624f

33 files changed

+2231
-1143
lines changed

deepstack/init.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,14 @@
1-
import sqlite3
2-
from sqlite3 import Cursor,Error
31
import os
2+
import sqlite3
3+
from sqlite3 import Cursor, Error
44

55
DATA_DIR = "/datastore"
66

77
CREATE_TABLE = "CREATE TABLE IF NOT EXISTS TB_EMBEDDINGS(userid TEXT PRIMARY KEY, embedding TEXT NOT NULL)"
88
CREATE_TABLE2 = "CREATE TABLE IF NOT EXISTS TB_EMBEDDINGS2(userid TEXT PRIMARY KEY, embedding TEXT NOT NULL)"
9-
conn = sqlite3.connect(DATA_DIR+"/faceembedding.db")
9+
conn = sqlite3.connect(DATA_DIR + "/faceembedding.db")
1010
cursor = conn.cursor()
1111
cursor.execute(CREATE_TABLE)
1212
cursor.execute(CREATE_TABLE2)
1313
conn.commit()
1414
conn.close()
15-

deepstack/intelligencelayer/shared/commons/utils.py

Lines changed: 24 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -2,82 +2,85 @@
22
import torch.nn as nn
33
import torch.nn.functional as F
44

5-
def load_model(model,path):
5+
6+
def load_model(model, path):
67
checkpoint = torch.load(path, map_location=lambda storage, loc: storage)
78

89
try:
910
model.load_state_dict(checkpoint)
10-
11+
1112
except:
1213
copy = dict()
1314
for x, y in zip(model.state_dict(), checkpoint):
14-
new_name = y[y.index(x):]
15+
new_name = y[y.index(x) :]
1516
copy[new_name] = checkpoint[y]
1617

17-
def l2_norm(input,axis=1):
18-
norm = torch.norm(input,2,axis,True)
18+
19+
def l2_norm(input, axis=1):
20+
norm = torch.norm(input, 2, axis, True)
1921
output = torch.div(input, norm)
2022
return output
2123

24+
2225
def compute_distance(embeddings, embeddings2):
2326

24-
diff = embeddings.unsqueeze(-1) - embeddings2.transpose(1,0).unsqueeze(0)
25-
distance = torch.sum(torch.pow(diff,2),dim=1)
27+
diff = embeddings.unsqueeze(-1) - embeddings2.transpose(1, 0).unsqueeze(0)
28+
distance = torch.sum(torch.pow(diff, 2), dim=1)
2629

2730
return distance
2831

2932

3033
class _GlobalPoolNd(nn.Module):
31-
def __init__(self,flatten=True):
34+
def __init__(self, flatten=True):
3235
"""
3336
3437
:param flatten:
3538
"""
36-
super(_GlobalPoolNd,self).__init__()
39+
super(_GlobalPoolNd, self).__init__()
3740
self.flatten = flatten
3841

39-
def pool(self,input):
42+
def pool(self, input):
4043
"""
4144
4245
:param input:
4346
:return:
4447
"""
4548
raise NotImplementedError()
4649

47-
def forward(self,input):
50+
def forward(self, input):
4851
"""
4952
5053
:param input:
5154
:return:
5255
"""
5356
input = self.pool(input)
5457
size_0 = input.size(1)
55-
return input.view(-1,size_0) if self.flatten else input
58+
return input.view(-1, size_0) if self.flatten else input
59+
5660

5761
class GlobalAvgPool2d(_GlobalPoolNd):
5862
def __init__(self, flatten=True):
5963
"""
6064
6165
:param flatten:
6266
"""
63-
super(GlobalAvgPool2d,self).__init__(flatten)
67+
super(GlobalAvgPool2d, self).__init__(flatten)
6468

6569
def pool(self, input):
66-
return F.adaptive_avg_pool2d(input,1)
70+
return F.adaptive_avg_pool2d(input, 1)
71+
6772

68-
6973
class Flatten(nn.Module):
7074
def forward(self, input):
7175
return input.view(input.size(0), -1)
7276

77+
7378
class UpSampleInterpolate(nn.Module):
74-
def __init__(self,scale_factor):
75-
super(UpSampleInterpolate,self).__init__()
79+
def __init__(self, scale_factor):
80+
super(UpSampleInterpolate, self).__init__()
7681

7782
self.scale_factor = scale_factor
7883

79-
def forward(self,x):
80-
81-
return F.interpolate(x,scale_factor=self.scale_factor,mode="nearest")
82-
84+
def forward(self, x):
8385

86+
return F.interpolate(x, scale_factor=self.scale_factor, mode="nearest")
Lines changed: 65 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -1,34 +1,36 @@
1-
2-
import torch
3-
import time
4-
import json
5-
import io
61
import _thread as thread
7-
from multiprocessing import Process
8-
from PIL import Image,UnidentifiedImageError
9-
import torch.nn.functional as F
102
import ast
3+
import io
4+
import json
5+
import os
116
import sqlite3
12-
import numpy as np
13-
import warnings
147
import sys
15-
import os
8+
import time
9+
import warnings
10+
from multiprocessing import Process
11+
12+
import numpy as np
13+
import torch
14+
import torch.nn.functional as F
15+
from PIL import Image, UnidentifiedImageError
16+
1617
sys.path.insert(0, os.path.join(os.path.dirname(os.path.realpath(__file__)), "../../"))
1718

18-
from process import YOLODetector
19-
from shared import SharedOptions
19+
import argparse
20+
import traceback
2021

2122
import torchvision.transforms as transforms
22-
import traceback
2323
from PIL import UnidentifiedImageError
24-
import argparse
24+
from process import YOLODetector
25+
from shared import SharedOptions
2526

2627
parser = argparse.ArgumentParser()
27-
parser.add_argument("--model",type=str,default=None)
28-
parser.add_argument("--name",type=str,default=None)
28+
parser.add_argument("--model", type=str, default=None)
29+
parser.add_argument("--name", type=str, default=None)
2930

3031
opt = parser.parse_args()
3132

33+
3234
def objectdetection(thread_name: str, delay: float):
3335

3436
MODE = SharedOptions.MODE
@@ -40,49 +42,50 @@ def objectdetection(thread_name: str, delay: float):
4042
if opt.name == None:
4143
IMAGE_QUEUE = "detection_queue"
4244
else:
43-
IMAGE_QUEUE = opt.name+"_queue"
44-
45+
IMAGE_QUEUE = opt.name + "_queue"
46+
4547
if opt.model == None:
46-
model_path = os.path.join(SHARED_APP_DIR,SharedOptions.SETTINGS.DETECTION_MODEL)
48+
model_path = os.path.join(
49+
SHARED_APP_DIR, SharedOptions.SETTINGS.DETECTION_MODEL
50+
)
4751
else:
4852
model_path = opt.model
49-
53+
5054
if MODE == "High":
5155

5256
reso = SharedOptions.SETTINGS.DETECTION_HIGH
53-
57+
5458
elif MODE == "Medium":
55-
59+
5660
reso = SharedOptions.SETTINGS.DETECTION_MEDIUM
57-
61+
5862
elif MODE == "Low":
5963

6064
reso = SharedOptions.SETTINGS.DETECTION_LOW
6165

62-
detector = YOLODetector(model_path,reso,cuda=CUDA_MODE)
66+
detector = YOLODetector(model_path, reso, cuda=CUDA_MODE)
6367
while True:
64-
queue = db.lrange(IMAGE_QUEUE,0,0)
68+
queue = db.lrange(IMAGE_QUEUE, 0, 0)
69+
70+
db.ltrim(IMAGE_QUEUE, len(queue), -1)
6571

66-
db.ltrim(IMAGE_QUEUE,len(queue), -1)
67-
6872
if len(queue) > 0:
6973

7074
for req_data in queue:
71-
72-
req_data = json.JSONDecoder().decode(req_data)
7375

76+
req_data = json.JSONDecoder().decode(req_data)
7477

7578
img_id = req_data["imgid"]
7679
req_id = req_data["reqid"]
7780
req_type = req_data["reqtype"]
7881
threshold = float(req_data["minconfidence"])
79-
82+
8083
try:
8184

82-
img = os.path.join(TEMP_PATH,img_id)
83-
84-
det = detector.predict(img,threshold)
85-
85+
img = os.path.join(TEMP_PATH, img_id)
86+
87+
det = detector.predict(img, threshold)
88+
8689
outputs = []
8790

8891
for *xyxy, conf, cls in reversed(det):
@@ -94,32 +97,47 @@ def objectdetection(thread_name: str, delay: float):
9497

9598
label = detector.names[int(cls.item())]
9699

97-
detection = {"confidence":score,"label":label, "x_min":int(x_min), "y_min":int(y_min),"x_max":int(x_max), "y_max":int(y_max)}
100+
detection = {
101+
"confidence": score,
102+
"label": label,
103+
"x_min": int(x_min),
104+
"y_min": int(y_min),
105+
"x_max": int(x_max),
106+
"y_max": int(y_max),
107+
}
98108

99109
outputs.append(detection)
100110

101-
output = {"success":True,"predictions":outputs}
111+
output = {"success": True, "predictions": outputs}
102112

103113
except UnidentifiedImageError:
104114
err_trace = traceback.format_exc()
105-
print(err_trace,file=sys.stderr,flush=True)
115+
print(err_trace, file=sys.stderr, flush=True)
116+
117+
output = {
118+
"success": False,
119+
"error": "invalid image file",
120+
"code": 400,
121+
}
106122

107-
output = {"success":False, "error":"invalid image file","code":400}
108-
109123
except Exception:
110124

111125
err_trace = traceback.format_exc()
112-
print(err_trace,file=sys.stderr,flush=True)
113-
114-
output = {"success":False, "error":"error occured on the server","code":500}
115-
126+
print(err_trace, file=sys.stderr, flush=True)
127+
128+
output = {
129+
"success": False,
130+
"error": "error occured on the server",
131+
"code": 500,
132+
}
133+
116134
finally:
117-
db.set(req_id,json.dumps(output))
135+
db.set(req_id, json.dumps(output))
118136
if os.path.exists(TEMP_PATH + img_id):
119137
os.remove(img)
120138

121139
time.sleep(delay)
122140

123-
p = Process(target=objectdetection,args=("",SharedOptions.SLEEP_TIME))
124-
p.start()
125141

142+
p = Process(target=objectdetection, args=("", SharedOptions.SLEEP_TIME))
143+
p.start()

0 commit comments

Comments
 (0)