Added conv constraint that infers layouts (#89031)
The core problem that we often have with contiguous/channels-last layouts and convolutions is that Inductor often doesn't do a great job of "preserving" the eager-mode layouts.
So, for example, we'll often have something like
```
a: channels-last
b = foo(a)
c = convolution(a)
```
In eager-mode, `a` would stay channels-last, and we would avoid two transpose copies (one into NHWC and one back into NCHW) within the convolution kernel.
However, Inductor currently sometimes loses the "correct" layout of `b` (not in this simple example, but others). Then, not only will we do a transpose within `foo`, but we'll then immediately transpose it back to do the convolution (and then again once the convolution is done).
This is particularly egregious in `convnext_base`, where there's a lot of mixing of non-channels last tensors and channels-last tensors.
The solution in this PR is to constrain the inputs to `aten.convolution`/`aten.convolution_backward` to match the layouts from eager-mode. This ensures that we'll never do extra transposes *within* `aten.convolution`, which are particularly bad (since Inductor can't fuse them).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/89031
Approved by: https://github.com/ngimel, https://github.com/jansel
diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py
index 1265ca3..651ef9e 100644
--- a/test/inductor/test_torchinductor.py
+++ b/test/inductor/test_torchinductor.py
@@ -65,7 +65,6 @@
from torch.testing._internal.inductor_utils import HAS_CPU, HAS_CUDA
aten = torch.ops.aten
-
requires_cuda = functools.partial(unittest.skipIf, not HAS_CUDA, "requires cuda")
torch._inductor.config.triton.autotune = False # too slow
@@ -5088,6 +5087,8 @@
return kernels
def test_divisibile_by_16_covers_numel_args(self):
+ torch._dynamo.reset()
+
def fn(a: torch.Tensor) -> torch.Tensor:
return torch.sum(a)
@@ -5107,6 +5108,7 @@
kernels[1].meta["configs"][0].divisible_by_16
)
self.assertEqual(arguments_that_are_divisible_by_16_in_kernel1, (0, 1))
+ torch._dynamo.reset()
if __name__ == "__main__":
diff --git a/torch/_inductor/graph.py b/torch/_inductor/graph.py
index e0e41fd..5114ffa 100644
--- a/torch/_inductor/graph.py
+++ b/torch/_inductor/graph.py
@@ -20,7 +20,12 @@
MissingOperatorWithoutDecomp,
)
from .ir import Constant, FixedLayout, InputBuffer, Pointwise, Reduction, TensorBox
-from .lowering import lowerings, make_fallback, needs_realized_inputs
+from .lowering import (
+ layout_constraints,
+ lowerings,
+ make_fallback,
+ needs_realized_inputs,
+)
from .sizevars import SizeVarAllocator
from .utils import dynamo_utils, gather_origins
from .virtualized import V
@@ -301,7 +306,12 @@
def run_node(self, n: torch.fx.Node):
with ir.IRNode.current_origins({n}):
- result = super().run_node(n)
+ if n.op == "call_function" and n.target in layout_constraints:
+ args, kwargs = self.fetch_args_kwargs_from_env(n)
+ args, kwargs = layout_constraints[n.target](n, *args, **kwargs)
+ result = self.call_function(n.target, args, kwargs)
+ else:
+ result = super().run_node(n)
# Realize if (1) any user need inputs realized, or (2) there is
# already too many reads and rematerializing can be bad.
@@ -310,7 +320,20 @@
for user in n.users:
if user.target in needs_realized_inputs:
result.realize_hint()
- elif user.op == "output":
+ # This inclusion is somewhat controversial (from
+ # discussion between Horace, Natalia, and Elias).
+ # Currently, it's not very clear why this is helpful.
+ # The general idea here is that even though a node may
+ # have FlexibleLayout, we still often *treat* it as if
+ # it was contiguous. This appears to sometime result in
+ # suboptimal behavior.
+ #
+ # When we do a better job selecting layout, we should
+ # revisit this.
+ result = ir.ExternKernel.require_stride_order(
+ result, ir.get_stride_order(n.meta["val"].stride())
+ )
+ if user.op == "output":
if isinstance(result.data.data, (Pointwise, Reduction)):
result.realize()
diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py
index 8327fe0..d547246 100644
--- a/torch/_inductor/ir.py
+++ b/torch/_inductor/ir.py
@@ -2478,6 +2478,9 @@
@classmethod
def require_stride_order(cls, x, order):
+ if x.get_numel() == 0: # Layout doesn't matter
+ return x
+
# require x to have the layout as strided_ordered as order
if is_storage_and_layout(x):
if isinstance(
diff --git a/torch/_inductor/lowering.py b/torch/_inductor/lowering.py
index 75d4e47..5168f37 100644
--- a/torch/_inductor/lowering.py
+++ b/torch/_inductor/lowering.py
@@ -23,7 +23,6 @@
from .decomposition import decompositions, get_decompositions
from .ir import (
ExpandView,
- get_stride_order,
IndexingConstant,
IndexingDiv,
PermuteView,
@@ -38,6 +37,7 @@
log = logging.getLogger(__name__)
lowerings = {}
+layout_constraints = {}
fallbacks = set()
aten = torch.ops.aten
prims = torch.ops.prims
@@ -53,6 +53,14 @@
needs_realized_inputs.add(getattr(fn, overload))
+def add_layout_constraint(fn, constraint):
+ if isinstance(fn, torch._ops.OpOverloadPacket):
+ for overload in fn.overloads():
+ layout_constraints[getattr(fn, overload)] = constraint
+ else:
+ layout_constraints[fn] = constraint
+
+
add_needs_realized_inputs(
[
aten.as_strided,
@@ -1013,12 +1021,10 @@
register_onednn_fusion_ops()
-def fallback_handler(kernel, inps_hook=None):
+def fallback_handler(kernel):
fallbacks.add(kernel)
def handler(*args, **kwargs):
- if inps_hook is not None:
- args, kwargs = inps_hook(*args, **kwargs)
return pytree.tree_map(
TensorBox.create, ir.FallbackKernel.create(kernel, *args, **kwargs)
)
@@ -1026,7 +1032,7 @@
return handler
-def make_fallback(kernel, inps_hook=None):
+def make_fallback(kernel, layout_constraint=None):
assert (
kernel not in decompositions
), f"both a fallback and a decomp for same kernel: {kernel}"
@@ -1036,9 +1042,9 @@
)
add_needs_realized_inputs(kernel)
- return register_lowering(kernel, type_promotion_kind=None)(
- fallback_handler(kernel, inps_hook)
- )
+ if layout_constraint is not None:
+ add_layout_constraint(kernel, layout_constraint)
+ return register_lowering(kernel, type_promotion_kind=None)(fallback_handler(kernel))
@register_lowering(aten.native_dropout, type_promotion_kind=None)
@@ -1189,72 +1195,14 @@
)
-def conv_backward(*args, **kwargs):
- # output striding complex and has a lot of build dependent options,
- # take the output strides to determine what to set the inputs
- with torch._subclasses.FakeTensorMode():
- args_fake, kwargs_fake = pytree.tree_map_only(
- ir.IRNode,
- lambda t: ir.ir_node_to_tensor(t, guard_shape=False),
- (args, kwargs),
- )
- output = aten.convolution_backward(*args_fake, **kwargs_fake)
-
- def constraints(
- grad_output,
- input,
- weight,
- bias_sizes,
- stride,
- padding,
- dilation,
- transposed,
- output_padding,
- groups,
- output_mask,
- ):
- out = (
- output[0]
- if output[0] is not None
- else output[1]
- if output[1] is not None
- else output[2]
- )
- if out is not None:
- stride_order = get_stride_order(out.stride())
- grad_output = ir.ExternKernel.require_stride_order(
- grad_output, stride_order
- )
- weight = ir.ExternKernel.require_stride_order(weight, stride_order)
- # Only make input contiguous when it is necessary for the backwards computation
- if output_mask[1]:
- input = ir.ExternKernel.require_stride_order(input, stride_order)
-
- return (
- grad_output,
- input,
- weight,
- bias_sizes,
- stride,
- padding,
- dilation,
- transposed,
- output_padding,
- groups,
- output_mask,
- ), {}
-
- return constraints(*args, **kwargs)
-
-
-def require_dense(*args, **kwargs):
+def require_dense(_, *args, **kwargs):
args, kwargs = pytree.tree_map_only(
ir.IRNode, lambda t: ir.ExternKernel.require_stride1(t), (args, kwargs)
)
return args, kwargs
-def require_contiguous(*args, **kwargs):
+def require_contiguous(_, *args, **kwargs):
args, kwargs = pytree.tree_map_only(
ir.IRNode, lambda t: ir.ExternKernel.require_contiguous(t), (args, kwargs)
)
@@ -1264,26 +1212,42 @@
if has_torchvision_roi_align():
make_fallback(torch.ops.torchvision.roi_align)
+
+def constrain_to_fx_strides(fx_node, *args, **kwargs):
+ def apply_constraint(arg, fx_arg):
+ if isinstance(arg, ir.IRNode):
+ stride_order = ir.get_stride_order(fx_arg.meta["val"].stride())
+ return ir.ExternKernel.require_stride_order(arg, stride_order)
+ return arg
+
+ args = [apply_constraint(arg, fx_arg) for arg, fx_arg in zip(args, fx_node.args)]
+ kwargs = {k: apply_constraint(v, fx_node.kwargs[k]) for k, v in kwargs.items()}
+ return args, kwargs
+
+
# TODO(jansel): we should implement decomps or lowerings for these
# https://github.com/pytorch/torchdynamo/issues/327
make_fallback(aten._adaptive_avg_pool2d_backward, require_dense)
-make_fallback(aten.convolution_backward, inps_hook=conv_backward)
+make_fallback(aten.convolution_backward, constrain_to_fx_strides)
make_fallback(aten._cudnn_rnn, require_dense)
-make_fallback(aten._cudnn_rnn_backward, inps_hook=require_contiguous)
-make_fallback(aten.cumsum, inps_hook=require_dense)
-make_fallback(aten._embedding_bag, inps_hook=require_contiguous)
-make_fallback(aten._embedding_bag_forward_only, inps_hook=require_contiguous)
+make_fallback(aten._cudnn_rnn_backward, require_contiguous)
+make_fallback(aten.cumsum, require_dense)
+make_fallback(aten._embedding_bag, require_contiguous)
+make_fallback(aten._embedding_bag_forward_only, require_contiguous)
make_fallback(aten._fused_moving_avg_obs_fq_helper)
make_fallback(aten._fused_moving_avg_obs_fq_helper_functional)
-make_fallback(aten.grid_sampler_2d_backward, inps_hook=require_dense)
+make_fallback(aten.grid_sampler_2d_backward, require_dense)
make_fallback(aten.randperm)
make_fallback(aten.sort)
make_fallback(aten.sort.stable)
make_fallback(aten._sparse_coo_tensor_with_dims_and_tensors)
-make_fallback(aten._thnn_fused_lstm_cell, inps_hook=require_dense)
+make_fallback(aten._thnn_fused_lstm_cell, require_dense)
make_fallback(aten.topk)
-make_fallback(aten.upsample_bicubic2d_backward, inps_hook=require_contiguous)
-make_fallback(aten.upsample_bilinear2d_backward, inps_hook=require_dense)
+make_fallback(aten.upsample_bicubic2d_backward, require_contiguous)
+make_fallback(aten.upsample_bilinear2d_backward, require_dense)
+
+
+add_layout_constraint(aten.convolution, constrain_to_fx_strides)
@register_lowering(aten.convolution)
diff --git a/torch/fx/experimental/proxy_tensor.py b/torch/fx/experimental/proxy_tensor.py
index c835607..8a51294 100644
--- a/torch/fx/experimental/proxy_tensor.py
+++ b/torch/fx/experimental/proxy_tensor.py
@@ -118,6 +118,11 @@
elif isinstance(val, torch.Tensor):
if not val.is_sparse:
proxy.node.meta['tensor_meta'] = _extract_tensor_metadata(val)
+ # NB: Kinda hacky, but we should try to get val as the metadata
+ # everywhere
+ fake_tensor_mode = FakeTensorMode(allow_fallback_kernels=True)
+ with fake_tensor_mode:
+ proxy.node.meta['val'] = torch.empty_strided(val.shape, val.stride(), device=val.device, dtype=val.dtype)
return proxy
def thunkify(f, *args, **kwargs):