Skip to content

Commit 09e8e26

Browse files
committed
Add new python scripts.
1 parent 3e117e9 commit 09e8e26

3 files changed

+1250
-39
lines changed

app/asl_classify_files.py

+215
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,215 @@
1+
'''
2+
Copyright 2023 Avnet Inc.
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
http://www.apache.org/licenses/LICENSE-2.0
7+
Unless required by applicable law or agreed to in writing, software
8+
distributed under the License is distributed on an "AS IS" BASIS,
9+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10+
See the License for the specific language governing permissions and
11+
limitations under the License.
12+
'''
13+
14+
from ctypes import *
15+
from typing import List
16+
import cv2
17+
import numpy as np
18+
import vart
19+
import os
20+
import pathlib
21+
import xir
22+
import threading
23+
import time
24+
import sys
25+
import argparse
26+
27+
divider = '------------------------------------'
28+
29+
def preprocess_fn(image_path, fix_scale):
30+
'''
31+
Image pre-processing.
32+
Rearranges from BGR to RGB then normalizes to range 0:1
33+
and then scales by input quantization scaling factor
34+
input arg: path of image file
35+
return: numpy array
36+
'''
37+
image = cv2.imread(image_path)
38+
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
39+
image = image * fix_scale
40+
image = image.astype(np.int8)
41+
return image
42+
43+
44+
def get_child_subgraph_dpu(graph: "Graph") -> List["Subgraph"]:
45+
assert graph is not None, "'graph' should not be None."
46+
root_subgraph = graph.get_root_subgraph()
47+
assert (root_subgraph is not None), "Failed to get root subgraph of input Graph object."
48+
if root_subgraph.is_leaf:
49+
return []
50+
child_subgraphs = root_subgraph.toposort_child_subgraph()
51+
assert child_subgraphs is not None and len(child_subgraphs) > 0
52+
return [
53+
cs
54+
for cs in child_subgraphs
55+
if cs.has_attr("device") and cs.get_attr("device").upper() == "DPU"
56+
]
57+
58+
59+
def runDPU(id,start,dpu,img):
60+
'''get tensor'''
61+
inputTensors = dpu.get_input_tensors()
62+
outputTensors = dpu.get_output_tensors()
63+
input_ndim = tuple(inputTensors[0].dims)
64+
output_ndim = tuple(outputTensors[0].dims)
65+
66+
# we can avoid output scaling if use argmax instead of softmax
67+
#output_fixpos = outputTensors[0].get_attr("fix_point")
68+
#output_scale = 1 / (2**output_fixpos)
69+
70+
batchSize = input_ndim[0]
71+
n_of_images = len(img)
72+
count = 0
73+
write_index = start
74+
ids=[]
75+
ids_max = 50
76+
outputData = []
77+
for i in range(ids_max):
78+
outputData.append([np.empty(output_ndim, dtype=np.int8, order="C")])
79+
while count < n_of_images:
80+
if (count+batchSize<=n_of_images):
81+
runSize = batchSize
82+
else:
83+
runSize=n_of_images-count
84+
85+
'''prepare batch input/output '''
86+
inputData = []
87+
inputData = [np.empty(input_ndim, dtype=np.int8, order="C")]
88+
89+
'''init input image to input buffer '''
90+
for j in range(runSize):
91+
imageRun = inputData[0]
92+
imageRun[j, ...] = img[(count + j) % n_of_images].reshape(input_ndim[1:])
93+
'''run with batch '''
94+
job_id = dpu.execute_async(inputData,outputData[len(ids)])
95+
ids.append((job_id,runSize,start+count))
96+
count = count + runSize
97+
if count<n_of_images:
98+
if len(ids) < ids_max-1:
99+
continue
100+
for index in range(len(ids)):
101+
dpu.wait(ids[index][0])
102+
write_index = ids[index][2]
103+
'''store output vectors '''
104+
for j in range(ids[index][1]):
105+
# we can avoid output scaling if use argmax instead of softmax
106+
# out_q[write_index] = np.argmax(outputData[0][j] * output_scale)
107+
out_q[write_index] = np.argmax(outputData[index][0][j])
108+
write_index += 1
109+
ids=[]
110+
111+
112+
def app(image_dir,threads,model):
113+
114+
listimage=os.listdir(image_dir)
115+
runTotal = len(listimage)
116+
117+
global out_q
118+
out_q = [None] * runTotal
119+
g = xir.Graph.deserialize(model)
120+
subgraphs = get_child_subgraph_dpu(g)
121+
all_dpu_runners = []
122+
for i in range(threads):
123+
all_dpu_runners.append(vart.Runner.create_runner(subgraphs[0], "run"))
124+
125+
# input scaling
126+
input_fixpos = all_dpu_runners[0].get_input_tensors()[0].get_attr("fix_point")
127+
input_scale = 2**input_fixpos
128+
print('[INFO] fix_point = ',input_fixpos)
129+
print('[INFO] input_scale = ',input_scale)
130+
131+
''' preprocess images '''
132+
print (divider)
133+
print('Pre-processing',runTotal,'images...')
134+
img = []
135+
for i in range(runTotal):
136+
path = os.path.join(image_dir,listimage[i])
137+
img.append(preprocess_fn(path, input_scale))
138+
139+
'''run threads '''
140+
print('Starting',threads,'threads...')
141+
threadAll = []
142+
start=0
143+
for i in range(threads):
144+
if (i==threads-1):
145+
end = len(img)
146+
else:
147+
end = start+(len(img)//threads)
148+
in_q = img[start:end]
149+
t1 = threading.Thread(target=runDPU, args=(i,start,all_dpu_runners[i], in_q))
150+
threadAll.append(t1)
151+
start=end
152+
153+
time1 = time.time()
154+
for x in threadAll:
155+
x.start()
156+
for x in threadAll:
157+
x.join()
158+
time2 = time.time()
159+
timetotal = time2 - time1
160+
161+
fps = float(runTotal / timetotal)
162+
print (divider)
163+
print("Throughput=%.2f fps, total frames = %.0f, time=%.4f seconds" %(fps, runTotal, timetotal))
164+
165+
166+
''' post-processing '''
167+
classes = ['A','B','C','D','E','F','G','H','I','J','K','L','M','N','O','P','Q','R','S','T','U','V','W','X','Y','Z','del','nothing','space']
168+
correct = 0
169+
wrong = 0
170+
print('Post-processing',len(out_q),'test images..')
171+
for i in range(len(out_q)):
172+
#prediction = classes[out_q[i]]
173+
#ground_truth, _ = listimage[i].split('.',1)
174+
prediction = out_q[i]
175+
path_split = listimage[i].split('_')
176+
# ['test0174', '2', 'C.png']
177+
ground_truth = path_split[len(path_split)-2]
178+
# '2'
179+
ground_truth = int( ground_truth )
180+
# 2
181+
#print(listimage[i],classes[ground_truth],'=>',classes[prediction])
182+
if (ground_truth==prediction):
183+
correct += 1
184+
else:
185+
wrong += 1
186+
print(listimage[i],classes[ground_truth],'=>',classes[prediction])
187+
accuracy = correct/len(out_q)
188+
print('Correct:%d, Wrong:%d, Accuracy:%.4f' %(correct,wrong,accuracy))
189+
print (divider)
190+
191+
return
192+
193+
194+
195+
# only used if script is run as 'main' from command line
196+
def main():
197+
198+
# construct the argument parser and parse the arguments
199+
ap = argparse.ArgumentParser()
200+
ap.add_argument('-d', '--image_dir', type=str, default='test-images', help='Path to folder of test images. Default is test-images')
201+
ap.add_argument('-t', '--threads', type=int, default=1, help='Number of threads. Default is 1')
202+
ap.add_argument('-m', '--model', type=str, default='asl_classifier.xmodel', help='Path of xmodel. Default is asl_classifier.xmodel')
203+
204+
args = ap.parse_args()
205+
206+
print ('Command line options:')
207+
print (' --image_dir : ', args.image_dir)
208+
print (' --threads : ', args.threads)
209+
print (' --model : ', args.model)
210+
211+
app(args.image_dir,args.threads,args.model)
212+
213+
if __name__ == '__main__':
214+
main()
215+

app/asl_detect_live.py app/asl_classify_live.py

+64-23
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,15 @@
1+
'''
2+
Copyright 2023 Avnet Inc.
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
http://www.apache.org/licenses/LICENSE-2.0
7+
Unless required by applicable law or agreed to in writing, software
8+
distributed under the License is distributed on an "AS IS" BASIS,
9+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10+
See the License for the specific language governing permissions and
11+
limitations under the License.
12+
'''
113
#
214
# ASL Classification (live with USB camera)
315
#
@@ -47,16 +59,16 @@ def get_video_dev_by_name(src):
4759
if src in line:
4860
return dev
4961

50-
def detect_dpu_architecture():
51-
proc = subprocess.run(['xdputil','query'], capture_output=True, encoding='utf8')
52-
for line in proc.stdout.splitlines():
53-
if 'DPU Arch' in line:
54-
# "DPU Arch":"DPUCZDX8G_ISA0_B128_01000020E2012208",
55-
#dpu_arch = re.search('DPUCZDX8G_ISA0_(.+?)_', line).group(1)
56-
# "DPU Arch":"DPUCZDX8G_ISA1_B2304",
57-
#dpu_arch = re.search('DPUCZDX8G_ISA1_(.+?)', line).group(1)
58-
dpu_arch = "B2304"
59-
return dpu_arch
62+
# ...work in progress ...
63+
#def detect_dpu_architecture():
64+
# proc = subprocess.run(['xdputil','query'], capture_output=True, encoding='utf8')
65+
# for line in proc.stdout.splitlines():
66+
# if 'DPU Arch' in line:
67+
# # "DPU Arch":"DPUCZDX8G_ISA0_B128_01000020E2012208",
68+
# #dpu_arch = re.search('DPUCZDX8G_ISA0_(.+?)_', line).group(1)
69+
# # "DPU Arch":"DPUCZDX8G_ISA1_B2304",
70+
# #dpu_arch = re.search('DPUCZDX8G_ISA1_(.+?)', line).group(1)
71+
# return dpu_arch
6072

6173
# Parameters (tweaked for video)
6274
scale = 1.0
@@ -76,7 +88,10 @@ def detect_dpu_architecture():
7688
input_video = dev_video
7789
print("[INFO] Input Video : ",input_video)
7890

79-
displayReference = True
91+
output_dir = './captured-images'
92+
93+
if not os.path.exists(output_dir):
94+
os.mkdir(output_dir) # Create the output directory if it doesn't already exist
8095

8196
cv2.namedWindow('ASL Classification')
8297

@@ -158,11 +173,22 @@ def TopK(datain, size, filePath):
158173
print("Top[%d] %d %s" % (i, idx, (line.strip)("\n")))
159174
idx = idx + 1
160175

161-
dpu_arch = detect_dpu_architecture()
162-
print('[INFO] Detected DPU architecture : ',dpu_arch)
176+
# construct the argument parser and parse the arguments
177+
ap = argparse.ArgumentParser()
178+
ap.add_argument('-m', '--model', type=str, default='asl_classifier.xmodel', help='Path of xmodel. Default is asl_classifier.xmodel')
179+
180+
args = ap.parse_args()
181+
182+
print ('Command line options:')
183+
print (' --model : ', args.model)
184+
185+
#dpu_arch = detect_dpu_architecture()
186+
#print('[INFO] Detected DPU architecture : ',dpu_arch)
187+
#
188+
#model_path = './model_1/'+dpu_arch+'/asl_classifier_1.xmodel'
189+
#print('[INFO] ASL model : ',model_path)
190+
model_path = args.model
163191

164-
model_path = './model_1/'+dpu_arch+'/asl_classifier_1.xmodel'
165-
print('[INFO] ASL model : ',model_path)
166192

167193
# Create DPU runner
168194
g = xir.Graph.deserialize(model_path)
@@ -255,11 +281,26 @@ def TopK(datain, size, filePath):
255281
output = image.copy()
256282

257283
asl_id = -1
258-
#try:
259-
if True:
284+
try:
285+
# 448x448 ROI for classification
286+
#y1 = (16)
287+
#y2 = (16+448)
288+
#x1 = (96)
289+
#x2 = (96+448)
290+
#roi_img = output[ y1:y2, x1:x2, : ]
291+
#roi_img = cv2.resize(asl_img,(224,224),interpolation=cv2.INTER_CUBIC)
292+
293+
# 224x224 ROI for classification
294+
y1 = (128)
295+
y2 = (128+224)
296+
x1 = (208)
297+
x2 = (208+224)
298+
roi_img = output[ y1:y2, x1:x2, : ]
299+
300+
cv2.rectangle(output, (x1,y1), (x2,y2), (0, 255, 0), 2)
301+
260302
# ASL pre-processing
261-
asl_img = cv2.resize(image,(224,224),interpolation=cv2.INTER_CUBIC)
262-
asl_img = cv2.cvtColor(asl_img, cv2.COLOR_BGR2RGB)
303+
asl_img = cv2.cvtColor(roi_img, cv2.COLOR_BGR2RGB)
263304
asl_img = asl_img*input_scale
264305
asl_img = asl_img.astype(np.int8)
265306
#cv2.imshow('asl_img',asl_img)
@@ -298,8 +339,8 @@ def TopK(datain, size, filePath):
298339
asl_text = '['+str(asl_id)+']='+asl_sign
299340
cv2.putText(output,asl_text,(10,30),text_fontType,text_fontSize,text_color,text_lineSize,text_lineType)
300341

301-
#except:
302-
# print("ERROR : Exception occured during ASL classification ...")
342+
except:
343+
print("ERROR : Exception occured during ASL classification ...")
303344

304345

305346
matching_text = ("[%04d] [%02d]=%s"%(frame_count,asl_id,asl_sign))
@@ -325,7 +366,7 @@ def TopK(datain, size, filePath):
325366
filename = ("frame%04d_asl%02d.tif"%(frame_count,asl_id))
326367

327368
print("Capturing ",filename," ...")
328-
cv2.imwrite(os.path.join(output_dir,filename),asl_img)
369+
cv2.imwrite(os.path.join(output_dir,filename),roi_img)
329370

330371
if key == 115: # 's'
331372
step = True
@@ -337,7 +378,7 @@ def TopK(datain, size, filePath):
337378
step = False
338379
pause = False
339380

340-
if key == 27:
381+
if key == 27 or key == 113: # ESC or 'q':
341382
break
342383

343384
# Update the real-time FPS counter

0 commit comments

Comments
 (0)