Skip to content

Commit 7751ed4

Browse files
vspenubarthipytorchmergebot
authored andcommitted
[ao] Added generate report capability to ModelReport class
Summary: The ModelReport class in model_report.py combines the functionality of the detectors and the ModelReportObserver. It creates an end-to-end system where a user can pass in a prepared Graph Model to insert the ModelReportObservers, then after the user callibrates their model, the callibrated model can then be used by the ModelReport class to generate reports based on what the user wished to gather information on. This contains the implementation and the tests for the generate_report method which is used on a callibrated fx model to generate reports based on data collected by the inserted observers during the callibration phase and also potentially remove those observers if desired. Test Plan: python test/test_quantization.py TestFxModelReportClass Reviewers: Subscribers: Tasks: Tags: Pull Request resolved: pytorch#79792 Approved by: https://github.com/HDCharles
1 parent 1cff414 commit 7751ed4

File tree

3 files changed

+194
-12
lines changed

3 files changed

+194
-12
lines changed

test/quantization/fx/test_model_report_fx.py

+130
Original file line numberDiff line numberDiff line change
@@ -812,6 +812,33 @@ def forward(self, x):
812812

813813
class TestFxModelReportClass(QuantizationTestCase):
814814

815+
# example model to use for tests
816+
class ThreeOps(nn.Module):
817+
def __init__(self):
818+
super().__init__()
819+
self.linear = nn.Linear(3, 3)
820+
self.bn = nn.BatchNorm2d(3)
821+
self.relu = nn.ReLU()
822+
823+
def forward(self, x):
824+
x = self.linear(x)
825+
x = self.bn(x)
826+
x = self.relu(x)
827+
return x
828+
829+
class TwoThreeOps(nn.Module):
830+
def __init__(self):
831+
super().__init__()
832+
self.block1 = TestFxModelReportClass.ThreeOps()
833+
self.block2 = TestFxModelReportClass.ThreeOps()
834+
835+
def forward(self, x):
836+
x = self.block1(x)
837+
y = self.block2(x)
838+
z = x + y
839+
z = F.relu(z)
840+
return z
841+
815842
@skipIfNoFBGEMM
816843
def test_constructor(self):
817844
"""
@@ -945,3 +972,106 @@ def forward(self, x):
945972
# ensure that we can prepare for callibration only once
946973
with self.assertRaises(ValueError):
947974
prepared_for_callibrate_model = model_report.prepare_detailed_calibration(model_prep)
975+
976+
977+
def get_module_and_graph_cnts(self, callibrated_fx_module):
978+
r"""
979+
Calculates number of ModelReportObserver modules in the model as well as the graph structure.
980+
981+
Returns a tuple of two elements:
982+
int: The number of ModelReportObservers found in the model
983+
int: The number of model_report nodes found in the graph
984+
"""
985+
# get the number of observers stored as modules
986+
modules_observer_cnt = 0
987+
for fqn, module in callibrated_fx_module.named_modules():
988+
if isinstance(module, ModelReportObserver):
989+
modules_observer_cnt += 1
990+
991+
# get number of observers in the graph
992+
model_report_str_check = "model_report"
993+
graph_observer_cnt = 0
994+
# also make sure arguments for observers in the graph are proper
995+
for node in callibrated_fx_module.graph.nodes:
996+
# not all node targets are strings, so check
997+
if isinstance(node.target, str) and model_report_str_check in node.target:
998+
# increment if we found a graph observer
999+
graph_observer_cnt += 1
1000+
1001+
return (modules_observer_cnt, graph_observer_cnt)
1002+
1003+
@skipIfNoFBGEMM
1004+
def test_generate_report(self):
1005+
"""
1006+
Tests model_report.generate_model_report to ensure report generation
1007+
1008+
Specifically looks at:
1009+
- Whether correct number of reports are being generated
1010+
- Whether observers are being properly removed if specified
1011+
- Whether correct blocking from generating report twice if obs removed
1012+
"""
1013+
1014+
with override_quantized_engine('fbgemm'):
1015+
# set the backend for this test
1016+
torch.backends.quantized.engine = "fbgemm"
1017+
1018+
# check whether the correct number of reports are being generated
1019+
filled_detector_set = set([DynamicStaticDetector(), PerChannelDetector(torch.backends.quantized.engine)])
1020+
single_detector_set = set([DynamicStaticDetector()])
1021+
1022+
# initialize one with filled detector
1023+
model_report_full = ModelReport(filled_detector_set)
1024+
# initialize another with a single detector set
1025+
model_report_single = ModelReport(single_detector_set)
1026+
1027+
# prepare and callibrate two different instances of same model
1028+
# prepare the model
1029+
model_full = TestFxModelReportClass.TwoThreeOps()
1030+
model_single = TestFxModelReportClass.TwoThreeOps()
1031+
example_input = torch.randn(1, 3, 3, 3)
1032+
current_backend = torch.backends.quantized.engine
1033+
q_config_mapping = QConfigMapping()
1034+
q_config_mapping.set_global(torch.ao.quantization.get_default_qconfig(torch.backends.quantized.engine))
1035+
1036+
model_prep_full = quantize_fx.prepare_fx(model_full, q_config_mapping, example_input)
1037+
model_prep_single = quantize_fx.prepare_fx(model_single, q_config_mapping, example_input)
1038+
1039+
# prepare the models for callibration
1040+
prepared_for_callibrate_model_full = model_report_full.prepare_detailed_calibration(model_prep_full)
1041+
prepared_for_callibrate_model_single = model_report_single.prepare_detailed_calibration(model_prep_single)
1042+
1043+
# now callibrate the two models
1044+
num_iterations = 10
1045+
for i in range(num_iterations):
1046+
example_input = torch.tensor(torch.randint(100, (1, 3, 3, 3)), dtype=torch.float)
1047+
prepared_for_callibrate_model_full(example_input)
1048+
prepared_for_callibrate_model_single(example_input)
1049+
1050+
# now generate the reports
1051+
model_full_report = model_report_full.generate_model_report(
1052+
prepared_for_callibrate_model_full, True
1053+
)
1054+
model_single_report = model_report_single.generate_model_report(prepared_for_callibrate_model_single, False)
1055+
1056+
# check that sizes are appropriate
1057+
self.assertEqual(len(model_full_report), len(filled_detector_set))
1058+
self.assertEqual(len(model_single_report), len(single_detector_set))
1059+
1060+
# make sure observers are being properly removed for full report since we put flag in
1061+
modules_observer_cnt, graph_observer_cnt = self.get_module_and_graph_cnts(prepared_for_callibrate_model_full)
1062+
self.assertEqual(modules_observer_cnt, 0) # assert no more observer modules
1063+
self.assertEqual(graph_observer_cnt, 0) # assert no more observer nodes in graph
1064+
1065+
# make sure observers aren't being removed for single report since not specified
1066+
modules_observer_cnt, graph_observer_cnt = self.get_module_and_graph_cnts(prepared_for_callibrate_model_single)
1067+
self.assertNotEqual(modules_observer_cnt, 0)
1068+
self.assertNotEqual(graph_observer_cnt, 0)
1069+
1070+
# make sure error when try to rerun report generation for full report but not single report
1071+
with self.assertRaises(Exception):
1072+
model_full_report = model_report_full.generate_model_report(
1073+
prepared_for_callibrate_model_full, False
1074+
)
1075+
1076+
# make sure we don't run into error for single report
1077+
model_single_report = model_report_single.generate_model_report(prepared_for_callibrate_model_single, False)

torch/ao/quantization/fx/_model_report/detector.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ class PerChannelDetector(DetectorBase):
7373
"onednn": set([nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d, nnqat.Linear, nnqat.Conv1d, nnqat.Conv2d, nnqat.Conv3d]),
7474
}
7575

76-
def __init__(self, backend=torch.backends.quantized.engine):
76+
def __init__(self, backend: str = torch.backends.quantized.engine):
7777
super().__init__()
7878

7979
# store the backend information

torch/ao/quantization/fx/_model_report/model_report.py

+63-11
Original file line numberDiff line numberDiff line change
@@ -32,27 +32,28 @@ def __init__(self, desired_report_detectors: Set[DetectorBase]):
3232

3333
# keep the reports private so they can't be modified
3434
self._desired_report_detectors = desired_report_detectors
35-
self._desired_reports = set([detector.get_detector_name() for detector in desired_report_detectors])
35+
self._desired_detector_names = set([detector.get_detector_name() for detector in desired_report_detectors])
3636

3737
# keep a mapping of desired reports to observers of interest
3838
# this is to get the readings, and to remove them, can create a large set
3939
# this set can then be used to traverse the graph and remove added observers
40-
self._report_name_to_observer_fqns: Dict[str, Set[str]] = {}
40+
self._detector_name_to_observer_fqns: Dict[str, Set[str]] = {}
4141

4242
# initialize each report to have empty set of observers of interest
43-
for desired_report in self._desired_reports:
44-
self._report_name_to_observer_fqns[desired_report] = set([])
43+
for desired_report in self._desired_detector_names:
44+
self._detector_name_to_observer_fqns[desired_report] = set([])
4545

46-
# flags to ensure that we can only prepare once
46+
# flags to ensure that we can only prepare and generate report once
4747
self._prepared_flag = False
48+
self._removed_observers = False
4849

4950
def get_desired_reports_names(self) -> Set[str]:
5051
""" Returns a copy of the desired reports for viewing """
51-
return self._desired_reports.copy()
52+
return self._desired_detector_names.copy()
5253

5354
def get_observers_of_interest(self) -> Dict[str, Set[str]]:
5455
""" Returns a copy of the observers of interest for viewing """
55-
return self._report_name_to_observer_fqns.copy()
56+
return self._detector_name_to_observer_fqns.copy()
5657

5758
def prepare_detailed_calibration(self, prepared_fx_model: GraphModule) -> GraphModule:
5859
r"""
@@ -61,7 +62,7 @@ def prepare_detailed_calibration(self, prepared_fx_model: GraphModule) -> GraphM
6162
6263
Each observer is inserted based on the desired_reports into the relavent locations
6364
64-
Right now, each report in self._desired_reports has independent insertions
65+
Right now, each report in self._desired_detector_names has independent insertions
6566
However, if a module already has a Observer of the same type, the insertion will not occur
6667
This is because all of the same type of Observer collect same information, so redundant
6768
@@ -84,7 +85,7 @@ def prepare_detailed_calibration(self, prepared_fx_model: GraphModule) -> GraphM
8485
# map each insert point to the observer to use
8586
insert_observers_fqns.update(obs_fqn_to_info)
8687
# update the set of observers this report cares about
87-
self._report_name_to_observer_fqns[detector.get_detector_name()] = set(obs_fqn_to_info.keys())
88+
self._detector_name_to_observer_fqns[detector.get_detector_name()] = set(obs_fqn_to_info.keys())
8889

8990
# now insert all the observers at their desired locations
9091
for observer_fqn in insert_observers_fqns:
@@ -142,7 +143,20 @@ def _get_node_from_fqn(self, fx_model: GraphModule, node_fqn: str) -> torch.fx.n
142143
143144
Returns the Node object of the given node_fqn otherwise returns None
144145
"""
145-
pass
146+
node_to_return = None
147+
for node in fx_model.graph.nodes:
148+
# if the target matches the fqn, it's the node we are looking for
149+
if node.target == node_fqn:
150+
node_to_return = node
151+
break
152+
153+
if node_to_return is None:
154+
raise ValueError("The node_fqn is was not found within the module.")
155+
156+
# assert for MyPy
157+
assert isinstance(node_to_return, torch.fx.node.Node)
158+
159+
return node_to_return
146160

147161
def generate_model_report(
148162
self, calibrated_fx_model: GraphModule, remove_inserted_observers: bool
@@ -161,4 +175,42 @@ def generate_model_report(
161175
The textual summary of that report information
162176
A dictionary containing relavent statistics or information for that report
163177
"""
164-
pass
178+
# if we already removed the observers, we cannot generate report
179+
if self._removed_observers:
180+
raise Exception("Cannot generate report on model you already removed observers from")
181+
182+
# keep track of all the reports of interest and their outputs
183+
reports_of_interest = {}
184+
185+
for detector in self._desired_report_detectors:
186+
# generate the individual report for the detector
187+
report_output = detector.generate_detector_report(calibrated_fx_model)
188+
reports_of_interest[detector.get_detector_name()] = report_output
189+
190+
# if user wishes to remove inserted observers, go ahead and remove
191+
if remove_inserted_observers:
192+
self._removed_observers = True
193+
# get the set of all Observers inserted by this instance of ModelReport
194+
all_observers_of_interest: Set[str] = set([])
195+
for desired_report in self._detector_name_to_observer_fqns:
196+
observers_of_interest = self._detector_name_to_observer_fqns[desired_report]
197+
all_observers_of_interest.update(observers_of_interest)
198+
199+
# go through all_observers_of_interest and remove them from the graph and model
200+
for observer_fqn in all_observers_of_interest:
201+
# remove the observer from the model
202+
calibrated_fx_model.delete_submodule(observer_fqn)
203+
204+
# remove the observer from the graph structure
205+
node_obj = self._get_node_from_fqn(calibrated_fx_model, observer_fqn)
206+
207+
if node_obj:
208+
calibrated_fx_model.graph.erase_node(node_obj)
209+
else:
210+
raise ValueError("Node no longer exists in GraphModule structure")
211+
212+
# remember to recompile the model
213+
calibrated_fx_model.recompile()
214+
215+
# return the reports of interest
216+
return reports_of_interest

0 commit comments

Comments
 (0)