[quant][be] Simplify fake_quant_per_channel (#123186)

Summary: We probably don't need
`torch._C._AutoDispatchBelowAutograd()`, which is to prevent
infinite recursion if the implementation calls itself. Let's
remove it and see if anything breaks. The other major change
is registering the op to the more general Autograd dispatch
key so it can be used on cuda as well.

Test Plan:
python test/inductor/test_cpu_repro.py -k test_decomposed_fake_quant_per_channel

Reviewers: zou3519, bdhirsh

Subscribers: zou3519, bdhirsh, jerryzh168, leslie-fang-intel
Pull Request resolved: https://github.com/pytorch/pytorch/pull/123186
Approved by: https://github.com/zou3519, https://github.com/leslie-fang-intel
diff --git a/torch/ao/quantization/fx/_decomposed.py b/torch/ao/quantization/fx/_decomposed.py
index 0a46e2c..18dd61c 100644
--- a/torch/ao/quantization/fx/_decomposed.py
+++ b/torch/ao/quantization/fx/_decomposed.py
@@ -972,21 +972,18 @@
 class FakeQuantPerChannel(torch.autograd.Function):
     @staticmethod
     def forward(ctx, input, scales, zero_points, axis, quant_min, quant_max):
-        with torch._C._AutoDispatchBelowAutograd():
-            if input.dtype in [torch.float16, torch.bfloat16]:
-                input = input.to(torch.float32)
-            if scales.dtype != torch.float32:
-                scales = scales.to(torch.float32)
-            if zero_points.dtype != torch.int32:
-                zero_points = zero_points.to(torch.int32)
-            assert input.dtype == torch.float32, f"Expecting input to have dtype torch.float32, but got dtype: {input.dtype}"
-            assert axis < input.dim(), f"Expecting axis to be < {input.dim()}"
-            broadcast_dims = list(range(0, axis)) + list(range(axis + 1, input.ndim))
-            unsqueeze_scales = _unsqueeze_multiple(scales, broadcast_dims)
-            unsqueeze_zero_points = _unsqueeze_multiple(zero_points, broadcast_dims)
-            temp = torch.round(input * (1.0 / unsqueeze_scales)) + unsqueeze_zero_points
-            out = (torch.clamp(temp, quant_min, quant_max) - unsqueeze_zero_points) * unsqueeze_scales
-            mask = torch.logical_and((temp >= quant_min), (temp <= quant_max))
+        if scales.dtype != torch.float32:
+            scales = scales.to(torch.float32)
+        if zero_points.dtype != torch.int32:
+            zero_points = zero_points.to(torch.int32)
+        assert input.dtype == torch.float32, f"Expecting input to have dtype torch.float32, but got dtype: {input.dtype}"
+        assert axis < input.dim(), f"Expecting axis to be < {input.dim()}"
+        broadcast_dims = list(range(0, axis)) + list(range(axis + 1, input.ndim))
+        unsqueeze_scales = _unsqueeze_multiple(scales, broadcast_dims)
+        unsqueeze_zero_points = _unsqueeze_multiple(zero_points, broadcast_dims)
+        temp = torch.round(input * (1.0 / unsqueeze_scales)) + unsqueeze_zero_points
+        out = (torch.clamp(temp, quant_min, quant_max) - unsqueeze_zero_points) * unsqueeze_scales
+        mask = torch.logical_and((temp >= quant_min), (temp <= quant_max))
 
         ctx.save_for_backward(mask)
         return out
@@ -996,7 +993,7 @@
         mask, = ctx.saved_tensors
         return gy * mask, None, None, None, None, None
 
-@impl(quantized_decomposed_lib, "fake_quant_per_channel", "AutogradCPU")
+@impl(quantized_decomposed_lib, "fake_quant_per_channel", "Autograd")
 def fake_quant_per_channel(
         input: torch.Tensor,
         scales: torch.Tensor,