Skip to content

Commit ad7de6a

Browse files
jianwensongfracape
authored andcommitted
[feat] kmac calculation for yolox_darknet53
1 parent 2023f67 commit ad7de6a

File tree

7 files changed

+172
-16
lines changed

7 files changed

+172
-16
lines changed

compressai_vision/model_wrappers/base_wrapper.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,10 @@ def forward(self, x, input_map_function):
6565
"""Complete the downstream task with end-to-end manner all the way from the input"""
6666
raise NotImplementedError
6767

68+
def calc_complexity(self, mode, input, data):
69+
"""Computes the MACs Complexity of the model"""
70+
raise NotImplementedError
71+
6872
@property
6973
def cfg(self):
7074
return None

compressai_vision/model_wrappers/detectron2.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,10 @@
3636

3737
from compressai_vision.registry import register_vision_model
3838

39+
from ..utils.measure_complexity import (
40+
calc_complexity_nn_part1_plyr,
41+
calc_complexity_nn_part2_plyr,
42+
)
3943
from .base_wrapper import BaseWrapper
4044
from .intconv2d import IntConv2d, IntTransposedConv2d
4145

@@ -559,6 +563,17 @@ def forward(self, x):
559563
def cfg(self):
560564
return self._cfg
561565

566+
def calc_complexity(self, mode, input, data=None):
567+
"""Computes the MACs Complexity of the model"""
568+
if mode == "nn_part_1":
569+
return calc_complexity_nn_part1_plyr(self, input)
570+
elif mode == "nn_part_2":
571+
return calc_complexity_nn_part2_plyr(self, input, data)
572+
else:
573+
raise NotImplementedError(
574+
f"Complexity calculation for {mode} not implemented for Detectron2"
575+
)
576+
562577

563578
@register_vision_model("faster_rcnn_X_101_32x8d_FPN_3x")
564579
class faster_rcnn_X_101_32x8d_FPN_3x(Rcnn_R_50_X_101_FPN):

compressai_vision/model_wrappers/jde.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,10 @@
3636

3737
from compressai_vision.registry import register_vision_model
3838

39+
from ..utils.measure_complexity import (
40+
calc_complexity_nn_part1_dn53,
41+
calc_complexity_nn_part2_dn53,
42+
)
3943
from .base_wrapper import BaseWrapper
4044

4145
__all__ = [
@@ -489,3 +493,14 @@ def forward(self, x):
489493
online_ids.append(tid)
490494

491495
return {"tlwhs": online_tlwhs, "ids": online_ids}
496+
497+
def calc_complexity(self, mode, input, data=None):
498+
"""Computes the MACs Complexity of the model"""
499+
if mode == "nn_part_1":
500+
return calc_complexity_nn_part1_dn53(self, input)
501+
elif mode == "nn_part_2":
502+
return calc_complexity_nn_part2_dn53(self, input)
503+
else:
504+
raise NotImplementedError(
505+
f"Complexity calculation for {mode} not implemented for JDE"
506+
)

compressai_vision/model_wrappers/yolox.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,10 @@
3636

3737
from compressai_vision.registry import register_vision_model
3838

39+
from ..utils.measure_complexity import (
40+
calc_complexity_nn_part1_yolox,
41+
calc_complexity_nn_part2_yolox,
42+
)
3943
from .base_wrapper import BaseWrapper
4044
from .split_squeezes import squeeze_yolox
4145

@@ -326,3 +330,14 @@ def forward(self, x):
326330
)
327331

328332
return pred
333+
334+
def calc_complexity(self, mode, input, data=None):
335+
"""Computes the MACs Complexity of the model"""
336+
if mode == "nn_part_1":
337+
return calc_complexity_nn_part1_yolox(self, input)
338+
elif mode == "nn_part_2":
339+
return calc_complexity_nn_part2_yolox(self, input)
340+
else:
341+
raise NotImplementedError(
342+
f"Complexity calculation for {mode} not implemented for YOLOX-Darknet53"
343+
)

compressai_vision/pipelines/split_inference/image_split_inference.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ def __call__(
111111
break
112112

113113
if self.is_mac_calculation:
114-
macs, pixels = calc_complexity_nn_part1_plyr(vision_model, d)
114+
macs, pixels = vision_model.calc_complexity("nn_part_1", d)
115115
self.acc_kmac_and_pixels_info("nn_part_1", macs, pixels)
116116

117117
start = time_measure()
@@ -200,8 +200,8 @@ def __call__(
200200
dec_features["file_name"] = d[0]["file_name"]
201201
dec_features["file_origin"] = d[0]["file_name"]
202202
if self.is_mac_calculation:
203-
macs, pixels = calc_complexity_nn_part2_plyr(
204-
vision_model, dec_features["data"], dec_features
203+
macs, pixels = vision_model.calc_complexity(
204+
"nn_part_2", dec_features, dec_features["data"]
205205
)
206206
self.acc_kmac_and_pixels_info("nn_part_2", macs, pixels)
207207

compressai_vision/pipelines/split_inference/video_split_inference.py

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -140,10 +140,7 @@ def __call__(
140140
break
141141

142142
if self.is_mac_calculation and e == self._codec_skip_n_frames:
143-
if hasattr(vision_model, "darknet"): # for jde
144-
kmacs, pixels = calc_complexity_nn_part1_dn53(vision_model, d)
145-
else: # for detectron2
146-
kmacs, pixels = calc_complexity_nn_part1_plyr(vision_model, d)
143+
kmacs, pixels = vision_model.calc_complexity("nn_part_1", d)
147144
self.add_kmac_and_pixels_info("nn_part_1", kmacs, pixels)
148145

149146
start = time_measure()
@@ -299,14 +296,9 @@ def __call__(
299296
) # Assuming one qp will be used
300297

301298
if self.is_mac_calculation and e == 0:
302-
if hasattr(vision_model, "darknet"): # for jde
303-
kmacs, pixels = calc_complexity_nn_part2_dn53(
304-
vision_model, dec_features
305-
)
306-
else: # for detectron2
307-
kmacs, pixels = calc_complexity_nn_part2_plyr(
308-
vision_model, data, dec_features
309-
)
299+
kmacs, pixels = vision_model.calc_complexity(
300+
"nn_part_2", dec_features, data
301+
)
310302
self.add_kmac_and_pixels_info("nn_part_2", kmacs, pixels)
311303

312304
start = time_measure()

compressai_vision/utils/measure_complexity.py

Lines changed: 116 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from functools import reduce
44

55
import torch
6+
import torch.nn as nn
67

78
from ptflops import get_model_complexity_info
89

@@ -72,7 +73,7 @@ def calc_complexity_nn_part1_plyr(vision_model, img):
7273
return kmacs, pixels
7374

7475

75-
def calc_complexity_nn_part2_plyr(vision_model, data, dec_features):
76+
def calc_complexity_nn_part2_plyr(vision_model, dec_features, data):
7677
if isinstance(data[0], list): # image task
7778
data = {k: v[0] for k, v in data.items()}
7879

@@ -147,6 +148,120 @@ def get_downsampled_shape(h, w, ratio):
147148
return h, w
148149

149150

151+
class YoloxPart1(nn.Module):
152+
def __init__(self, vision_model, split_id):
153+
super().__init__()
154+
self.backbone = vision_model.backbone
155+
self.split_id = split_id
156+
self.squeeze_at_split_enabled = vision_model.squeeze_at_split_enabled
157+
if self.squeeze_at_split_enabled:
158+
self.squeeze_model = vision_model.squeeze_model
159+
160+
def forward(self, x):
161+
if self.split_id == "l13":
162+
y = self.backbone.stem(x)
163+
y = self.backbone.dark2(y)
164+
y = self.backbone.dark3[0](y)
165+
if self.squeeze_at_split_enabled:
166+
y = self.squeeze_model.squeeze_(y)
167+
elif self.split_id == "l37":
168+
y = self.backbone.stem(x)
169+
y = self.backbone.dark2(y)
170+
y = self.backbone.dark3(y)
171+
return y
172+
173+
174+
class YoloxPart2(nn.Module):
175+
def __init__(self, vision_model, split_id):
176+
super().__init__()
177+
self.backbone = vision_model.backbone
178+
self.out1_cbl = vision_model.yolo_fpn.out1_cbl
179+
self.out1 = vision_model.yolo_fpn.out1
180+
self.out2_cbl = vision_model.yolo_fpn.out2_cbl
181+
self.out2 = vision_model.yolo_fpn.out2
182+
self.upsample = vision_model.yolo_fpn.upsample
183+
self.head = vision_model.head
184+
self.split_id = split_id
185+
self.squeeze_at_split_enabled = vision_model.squeeze_at_split_enabled
186+
if self.squeeze_at_split_enabled:
187+
self.squeeze_model = vision_model.squeeze_model
188+
# self.postprocess = vision_model.postprocess # Not needed for MAC calc
189+
190+
def forward(self, x):
191+
y = x
192+
if self.split_id == "l13":
193+
if self.squeeze_at_split_enabled:
194+
y = self.squeeze_model.expand_(y)
195+
for proc_module in self.backbone.dark3[1:]:
196+
y = proc_module(y)
197+
198+
fp_lvl2 = y
199+
fp_lvl1 = self.backbone.dark4(fp_lvl2)
200+
fp_lvl0 = self.backbone.dark5(fp_lvl1)
201+
202+
# yolo branch 1
203+
b1_in = self.out1_cbl(fp_lvl0)
204+
b1_in = self.upsample(b1_in)
205+
b1_in = torch.cat([b1_in, fp_lvl1], 1)
206+
fp_lvl1 = self.out1(b1_in)
207+
208+
# yolo branch 2
209+
b2_in = self.out2_cbl(fp_lvl1)
210+
b2_in = self.upsample(b2_in)
211+
b2_in = torch.cat([b2_in, fp_lvl2], 1)
212+
fp_lvl2 = self.out2(b2_in)
213+
214+
outputs = self.head((fp_lvl2, fp_lvl1, fp_lvl0))
215+
return outputs
216+
217+
218+
def calc_complexity_nn_part1_yolox(vision_model, img):
219+
device = torch.device(vision_model.device)
220+
img = img[0]["image"].unsqueeze(0).to(device)
221+
222+
partial_model = YoloxPart1(vision_model, vision_model.split_id)
223+
224+
C, H, W = img.shape[1:]
225+
226+
kmacs, _ = measure_mac(
227+
partial_model=partial_model,
228+
input_res=(C, H, W),
229+
input_constructor=None,
230+
)
231+
232+
pixels = reduce(operator.mul, [p_size for p_size in img.shape])
233+
return kmacs, pixels
234+
235+
236+
def calc_complexity_nn_part2_yolox(vision_model, dec_features):
237+
assert "data" in dec_features
238+
239+
x_data = dec_features["data"]
240+
241+
x_data = {
242+
k: (v[0] if isinstance(x_data[0], list) else v).to(vision_model.device)
243+
for k, v in zip(vision_model.split_layer_list, x_data.values())
244+
}
245+
246+
input_tensor = x_data[vision_model.split_id]
247+
248+
if input_tensor.dim() == 3:
249+
input_tensor = input_tensor.unsqueeze(0)
250+
251+
C, H, W = input_tensor.shape[1:]
252+
partial_model = YoloxPart2(vision_model, vision_model.split_id)
253+
254+
kmacs, _ = measure_mac(
255+
partial_model=partial_model,
256+
input_res=(C, H, W),
257+
input_constructor=None,
258+
)
259+
260+
pixels = reduce(operator.mul, input_tensor.shape)
261+
262+
return kmacs, pixels
263+
264+
150265
def prepare_proposal_input_fpn(resolutions):
151266
b, c, h, w = resolutions[1]
152267
resized_img = resolutions[0]

0 commit comments

Comments
 (0)