[Reland][Inductor] Disallow OpOverloadPacket in ir.FallbackKernel (#110567) (#111396)

This is a reland of #110567 with additional fbcode fixed.

Summary:
In ABI compatible mode, We always need op_overload.schema for FallbackKernel.

Approved by: https://github.com/jansel

Test Plan: contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/37a02659921490d85b2b0712ad52b924e0c431cd

Differential Revision: D50339346

Pull Request resolved: https://github.com/pytorch/pytorch/pull/111396
Approved by: https://github.com/chenyang78
diff --git a/torch/_inductor/fx_passes/group_batch_fusion.py b/torch/_inductor/fx_passes/group_batch_fusion.py
index d52f7ef..37f659e 100644
--- a/torch/_inductor/fx_passes/group_batch_fusion.py
+++ b/torch/_inductor/fx_passes/group_batch_fusion.py
@@ -129,7 +129,7 @@
 
         with graph.inserting_before(subset[0]):
             fused_mm = graph.call_function(
-                torch.ops.fbgemm.gmm,
+                torch.ops.fbgemm.gmm.default,
                 args=(group_inputs, group_weights, group_biases),
             )
 
diff --git a/torch/_inductor/fx_passes/quantization.py b/torch/_inductor/fx_passes/quantization.py
index cc120c0..a46b8db 100644
--- a/torch/_inductor/fx_passes/quantization.py
+++ b/torch/_inductor/fx_passes/quantization.py
@@ -533,7 +533,7 @@
         )
         _register_quantized_maxpool2d_lowering(
             generate_pattern_with_output_quant(dequantize_maxpool2d_get_item_pattern),
-            quantized.max_pool2d,
+            quantized.max_pool2d.default,
         )
 
 
diff --git a/torch/_inductor/graph.py b/torch/_inductor/graph.py
index dff6f32..a8eadca 100644
--- a/torch/_inductor/graph.py
+++ b/torch/_inductor/graph.py
@@ -605,6 +605,9 @@
             return target(*args, **kwargs)
 
         if target not in lowerings:
+            assert isinstance(
+                target, torch._ops.OpOverload
+            ), f"{target} is not an OpOverload"
             base_name = target.name().split(".")[0]
             if base_name in FALLBACK_ALLOW_LIST:
                 make_fallback(target)
diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py
index 5aed1bb..0a21626 100644
--- a/torch/_inductor/ir.py
+++ b/torch/_inductor/ir.py
@@ -3841,12 +3841,10 @@
 
         self.op_overload = kernel
 
-        # TODO: Need to revisit schema matching to find the correct OpOverload from OpOverloadPacket
         assert isinstance(
             kernel,
             (
                 torch._ops.OpOverload,
-                torch._ops.OpOverloadPacket,
                 torch._ops.HigherOrderOperator,
             ),
         ), f"Fails to create FallbackKernel for {kernel}: {type(kernel)} not supported"
@@ -3858,19 +3856,17 @@
                 else kernel.__name__
             )
             if V.graph.cpp_wrapper:
-                if isinstance(kernel, torch._ops.OpOverload):
-                    # Calling with the default kernel name can lead to ambiguous behavior like the following example.
-                    # repeat_interleave(const at::Tensor & repeats, c10::optional<int64_t> output_size=c10::nullopt)
-                    # repeat_interleave(const at::Tensor & self, int64_t repeats,
-                    #       c10::optional<int64_t> dim=c10::nullopt, c10::optional<int64_t> output_size=c10::nullopt)
-                    self.kernel = (
-                        f"at::_ops::{kernel.__name__.replace('.default', '')}::call"
-                        if kernel._overloadname == "default"
-                        else f"at::_ops::{kernel.__name__.replace('.', '_')}::call"
-                    )
-                    schema = kernel._schema
-                else:
-                    self.kernel = f"at::{op_base_name}"
+                assert isinstance(kernel, torch._ops.OpOverload)
+                # Calling with the default kernel name can lead to ambiguous behavior like the following example.
+                # repeat_interleave(const at::Tensor & repeats, c10::optional<int64_t> output_size=c10::nullopt)
+                # repeat_interleave(const at::Tensor & self, int64_t repeats,
+                #       c10::optional<int64_t> dim=c10::nullopt, c10::optional<int64_t> output_size=c10::nullopt)
+                self.kernel = (
+                    f"at::{op_base_name}"
+                    if kernel._overloadname == "default"
+                    else f"at::_ops::{kernel.__name__.replace('.', '_')}::call"
+                )
+                schema = kernel._schema
             else:
                 self.kernel = f"aten.{op_base_name}"
 
diff --git a/torch/_inductor/lowering.py b/torch/_inductor/lowering.py
index 3690c44..2fb4bf3 100644
--- a/torch/_inductor/lowering.py
+++ b/torch/_inductor/lowering.py
@@ -959,7 +959,7 @@
             input.realize()
         if all(len(input.layout.size) == 4 for input in inputs):
             inputs, _ = require_channels_last(aten.cat, *inputs)
-        return fallback_handler(aten.cat)(inputs, dim)
+        return fallback_handler(aten.cat.default)(inputs, dim)
 
     if len(inputs) == 1:
         return clone(inputs[0])
@@ -1550,11 +1550,9 @@
     return check_skip_condition(node, is_output=True)
 
 
-def make_fallback(kernel, layout_constraint=None, warn=True):
-    assert (
-        kernel not in decompositions
-    ), f"both a fallback and a decomp for same kernel: {kernel}"
-    if get_decompositions([kernel]) and warn and bool(os.getenv("CI")):
+def make_fallback(op, layout_constraint=None, warn=True):
+    assert op not in decompositions, f"both a fallback and a decomp for same op: {op}"
+    if get_decompositions([op]) and warn and bool(os.getenv("CI")):
         # Note: 'warn' is holdover from when this was a warning, but for ops that previously
         # set warn=False we do not want a CI error.
         # Ignore the 'suppress errors' configs in CI, as this particular warning happens on startup anyway and is not
@@ -1566,16 +1564,28 @@
                 " and suppress_errors is being disabled to surface it."
             )
         raise AssertionError(
-            f"make_fallback({kernel}): a decomposition exists, we should switch to it."
+            f"make_fallback({op}): a decomposition exists, we should switch to it."
             " To fix this error, either add a decomposition to core_aten_decompositions (preferred)"
             " or inductor_decompositions, and delete the corresponding `make_fallback` line."
             " Get help from the inductor team if unsure, don't pick arbitrarily to unblock yourself.",
         )
 
-    add_needs_realized_inputs(kernel)
-    if layout_constraint is not None:
-        add_layout_constraint(kernel, layout_constraint)
-    return register_lowering(kernel, type_promotion_kind=None)(fallback_handler(kernel))
+    def register_fallback(op_overload):
+        add_needs_realized_inputs(op_overload)
+        if layout_constraint is not None:
+            add_layout_constraint(op_overload, layout_constraint)
+        return register_lowering(op_overload, type_promotion_kind=None)(
+            fallback_handler(op_overload)
+        )
+
+    if isinstance(op, torch._ops.OpOverloadPacket):
+        for ol in op.overloads():
+            op_overload = getattr(op, ol)
+            register_fallback(op_overload)
+    elif isinstance(op, (torch._ops.OpOverload, torch._ops.HigherOrderOperator)):
+        register_fallback(op)
+    else:
+        raise RuntimeError(f"Unsupported fallback {op} with type {type(op)}")
 
 
 def philox_rand_offset(shape):
@@ -1633,7 +1643,8 @@
 def native_dropout(x, p, train):
     if config.fallback_random:
         return pytree.tree_map(
-            TensorBox.create, ir.FallbackKernel.create(aten.native_dropout, x, p, train)
+            TensorBox.create,
+            ir.FallbackKernel.create(aten.native_dropout.default, x, p, train),
         )
     else:
         raise AssertionError("should be handled in replace_random.py")
@@ -1673,22 +1684,28 @@
     _warn_triton_random(V.graph.creation_time)
 
 
-fallback_rand = fallback_handler(aten.rand)
-fallback_randn = fallback_handler(aten.randn)
+fallback_rand_default = fallback_handler(aten.rand.default)
+fallback_rand_generator = fallback_handler(aten.rand.generator)
+fallback_randn_default = fallback_handler(aten.randn.default)
+fallback_randn_generator = fallback_handler(aten.randn.generator)
 make_fallback(aten.randint)
 
 
 @register_lowering(aten.rand)
 def rand(*args, **kwargs):
-    if config.fallback_random or kwargs.get("generator", None) is not None:
-        return fallback_rand(*args, **kwargs)
+    if kwargs.get("generator", None) is not None:
+        return fallback_rand_generator(*args, **kwargs)
+    elif config.fallback_random:
+        return fallback_rand_default(*args, **kwargs)
     raise AssertionError("should have been handled in replace_random.py")
 
 
 @register_lowering(aten.randn)
 def randn(*args, **kwargs):
-    if config.fallback_random or kwargs.get("generator", None) is not None:
-        return fallback_randn(*args, **kwargs)
+    if kwargs.get("generator", None) is not None:
+        return fallback_randn_generator(*args, **kwargs)
+    elif config.fallback_random:
+        return fallback_randn_default(*args, **kwargs)
     raise AssertionError("should have been handled in replace_random.py")
 
 
@@ -1790,7 +1807,7 @@
     assert len(boundaries.get_size()) == 1
 
     if not (is_triton(input) and is_triton(boundaries)):
-        return fallback_handler(aten.bucketize, add_to_fallback_set=False)(
+        return fallback_handler(aten.bucketize.Tensor, add_to_fallback_set=False)(
             input, boundaries, out_int32=out_int32, right=right
         )
 
@@ -2674,7 +2691,7 @@
     except NotImplementedError:
         # Fallback to ATen for boolean indexing
         x.realize()
-        return fallback_handler(aten.index)(x, indices)
+        return fallback_handler(aten.index.Tensor)(x, indices)
 
 
 @register_lowering(aten._unsafe_index, type_promotion_kind=None)
@@ -3464,7 +3481,9 @@
     return x_out, ceil_mode
 
 
-fallback_max_pool2d_with_indices = fallback_handler(aten.max_pool2d_with_indices)
+fallback_max_pool2d_with_indices = fallback_handler(
+    aten.max_pool2d_with_indices.default
+)
 
 
 @register_lowering(aten.max_pool2d_with_indices, type_promotion_kind=None)
@@ -3549,7 +3568,7 @@
 
 
 fallback_max_pool2d_with_indices_backward = fallback_handler(
-    aten.max_pool2d_with_indices_backward
+    aten.max_pool2d_with_indices_backward.default
 )
 
 
@@ -3770,7 +3789,7 @@
     return fn_sum
 
 
-fallback_adaptive_avg_pool2d = fallback_handler(aten._adaptive_avg_pool2d)
+fallback_adaptive_avg_pool2d = fallback_handler(aten._adaptive_avg_pool2d.default)
 
 
 @register_lowering(aten._adaptive_avg_pool2d)
@@ -3890,7 +3909,7 @@
     return rv
 
 
-fallback_avg_pool2d = fallback_handler(aten.avg_pool2d)
+fallback_avg_pool2d = fallback_handler(aten.avg_pool2d.default)
 
 
 @register_lowering(aten.avg_pool2d, type_promotion_kind=None)
@@ -3987,7 +4006,7 @@
     return rv
 
 
-fallback_avg_pool2d_backward = fallback_handler(aten.avg_pool2d_backward)
+fallback_avg_pool2d_backward = fallback_handler(aten.avg_pool2d_backward.default)
 
 
 @register_lowering(aten.avg_pool2d_backward, type_promotion_kind=None)
@@ -4390,7 +4409,9 @@
     return ops.pow(a, b)
 
 
-fallback_pow = fallback_handler(aten.pow)
+fallback_pow_tensor_tensor = fallback_handler(aten.pow.Tensor_Tensor)
+fallback_pow_scalar = fallback_handler(aten.pow.Scalar)
+fallback_pow_tensor_scalar = fallback_handler(aten.pow.Tensor_Scalar)
 
 
 @register_lowering(aten.pow, broadcast=True)
@@ -4431,7 +4452,12 @@
 
     if is_integer_pow:
         # ops.pow doesn't work for integers
-        return fallback_pow(a, b)
+        if isinstance(a, Number):
+            return fallback_pow_scalar(a, b)
+        elif isinstance(b, Number):
+            return fallback_pow_tensor_scalar(a, b)
+        else:
+            return fallback_pow_tensor_tensor(a, b)
 
     return pow_native(a, b)