@@ -812,6 +812,33 @@ def forward(self, x):
812
812
813
813
class TestFxModelReportClass (QuantizationTestCase ):
814
814
815
+ # example model to use for tests
816
+ class ThreeOps (nn .Module ):
817
+ def __init__ (self ):
818
+ super ().__init__ ()
819
+ self .linear = nn .Linear (3 , 3 )
820
+ self .bn = nn .BatchNorm2d (3 )
821
+ self .relu = nn .ReLU ()
822
+
823
+ def forward (self , x ):
824
+ x = self .linear (x )
825
+ x = self .bn (x )
826
+ x = self .relu (x )
827
+ return x
828
+
829
+ class TwoThreeOps (nn .Module ):
830
+ def __init__ (self ):
831
+ super ().__init__ ()
832
+ self .block1 = TestFxModelReportClass .ThreeOps ()
833
+ self .block2 = TestFxModelReportClass .ThreeOps ()
834
+
835
+ def forward (self , x ):
836
+ x = self .block1 (x )
837
+ y = self .block2 (x )
838
+ z = x + y
839
+ z = F .relu (z )
840
+ return z
841
+
815
842
@skipIfNoFBGEMM
816
843
def test_constructor (self ):
817
844
"""
@@ -945,3 +972,106 @@ def forward(self, x):
945
972
# ensure that we can prepare for callibration only once
946
973
with self .assertRaises (ValueError ):
947
974
prepared_for_callibrate_model = model_report .prepare_detailed_calibration (model_prep )
975
+
976
+
977
+ def get_module_and_graph_cnts (self , callibrated_fx_module ):
978
+ r"""
979
+ Calculates number of ModelReportObserver modules in the model as well as the graph structure.
980
+
981
+ Returns a tuple of two elements:
982
+ int: The number of ModelReportObservers found in the model
983
+ int: The number of model_report nodes found in the graph
984
+ """
985
+ # get the number of observers stored as modules
986
+ modules_observer_cnt = 0
987
+ for fqn , module in callibrated_fx_module .named_modules ():
988
+ if isinstance (module , ModelReportObserver ):
989
+ modules_observer_cnt += 1
990
+
991
+ # get number of observers in the graph
992
+ model_report_str_check = "model_report"
993
+ graph_observer_cnt = 0
994
+ # also make sure arguments for observers in the graph are proper
995
+ for node in callibrated_fx_module .graph .nodes :
996
+ # not all node targets are strings, so check
997
+ if isinstance (node .target , str ) and model_report_str_check in node .target :
998
+ # increment if we found a graph observer
999
+ graph_observer_cnt += 1
1000
+
1001
+ return (modules_observer_cnt , graph_observer_cnt )
1002
+
1003
+ @skipIfNoFBGEMM
1004
+ def test_generate_report (self ):
1005
+ """
1006
+ Tests model_report.generate_model_report to ensure report generation
1007
+
1008
+ Specifically looks at:
1009
+ - Whether correct number of reports are being generated
1010
+ - Whether observers are being properly removed if specified
1011
+ - Whether correct blocking from generating report twice if obs removed
1012
+ """
1013
+
1014
+ with override_quantized_engine ('fbgemm' ):
1015
+ # set the backend for this test
1016
+ torch .backends .quantized .engine = "fbgemm"
1017
+
1018
+ # check whether the correct number of reports are being generated
1019
+ filled_detector_set = set ([DynamicStaticDetector (), PerChannelDetector (torch .backends .quantized .engine )])
1020
+ single_detector_set = set ([DynamicStaticDetector ()])
1021
+
1022
+ # initialize one with filled detector
1023
+ model_report_full = ModelReport (filled_detector_set )
1024
+ # initialize another with a single detector set
1025
+ model_report_single = ModelReport (single_detector_set )
1026
+
1027
+ # prepare and callibrate two different instances of same model
1028
+ # prepare the model
1029
+ model_full = TestFxModelReportClass .TwoThreeOps ()
1030
+ model_single = TestFxModelReportClass .TwoThreeOps ()
1031
+ example_input = torch .randn (1 , 3 , 3 , 3 )
1032
+ current_backend = torch .backends .quantized .engine
1033
+ q_config_mapping = QConfigMapping ()
1034
+ q_config_mapping .set_global (torch .ao .quantization .get_default_qconfig (torch .backends .quantized .engine ))
1035
+
1036
+ model_prep_full = quantize_fx .prepare_fx (model_full , q_config_mapping , example_input )
1037
+ model_prep_single = quantize_fx .prepare_fx (model_single , q_config_mapping , example_input )
1038
+
1039
+ # prepare the models for callibration
1040
+ prepared_for_callibrate_model_full = model_report_full .prepare_detailed_calibration (model_prep_full )
1041
+ prepared_for_callibrate_model_single = model_report_single .prepare_detailed_calibration (model_prep_single )
1042
+
1043
+ # now callibrate the two models
1044
+ num_iterations = 10
1045
+ for i in range (num_iterations ):
1046
+ example_input = torch .tensor (torch .randint (100 , (1 , 3 , 3 , 3 )), dtype = torch .float )
1047
+ prepared_for_callibrate_model_full (example_input )
1048
+ prepared_for_callibrate_model_single (example_input )
1049
+
1050
+ # now generate the reports
1051
+ model_full_report = model_report_full .generate_model_report (
1052
+ prepared_for_callibrate_model_full , True
1053
+ )
1054
+ model_single_report = model_report_single .generate_model_report (prepared_for_callibrate_model_single , False )
1055
+
1056
+ # check that sizes are appropriate
1057
+ self .assertEqual (len (model_full_report ), len (filled_detector_set ))
1058
+ self .assertEqual (len (model_single_report ), len (single_detector_set ))
1059
+
1060
+ # make sure observers are being properly removed for full report since we put flag in
1061
+ modules_observer_cnt , graph_observer_cnt = self .get_module_and_graph_cnts (prepared_for_callibrate_model_full )
1062
+ self .assertEqual (modules_observer_cnt , 0 ) # assert no more observer modules
1063
+ self .assertEqual (graph_observer_cnt , 0 ) # assert no more observer nodes in graph
1064
+
1065
+ # make sure observers aren't being removed for single report since not specified
1066
+ modules_observer_cnt , graph_observer_cnt = self .get_module_and_graph_cnts (prepared_for_callibrate_model_single )
1067
+ self .assertNotEqual (modules_observer_cnt , 0 )
1068
+ self .assertNotEqual (graph_observer_cnt , 0 )
1069
+
1070
+ # make sure error when try to rerun report generation for full report but not single report
1071
+ with self .assertRaises (Exception ):
1072
+ model_full_report = model_report_full .generate_model_report (
1073
+ prepared_for_callibrate_model_full , False
1074
+ )
1075
+
1076
+ # make sure we don't run into error for single report
1077
+ model_single_report = model_report_single .generate_model_report (prepared_for_callibrate_model_single , False )
0 commit comments