[Reland][Inductor] Disallow OpOverloadPacket in ir.FallbackKernel (#110567) (#111396)
This is a reland of #110567 with additional fbcode fixed.
Summary:
In ABI compatible mode, We always need op_overload.schema for FallbackKernel.
Approved by: https://github.com/jansel
Test Plan: contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/37a02659921490d85b2b0712ad52b924e0c431cd
Differential Revision: D50339346
Pull Request resolved: https://github.com/pytorch/pytorch/pull/111396
Approved by: https://github.com/chenyang78
diff --git a/torch/_inductor/fx_passes/group_batch_fusion.py b/torch/_inductor/fx_passes/group_batch_fusion.py
index d52f7ef..37f659e 100644
--- a/torch/_inductor/fx_passes/group_batch_fusion.py
+++ b/torch/_inductor/fx_passes/group_batch_fusion.py
@@ -129,7 +129,7 @@
with graph.inserting_before(subset[0]):
fused_mm = graph.call_function(
- torch.ops.fbgemm.gmm,
+ torch.ops.fbgemm.gmm.default,
args=(group_inputs, group_weights, group_biases),
)
diff --git a/torch/_inductor/fx_passes/quantization.py b/torch/_inductor/fx_passes/quantization.py
index cc120c0..a46b8db 100644
--- a/torch/_inductor/fx_passes/quantization.py
+++ b/torch/_inductor/fx_passes/quantization.py
@@ -533,7 +533,7 @@
)
_register_quantized_maxpool2d_lowering(
generate_pattern_with_output_quant(dequantize_maxpool2d_get_item_pattern),
- quantized.max_pool2d,
+ quantized.max_pool2d.default,
)
diff --git a/torch/_inductor/graph.py b/torch/_inductor/graph.py
index dff6f32..a8eadca 100644
--- a/torch/_inductor/graph.py
+++ b/torch/_inductor/graph.py
@@ -605,6 +605,9 @@
return target(*args, **kwargs)
if target not in lowerings:
+ assert isinstance(
+ target, torch._ops.OpOverload
+ ), f"{target} is not an OpOverload"
base_name = target.name().split(".")[0]
if base_name in FALLBACK_ALLOW_LIST:
make_fallback(target)
diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py
index 5aed1bb..0a21626 100644
--- a/torch/_inductor/ir.py
+++ b/torch/_inductor/ir.py
@@ -3841,12 +3841,10 @@
self.op_overload = kernel
- # TODO: Need to revisit schema matching to find the correct OpOverload from OpOverloadPacket
assert isinstance(
kernel,
(
torch._ops.OpOverload,
- torch._ops.OpOverloadPacket,
torch._ops.HigherOrderOperator,
),
), f"Fails to create FallbackKernel for {kernel}: {type(kernel)} not supported"
@@ -3858,19 +3856,17 @@
else kernel.__name__
)
if V.graph.cpp_wrapper:
- if isinstance(kernel, torch._ops.OpOverload):
- # Calling with the default kernel name can lead to ambiguous behavior like the following example.
- # repeat_interleave(const at::Tensor & repeats, c10::optional<int64_t> output_size=c10::nullopt)
- # repeat_interleave(const at::Tensor & self, int64_t repeats,
- # c10::optional<int64_t> dim=c10::nullopt, c10::optional<int64_t> output_size=c10::nullopt)
- self.kernel = (
- f"at::_ops::{kernel.__name__.replace('.default', '')}::call"
- if kernel._overloadname == "default"
- else f"at::_ops::{kernel.__name__.replace('.', '_')}::call"
- )
- schema = kernel._schema
- else:
- self.kernel = f"at::{op_base_name}"
+ assert isinstance(kernel, torch._ops.OpOverload)
+ # Calling with the default kernel name can lead to ambiguous behavior like the following example.
+ # repeat_interleave(const at::Tensor & repeats, c10::optional<int64_t> output_size=c10::nullopt)
+ # repeat_interleave(const at::Tensor & self, int64_t repeats,
+ # c10::optional<int64_t> dim=c10::nullopt, c10::optional<int64_t> output_size=c10::nullopt)
+ self.kernel = (
+ f"at::{op_base_name}"
+ if kernel._overloadname == "default"
+ else f"at::_ops::{kernel.__name__.replace('.', '_')}::call"
+ )
+ schema = kernel._schema
else:
self.kernel = f"aten.{op_base_name}"
diff --git a/torch/_inductor/lowering.py b/torch/_inductor/lowering.py
index 3690c44..2fb4bf3 100644
--- a/torch/_inductor/lowering.py
+++ b/torch/_inductor/lowering.py
@@ -959,7 +959,7 @@
input.realize()
if all(len(input.layout.size) == 4 for input in inputs):
inputs, _ = require_channels_last(aten.cat, *inputs)
- return fallback_handler(aten.cat)(inputs, dim)
+ return fallback_handler(aten.cat.default)(inputs, dim)
if len(inputs) == 1:
return clone(inputs[0])
@@ -1550,11 +1550,9 @@
return check_skip_condition(node, is_output=True)
-def make_fallback(kernel, layout_constraint=None, warn=True):
- assert (
- kernel not in decompositions
- ), f"both a fallback and a decomp for same kernel: {kernel}"
- if get_decompositions([kernel]) and warn and bool(os.getenv("CI")):
+def make_fallback(op, layout_constraint=None, warn=True):
+ assert op not in decompositions, f"both a fallback and a decomp for same op: {op}"
+ if get_decompositions([op]) and warn and bool(os.getenv("CI")):
# Note: 'warn' is holdover from when this was a warning, but for ops that previously
# set warn=False we do not want a CI error.
# Ignore the 'suppress errors' configs in CI, as this particular warning happens on startup anyway and is not
@@ -1566,16 +1564,28 @@
" and suppress_errors is being disabled to surface it."
)
raise AssertionError(
- f"make_fallback({kernel}): a decomposition exists, we should switch to it."
+ f"make_fallback({op}): a decomposition exists, we should switch to it."
" To fix this error, either add a decomposition to core_aten_decompositions (preferred)"
" or inductor_decompositions, and delete the corresponding `make_fallback` line."
" Get help from the inductor team if unsure, don't pick arbitrarily to unblock yourself.",
)
- add_needs_realized_inputs(kernel)
- if layout_constraint is not None:
- add_layout_constraint(kernel, layout_constraint)
- return register_lowering(kernel, type_promotion_kind=None)(fallback_handler(kernel))
+ def register_fallback(op_overload):
+ add_needs_realized_inputs(op_overload)
+ if layout_constraint is not None:
+ add_layout_constraint(op_overload, layout_constraint)
+ return register_lowering(op_overload, type_promotion_kind=None)(
+ fallback_handler(op_overload)
+ )
+
+ if isinstance(op, torch._ops.OpOverloadPacket):
+ for ol in op.overloads():
+ op_overload = getattr(op, ol)
+ register_fallback(op_overload)
+ elif isinstance(op, (torch._ops.OpOverload, torch._ops.HigherOrderOperator)):
+ register_fallback(op)
+ else:
+ raise RuntimeError(f"Unsupported fallback {op} with type {type(op)}")
def philox_rand_offset(shape):
@@ -1633,7 +1643,8 @@
def native_dropout(x, p, train):
if config.fallback_random:
return pytree.tree_map(
- TensorBox.create, ir.FallbackKernel.create(aten.native_dropout, x, p, train)
+ TensorBox.create,
+ ir.FallbackKernel.create(aten.native_dropout.default, x, p, train),
)
else:
raise AssertionError("should be handled in replace_random.py")
@@ -1673,22 +1684,28 @@
_warn_triton_random(V.graph.creation_time)
-fallback_rand = fallback_handler(aten.rand)
-fallback_randn = fallback_handler(aten.randn)
+fallback_rand_default = fallback_handler(aten.rand.default)
+fallback_rand_generator = fallback_handler(aten.rand.generator)
+fallback_randn_default = fallback_handler(aten.randn.default)
+fallback_randn_generator = fallback_handler(aten.randn.generator)
make_fallback(aten.randint)
@register_lowering(aten.rand)
def rand(*args, **kwargs):
- if config.fallback_random or kwargs.get("generator", None) is not None:
- return fallback_rand(*args, **kwargs)
+ if kwargs.get("generator", None) is not None:
+ return fallback_rand_generator(*args, **kwargs)
+ elif config.fallback_random:
+ return fallback_rand_default(*args, **kwargs)
raise AssertionError("should have been handled in replace_random.py")
@register_lowering(aten.randn)
def randn(*args, **kwargs):
- if config.fallback_random or kwargs.get("generator", None) is not None:
- return fallback_randn(*args, **kwargs)
+ if kwargs.get("generator", None) is not None:
+ return fallback_randn_generator(*args, **kwargs)
+ elif config.fallback_random:
+ return fallback_randn_default(*args, **kwargs)
raise AssertionError("should have been handled in replace_random.py")
@@ -1790,7 +1807,7 @@
assert len(boundaries.get_size()) == 1
if not (is_triton(input) and is_triton(boundaries)):
- return fallback_handler(aten.bucketize, add_to_fallback_set=False)(
+ return fallback_handler(aten.bucketize.Tensor, add_to_fallback_set=False)(
input, boundaries, out_int32=out_int32, right=right
)
@@ -2674,7 +2691,7 @@
except NotImplementedError:
# Fallback to ATen for boolean indexing
x.realize()
- return fallback_handler(aten.index)(x, indices)
+ return fallback_handler(aten.index.Tensor)(x, indices)
@register_lowering(aten._unsafe_index, type_promotion_kind=None)
@@ -3464,7 +3481,9 @@
return x_out, ceil_mode
-fallback_max_pool2d_with_indices = fallback_handler(aten.max_pool2d_with_indices)
+fallback_max_pool2d_with_indices = fallback_handler(
+ aten.max_pool2d_with_indices.default
+)
@register_lowering(aten.max_pool2d_with_indices, type_promotion_kind=None)
@@ -3549,7 +3568,7 @@
fallback_max_pool2d_with_indices_backward = fallback_handler(
- aten.max_pool2d_with_indices_backward
+ aten.max_pool2d_with_indices_backward.default
)
@@ -3770,7 +3789,7 @@
return fn_sum
-fallback_adaptive_avg_pool2d = fallback_handler(aten._adaptive_avg_pool2d)
+fallback_adaptive_avg_pool2d = fallback_handler(aten._adaptive_avg_pool2d.default)
@register_lowering(aten._adaptive_avg_pool2d)
@@ -3890,7 +3909,7 @@
return rv
-fallback_avg_pool2d = fallback_handler(aten.avg_pool2d)
+fallback_avg_pool2d = fallback_handler(aten.avg_pool2d.default)
@register_lowering(aten.avg_pool2d, type_promotion_kind=None)
@@ -3987,7 +4006,7 @@
return rv
-fallback_avg_pool2d_backward = fallback_handler(aten.avg_pool2d_backward)
+fallback_avg_pool2d_backward = fallback_handler(aten.avg_pool2d_backward.default)
@register_lowering(aten.avg_pool2d_backward, type_promotion_kind=None)
@@ -4390,7 +4409,9 @@
return ops.pow(a, b)
-fallback_pow = fallback_handler(aten.pow)
+fallback_pow_tensor_tensor = fallback_handler(aten.pow.Tensor_Tensor)
+fallback_pow_scalar = fallback_handler(aten.pow.Scalar)
+fallback_pow_tensor_scalar = fallback_handler(aten.pow.Tensor_Scalar)
@register_lowering(aten.pow, broadcast=True)
@@ -4431,7 +4452,12 @@
if is_integer_pow:
# ops.pow doesn't work for integers
- return fallback_pow(a, b)
+ if isinstance(a, Number):
+ return fallback_pow_scalar(a, b)
+ elif isinstance(b, Number):
+ return fallback_pow_tensor_scalar(a, b)
+ else:
+ return fallback_pow_tensor_tensor(a, b)
return pow_native(a, b)