[quant] Input-Weight Equalization - allow logical evaluation (#61603)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/61603
Test Plan: Imported from OSS
Reviewed By: supriyar
Differential Revision: D29686878
fbshipit-source-id: 67ca4cab98b3d592ff2bb8db86499789b85bd582
diff --git a/test/quantization/fx/test_equalize_fx.py b/test/quantization/fx/test_equalize_fx.py
index 1be3a99..7c17d12 100644
--- a/test/quantization/fx/test_equalize_fx.py
+++ b/test/quantization/fx/test_equalize_fx.py
@@ -728,35 +728,28 @@
ns.call_method('dequantize')
]
- tests = [(SingleLayerLinearModel, 2, linear_node_list),
- (LinearAddModel, 2, linearAdd_node_list),
- (TwoLayerLinearModel, 2, linear2_node_list),
- (SingleLayerFunctionalLinearModel, 2, functionalLinear_node_list),
- (FunctionalLinearAddModel, 2, functionalLinearAdd_node_list),
- (TwoLayerFunctionalLinearModel, 2, functionalLinear2_node_list),
- (LinearReluModel, 2, linearRelu_node_list),
- (LinearReluLinearModel, 2, linearReluLinear_node_list),
- (FunctionalLinearReluModel, 2, functionalLinearRelu_node_list),
- (FunctionalLinearReluLinearModel, 2, functionalLinearReluLinear_node_list),
- (ConvModel, 4, conv_node_list),
- (TwoLayerConvModel, 4, conv2_node_list),
- (SingleLayerFunctionalConvModel, 4, functionalConv_node_list),
- (TwoLayerFunctionalConvModel, 4, functionalConv2_node_list),
- (ConvReluModel, 4, convRelu_node_list),
- (ConvReluConvModel, 4, convReluConv_node_list),
- (FunctionalConvReluModel, 4, functionalConvRelu_node_list),
- (FunctionalConvReluConvModel, 4, functionalConvReluConv_node_list)]
+ tests = [(SingleLayerLinearModel, linear_node_list),
+ (LinearAddModel, linearAdd_node_list),
+ (TwoLayerLinearModel, linear2_node_list),
+ (SingleLayerFunctionalLinearModel, functionalLinear_node_list),
+ (FunctionalLinearAddModel, functionalLinearAdd_node_list),
+ (TwoLayerFunctionalLinearModel, functionalLinear2_node_list),
+ (LinearReluModel, linearRelu_node_list),
+ (LinearReluLinearModel, linearReluLinear_node_list),
+ (FunctionalLinearReluModel, functionalLinearRelu_node_list),
+ (FunctionalLinearReluLinearModel, functionalLinearReluLinear_node_list),
+ (ConvModel, conv_node_list),
+ (TwoLayerConvModel, conv2_node_list),
+ (SingleLayerFunctionalConvModel, functionalConv_node_list),
+ (TwoLayerFunctionalConvModel, functionalConv2_node_list),
+ (ConvReluModel, convRelu_node_list),
+ (ConvReluConvModel, convReluConv_node_list),
+ (FunctionalConvReluModel, functionalConvRelu_node_list),
+ (FunctionalConvReluConvModel, functionalConvReluConv_node_list)]
- for (M, ndim, node_list) in tests:
+ for (M, node_list) in tests:
m = M().eval()
-
- if ndim == 2:
- x = torch.rand((5, 5))
- elif ndim == 4:
- x = torch.rand((16, 3, 224, 224))
-
prepared = prepare_fx(m, specific_qconfig_dict, equalization_qconfig_dict=default_equalization_qconfig_dict)
- prepared(x)
equalized_quantized_model = convert_fx(prepared)
# Check the order of nodes in the graph
diff --git a/torch/quantization/fx/_equalize.py b/torch/quantization/fx/_equalize.py
index b2db5e2..231f8df 100644
--- a/torch/quantization/fx/_equalize.py
+++ b/torch/quantization/fx/_equalize.py
@@ -66,7 +66,7 @@
quant_max=quant_max,
factory_kwargs=factory_kwargs)
- self.equalization_scale = torch.empty(0)
+ self.equalization_scale = torch.tensor(1)
self.equalization_shape: List[int] = []
def forward(self, x_orig):
@@ -85,17 +85,19 @@
def set_equalization_scale(self, equalization_scale):
# Reshape the equalization scale along axis=1 so that it can be
# multiplied with the input along axis=1
+ if equalization_scale.nelement() == 1 and equalization_scale == torch.tensor(1):
+ return
self.equalization_scale = torch.reshape(equalization_scale, self.equalization_shape)
def calculate_scaled_minmax(self):
r""" Returns the scaled min/max inputs
"""
- if self.equalization_scale.nelement() == 0:
+ if self.equalization_scale.nelement() == 1 and self.equalization_scale == torch.tensor(1):
warnings.warn(
- "Must call calculate_scale before calling calculate_qparams.\
- Returning default min and max input."
+ "Must call calculate_equalization_scale before calling calculate_scaled_minmax. " +
+ "Will not scale the next quantization observer."
)
- return torch.tensor([0]), torch.tensor([0])
+ return None, None
# Calculate qparams for the scaled min/max inputs
# Scale the input by the equalization scale located at the same column
@@ -145,7 +147,7 @@
quant_max=quant_max,
factory_kwargs=factory_kwargs)
- self.equalization_scale = torch.empty(0)
+ self.equalization_scale = torch.tensor(1)
def forward(self, w_orig):
if not (w_orig.ndim >= 2 and w_orig.ndim <= 5):
@@ -176,12 +178,16 @@
(min_weights, max_weights) = weight_obs.get_weight_col_minmax()
if not (check_min_max_valid(min_inputs, max_inputs) and check_min_max_valid(min_weights, max_weights)):
+ warnings.warn(
+ "Must run observer before calling calculate_equalization_scale. " +
+ "Returning default equalization scale torch.tensor(1)."
+ )
return torch.tensor(1)
if not (min_inputs.shape == min_weights.shape):
raise ValueError(
"Input and Weight must have the same column dimension. " +
- f"Found {min_inputs.shape} and {max_inputs.shape} instead."
+ f"Found {min_inputs.shape} and {min_weights.shape} shapes instead."
)
equalization_scale = torch.sqrt((max_weights - min_weights) / (max_inputs - min_inputs))
@@ -356,6 +362,9 @@
"""
next_inp_eq_obs = maybe_get_next_input_eq_obs(node, modules)
if next_inp_eq_obs:
+ if next_inp_eq_obs.equalization_scale.nelement() == 1 and \
+ next_inp_eq_obs.equalization_scale == torch.tensor(1):
+ return None
return next_inp_eq_obs.equalization_scale
return None
@@ -375,6 +384,8 @@
return
min_input_scaled, max_input_scaled = input_eq_obs.calculate_scaled_minmax()
+ if min_input_scaled is None and max_input_scaled is None:
+ return
input_quant_obs.min_val = min_input_scaled
input_quant_obs.max_val = max_input_scaled
@@ -393,6 +404,9 @@
next_equalization_scale: Next node's calculated equalization scale if
the following node needs to be equalized, 1 otherwise
"""
+ if equalization_scale is None:
+ return
+
if fused_module_supports_equalization(modules[str(node.target)]):
op_module = modules[str(node.target)][0] # type: ignore[index]
else:
@@ -440,6 +454,8 @@
) -> None:
""" Scales the weight value for functional layers
"""
+ if equalization_scale is None:
+ return
# From the given op_node, the path looks like:
# get_attr(weight) -> weight_quant_obs -> weight_eq_obs -> op_node
@@ -667,6 +683,9 @@
weight_eq_obs = weight_eq_obs_dict.get(node.name)
assert(isinstance(weight_eq_obs, _WeightEqualizationObserver))
equalization_scale = weight_eq_obs.equalization_scale
+
+ if equalization_scale.nelement() == 1 and equalization_scale == torch.tensor(1):
+ equalization_scale = None # type: ignore[assignment]
maybe_next_equalization_scale = maybe_get_next_equalization_scale(node, modules)
# Scale the weight nodes