[quant] Input-Weight Equalization - allow logical evaluation (#61603)

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/61603

Test Plan: Imported from OSS

Reviewed By: supriyar

Differential Revision: D29686878

fbshipit-source-id: 67ca4cab98b3d592ff2bb8db86499789b85bd582
diff --git a/test/quantization/fx/test_equalize_fx.py b/test/quantization/fx/test_equalize_fx.py
index 1be3a99..7c17d12 100644
--- a/test/quantization/fx/test_equalize_fx.py
+++ b/test/quantization/fx/test_equalize_fx.py
@@ -728,35 +728,28 @@
             ns.call_method('dequantize')
         ]
 
-        tests = [(SingleLayerLinearModel, 2, linear_node_list),
-                 (LinearAddModel, 2, linearAdd_node_list),
-                 (TwoLayerLinearModel, 2, linear2_node_list),
-                 (SingleLayerFunctionalLinearModel, 2, functionalLinear_node_list),
-                 (FunctionalLinearAddModel, 2, functionalLinearAdd_node_list),
-                 (TwoLayerFunctionalLinearModel, 2, functionalLinear2_node_list),
-                 (LinearReluModel, 2, linearRelu_node_list),
-                 (LinearReluLinearModel, 2, linearReluLinear_node_list),
-                 (FunctionalLinearReluModel, 2, functionalLinearRelu_node_list),
-                 (FunctionalLinearReluLinearModel, 2, functionalLinearReluLinear_node_list),
-                 (ConvModel, 4, conv_node_list),
-                 (TwoLayerConvModel, 4, conv2_node_list),
-                 (SingleLayerFunctionalConvModel, 4, functionalConv_node_list),
-                 (TwoLayerFunctionalConvModel, 4, functionalConv2_node_list),
-                 (ConvReluModel, 4, convRelu_node_list),
-                 (ConvReluConvModel, 4, convReluConv_node_list),
-                 (FunctionalConvReluModel, 4, functionalConvRelu_node_list),
-                 (FunctionalConvReluConvModel, 4, functionalConvReluConv_node_list)]
+        tests = [(SingleLayerLinearModel, linear_node_list),
+                 (LinearAddModel, linearAdd_node_list),
+                 (TwoLayerLinearModel, linear2_node_list),
+                 (SingleLayerFunctionalLinearModel, functionalLinear_node_list),
+                 (FunctionalLinearAddModel, functionalLinearAdd_node_list),
+                 (TwoLayerFunctionalLinearModel, functionalLinear2_node_list),
+                 (LinearReluModel, linearRelu_node_list),
+                 (LinearReluLinearModel, linearReluLinear_node_list),
+                 (FunctionalLinearReluModel, functionalLinearRelu_node_list),
+                 (FunctionalLinearReluLinearModel, functionalLinearReluLinear_node_list),
+                 (ConvModel, conv_node_list),
+                 (TwoLayerConvModel, conv2_node_list),
+                 (SingleLayerFunctionalConvModel, functionalConv_node_list),
+                 (TwoLayerFunctionalConvModel, functionalConv2_node_list),
+                 (ConvReluModel, convRelu_node_list),
+                 (ConvReluConvModel, convReluConv_node_list),
+                 (FunctionalConvReluModel, functionalConvRelu_node_list),
+                 (FunctionalConvReluConvModel, functionalConvReluConv_node_list)]
 
-        for (M, ndim, node_list) in tests:
+        for (M, node_list) in tests:
             m = M().eval()
-
-            if ndim == 2:
-                x = torch.rand((5, 5))
-            elif ndim == 4:
-                x = torch.rand((16, 3, 224, 224))
-
             prepared = prepare_fx(m, specific_qconfig_dict, equalization_qconfig_dict=default_equalization_qconfig_dict)
-            prepared(x)
             equalized_quantized_model = convert_fx(prepared)
 
             # Check the order of nodes in the graph
diff --git a/torch/quantization/fx/_equalize.py b/torch/quantization/fx/_equalize.py
index b2db5e2..231f8df 100644
--- a/torch/quantization/fx/_equalize.py
+++ b/torch/quantization/fx/_equalize.py
@@ -66,7 +66,7 @@
                                                   quant_max=quant_max,
                                                   factory_kwargs=factory_kwargs)
 
-        self.equalization_scale = torch.empty(0)
+        self.equalization_scale = torch.tensor(1)
         self.equalization_shape: List[int] = []
 
     def forward(self, x_orig):
@@ -85,17 +85,19 @@
     def set_equalization_scale(self, equalization_scale):
         # Reshape the equalization scale along axis=1 so that it can be
         # multiplied with the input along axis=1
+        if equalization_scale.nelement() == 1 and equalization_scale == torch.tensor(1):
+            return
         self.equalization_scale = torch.reshape(equalization_scale, self.equalization_shape)
 
     def calculate_scaled_minmax(self):
         r""" Returns the scaled min/max inputs
         """
-        if self.equalization_scale.nelement() == 0:
+        if self.equalization_scale.nelement() == 1 and self.equalization_scale == torch.tensor(1):
             warnings.warn(
-                "Must call calculate_scale before calling calculate_qparams.\
-                Returning default min and max input."
+                "Must call calculate_equalization_scale before calling calculate_scaled_minmax. " +
+                "Will not scale the next quantization observer."
             )
-            return torch.tensor([0]), torch.tensor([0])
+            return None, None
 
         # Calculate qparams for the scaled min/max inputs
         # Scale the input by the equalization scale located at the same column
@@ -145,7 +147,7 @@
                                                        quant_max=quant_max,
                                                        factory_kwargs=factory_kwargs)
 
-        self.equalization_scale = torch.empty(0)
+        self.equalization_scale = torch.tensor(1)
 
     def forward(self, w_orig):
         if not (w_orig.ndim >= 2 and w_orig.ndim <= 5):
@@ -176,12 +178,16 @@
     (min_weights, max_weights) = weight_obs.get_weight_col_minmax()
 
     if not (check_min_max_valid(min_inputs, max_inputs) and check_min_max_valid(min_weights, max_weights)):
+        warnings.warn(
+            "Must run observer before calling calculate_equalization_scale. " +
+            "Returning default equalization scale torch.tensor(1)."
+        )
         return torch.tensor(1)
 
     if not (min_inputs.shape == min_weights.shape):
         raise ValueError(
             "Input and Weight must have the same column dimension. " +
-            f"Found {min_inputs.shape} and {max_inputs.shape} instead."
+            f"Found {min_inputs.shape} and {min_weights.shape} shapes instead."
         )
 
     equalization_scale = torch.sqrt((max_weights - min_weights) / (max_inputs - min_inputs))
@@ -356,6 +362,9 @@
     """
     next_inp_eq_obs = maybe_get_next_input_eq_obs(node, modules)
     if next_inp_eq_obs:
+        if next_inp_eq_obs.equalization_scale.nelement() == 1 and \
+           next_inp_eq_obs.equalization_scale == torch.tensor(1):
+            return None
         return next_inp_eq_obs.equalization_scale
     return None
 
@@ -375,6 +384,8 @@
         return
 
     min_input_scaled, max_input_scaled = input_eq_obs.calculate_scaled_minmax()
+    if min_input_scaled is None and max_input_scaled is None:
+        return
     input_quant_obs.min_val = min_input_scaled
     input_quant_obs.max_val = max_input_scaled
 
@@ -393,6 +404,9 @@
         next_equalization_scale: Next node's calculated equalization scale if
            the following node needs to be equalized, 1 otherwise
     """
+    if equalization_scale is None:
+        return
+
     if fused_module_supports_equalization(modules[str(node.target)]):
         op_module = modules[str(node.target)][0]    # type: ignore[index]
     else:
@@ -440,6 +454,8 @@
 ) -> None:
     """ Scales the weight value for functional layers
     """
+    if equalization_scale is None:
+        return
 
     # From the given op_node, the path looks like:
     #   get_attr(weight) -> weight_quant_obs -> weight_eq_obs -> op_node
@@ -667,6 +683,9 @@
             weight_eq_obs = weight_eq_obs_dict.get(node.name)
             assert(isinstance(weight_eq_obs, _WeightEqualizationObserver))
             equalization_scale = weight_eq_obs.equalization_scale
+
+            if equalization_scale.nelement() == 1 and equalization_scale == torch.tensor(1):
+                equalization_scale = None  # type: ignore[assignment]
             maybe_next_equalization_scale = maybe_get_next_equalization_scale(node, modules)
 
             # Scale the weight nodes