Added conv constraint that infers layouts (#89031)

The core problem that we often have with contiguous/channels-last layouts and convolutions is that Inductor often doesn't do a great job of "preserving" the eager-mode layouts.

So, for example, we'll often have something like
```
a: channels-last
b = foo(a)
c = convolution(a)
```

In eager-mode, `a` would stay channels-last, and we would avoid two transpose copies (one into NHWC and one back into NCHW) within the convolution kernel.

However, Inductor currently sometimes loses the "correct" layout of `b` (not in this simple example, but others). Then, not only will we do a transpose within `foo`, but we'll then immediately transpose it back to do the convolution (and then again once the convolution is done).

This is particularly egregious in `convnext_base`, where there's a lot of mixing of non-channels last tensors and channels-last tensors.

The solution in this PR is to constrain the inputs to `aten.convolution`/`aten.convolution_backward` to match the layouts from eager-mode. This ensures that we'll never do extra transposes *within* `aten.convolution`, which are particularly bad (since Inductor can't fuse them).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/89031
Approved by: https://github.com/ngimel, https://github.com/jansel
diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py
index 1265ca3..651ef9e 100644
--- a/test/inductor/test_torchinductor.py
+++ b/test/inductor/test_torchinductor.py
@@ -65,7 +65,6 @@
 from torch.testing._internal.inductor_utils import HAS_CPU, HAS_CUDA
 
 aten = torch.ops.aten
-
 requires_cuda = functools.partial(unittest.skipIf, not HAS_CUDA, "requires cuda")
 
 torch._inductor.config.triton.autotune = False  # too slow
@@ -5088,6 +5087,8 @@
             return kernels
 
         def test_divisibile_by_16_covers_numel_args(self):
+            torch._dynamo.reset()
+
             def fn(a: torch.Tensor) -> torch.Tensor:
                 return torch.sum(a)
 
@@ -5107,6 +5108,7 @@
                 kernels[1].meta["configs"][0].divisible_by_16
             )
             self.assertEqual(arguments_that_are_divisible_by_16_in_kernel1, (0, 1))
+            torch._dynamo.reset()
 
 
 if __name__ == "__main__":
diff --git a/torch/_inductor/graph.py b/torch/_inductor/graph.py
index e0e41fd..5114ffa 100644
--- a/torch/_inductor/graph.py
+++ b/torch/_inductor/graph.py
@@ -20,7 +20,12 @@
     MissingOperatorWithoutDecomp,
 )
 from .ir import Constant, FixedLayout, InputBuffer, Pointwise, Reduction, TensorBox
-from .lowering import lowerings, make_fallback, needs_realized_inputs
+from .lowering import (
+    layout_constraints,
+    lowerings,
+    make_fallback,
+    needs_realized_inputs,
+)
 from .sizevars import SizeVarAllocator
 from .utils import dynamo_utils, gather_origins
 from .virtualized import V
@@ -301,7 +306,12 @@
 
     def run_node(self, n: torch.fx.Node):
         with ir.IRNode.current_origins({n}):
-            result = super().run_node(n)
+            if n.op == "call_function" and n.target in layout_constraints:
+                args, kwargs = self.fetch_args_kwargs_from_env(n)
+                args, kwargs = layout_constraints[n.target](n, *args, **kwargs)
+                result = self.call_function(n.target, args, kwargs)
+            else:
+                result = super().run_node(n)
 
             # Realize if (1) any user need inputs realized, or (2) there is
             # already too many reads and rematerializing can be bad.
@@ -310,7 +320,20 @@
                 for user in n.users:
                     if user.target in needs_realized_inputs:
                         result.realize_hint()
-                    elif user.op == "output":
+                        # This inclusion is somewhat controversial (from
+                        # discussion between Horace, Natalia, and Elias).
+                        # Currently, it's not very clear why this is helpful.
+                        # The general idea here is that even though a node may
+                        # have FlexibleLayout, we still often *treat* it as if
+                        # it was contiguous. This appears to sometime result in
+                        # suboptimal behavior.
+                        #
+                        # When we do a better job selecting layout, we should
+                        # revisit this.
+                        result = ir.ExternKernel.require_stride_order(
+                            result, ir.get_stride_order(n.meta["val"].stride())
+                        )
+                    if user.op == "output":
                         if isinstance(result.data.data, (Pointwise, Reduction)):
                             result.realize()
 
diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py
index 8327fe0..d547246 100644
--- a/torch/_inductor/ir.py
+++ b/torch/_inductor/ir.py
@@ -2478,6 +2478,9 @@
 
     @classmethod
     def require_stride_order(cls, x, order):
+        if x.get_numel() == 0:  # Layout doesn't matter
+            return x
+
         # require x to have the layout as strided_ordered as order
         if is_storage_and_layout(x):
             if isinstance(
diff --git a/torch/_inductor/lowering.py b/torch/_inductor/lowering.py
index 75d4e47..5168f37 100644
--- a/torch/_inductor/lowering.py
+++ b/torch/_inductor/lowering.py
@@ -23,7 +23,6 @@
 from .decomposition import decompositions, get_decompositions
 from .ir import (
     ExpandView,
-    get_stride_order,
     IndexingConstant,
     IndexingDiv,
     PermuteView,
@@ -38,6 +37,7 @@
 
 log = logging.getLogger(__name__)
 lowerings = {}
+layout_constraints = {}
 fallbacks = set()
 aten = torch.ops.aten
 prims = torch.ops.prims
@@ -53,6 +53,14 @@
             needs_realized_inputs.add(getattr(fn, overload))
 
 
+def add_layout_constraint(fn, constraint):
+    if isinstance(fn, torch._ops.OpOverloadPacket):
+        for overload in fn.overloads():
+            layout_constraints[getattr(fn, overload)] = constraint
+    else:
+        layout_constraints[fn] = constraint
+
+
 add_needs_realized_inputs(
     [
         aten.as_strided,
@@ -1013,12 +1021,10 @@
 register_onednn_fusion_ops()
 
 
-def fallback_handler(kernel, inps_hook=None):
+def fallback_handler(kernel):
     fallbacks.add(kernel)
 
     def handler(*args, **kwargs):
-        if inps_hook is not None:
-            args, kwargs = inps_hook(*args, **kwargs)
         return pytree.tree_map(
             TensorBox.create, ir.FallbackKernel.create(kernel, *args, **kwargs)
         )
@@ -1026,7 +1032,7 @@
     return handler
 
 
-def make_fallback(kernel, inps_hook=None):
+def make_fallback(kernel, layout_constraint=None):
     assert (
         kernel not in decompositions
     ), f"both a fallback and a decomp for same kernel: {kernel}"
@@ -1036,9 +1042,9 @@
         )
 
     add_needs_realized_inputs(kernel)
-    return register_lowering(kernel, type_promotion_kind=None)(
-        fallback_handler(kernel, inps_hook)
-    )
+    if layout_constraint is not None:
+        add_layout_constraint(kernel, layout_constraint)
+    return register_lowering(kernel, type_promotion_kind=None)(fallback_handler(kernel))
 
 
 @register_lowering(aten.native_dropout, type_promotion_kind=None)
@@ -1189,72 +1195,14 @@
     )
 
 
-def conv_backward(*args, **kwargs):
-    # output striding complex and has a lot of build dependent options,
-    # take the output strides to determine what to set the inputs
-    with torch._subclasses.FakeTensorMode():
-        args_fake, kwargs_fake = pytree.tree_map_only(
-            ir.IRNode,
-            lambda t: ir.ir_node_to_tensor(t, guard_shape=False),
-            (args, kwargs),
-        )
-        output = aten.convolution_backward(*args_fake, **kwargs_fake)
-
-    def constraints(
-        grad_output,
-        input,
-        weight,
-        bias_sizes,
-        stride,
-        padding,
-        dilation,
-        transposed,
-        output_padding,
-        groups,
-        output_mask,
-    ):
-        out = (
-            output[0]
-            if output[0] is not None
-            else output[1]
-            if output[1] is not None
-            else output[2]
-        )
-        if out is not None:
-            stride_order = get_stride_order(out.stride())
-            grad_output = ir.ExternKernel.require_stride_order(
-                grad_output, stride_order
-            )
-            weight = ir.ExternKernel.require_stride_order(weight, stride_order)
-            # Only make input contiguous when it is necessary for the backwards computation
-            if output_mask[1]:
-                input = ir.ExternKernel.require_stride_order(input, stride_order)
-
-        return (
-            grad_output,
-            input,
-            weight,
-            bias_sizes,
-            stride,
-            padding,
-            dilation,
-            transposed,
-            output_padding,
-            groups,
-            output_mask,
-        ), {}
-
-    return constraints(*args, **kwargs)
-
-
-def require_dense(*args, **kwargs):
+def require_dense(_, *args, **kwargs):
     args, kwargs = pytree.tree_map_only(
         ir.IRNode, lambda t: ir.ExternKernel.require_stride1(t), (args, kwargs)
     )
     return args, kwargs
 
 
-def require_contiguous(*args, **kwargs):
+def require_contiguous(_, *args, **kwargs):
     args, kwargs = pytree.tree_map_only(
         ir.IRNode, lambda t: ir.ExternKernel.require_contiguous(t), (args, kwargs)
     )
@@ -1264,26 +1212,42 @@
 if has_torchvision_roi_align():
     make_fallback(torch.ops.torchvision.roi_align)
 
+
+def constrain_to_fx_strides(fx_node, *args, **kwargs):
+    def apply_constraint(arg, fx_arg):
+        if isinstance(arg, ir.IRNode):
+            stride_order = ir.get_stride_order(fx_arg.meta["val"].stride())
+            return ir.ExternKernel.require_stride_order(arg, stride_order)
+        return arg
+
+    args = [apply_constraint(arg, fx_arg) for arg, fx_arg in zip(args, fx_node.args)]
+    kwargs = {k: apply_constraint(v, fx_node.kwargs[k]) for k, v in kwargs.items()}
+    return args, kwargs
+
+
 # TODO(jansel): we should implement decomps or lowerings for these
 # https://github.com/pytorch/torchdynamo/issues/327
 make_fallback(aten._adaptive_avg_pool2d_backward, require_dense)
-make_fallback(aten.convolution_backward, inps_hook=conv_backward)
+make_fallback(aten.convolution_backward, constrain_to_fx_strides)
 make_fallback(aten._cudnn_rnn, require_dense)
-make_fallback(aten._cudnn_rnn_backward, inps_hook=require_contiguous)
-make_fallback(aten.cumsum, inps_hook=require_dense)
-make_fallback(aten._embedding_bag, inps_hook=require_contiguous)
-make_fallback(aten._embedding_bag_forward_only, inps_hook=require_contiguous)
+make_fallback(aten._cudnn_rnn_backward, require_contiguous)
+make_fallback(aten.cumsum, require_dense)
+make_fallback(aten._embedding_bag, require_contiguous)
+make_fallback(aten._embedding_bag_forward_only, require_contiguous)
 make_fallback(aten._fused_moving_avg_obs_fq_helper)
 make_fallback(aten._fused_moving_avg_obs_fq_helper_functional)
-make_fallback(aten.grid_sampler_2d_backward, inps_hook=require_dense)
+make_fallback(aten.grid_sampler_2d_backward, require_dense)
 make_fallback(aten.randperm)
 make_fallback(aten.sort)
 make_fallback(aten.sort.stable)
 make_fallback(aten._sparse_coo_tensor_with_dims_and_tensors)
-make_fallback(aten._thnn_fused_lstm_cell, inps_hook=require_dense)
+make_fallback(aten._thnn_fused_lstm_cell, require_dense)
 make_fallback(aten.topk)
-make_fallback(aten.upsample_bicubic2d_backward, inps_hook=require_contiguous)
-make_fallback(aten.upsample_bilinear2d_backward, inps_hook=require_dense)
+make_fallback(aten.upsample_bicubic2d_backward, require_contiguous)
+make_fallback(aten.upsample_bilinear2d_backward, require_dense)
+
+
+add_layout_constraint(aten.convolution, constrain_to_fx_strides)
 
 
 @register_lowering(aten.convolution)
diff --git a/torch/fx/experimental/proxy_tensor.py b/torch/fx/experimental/proxy_tensor.py
index c835607..8a51294 100644
--- a/torch/fx/experimental/proxy_tensor.py
+++ b/torch/fx/experimental/proxy_tensor.py
@@ -118,6 +118,11 @@
     elif isinstance(val, torch.Tensor):
         if not val.is_sparse:
             proxy.node.meta['tensor_meta'] = _extract_tensor_metadata(val)
+            # NB: Kinda hacky, but we should try to get val as the metadata
+            # everywhere
+            fake_tensor_mode = FakeTensorMode(allow_fallback_kernels=True)
+            with fake_tensor_mode:
+                proxy.node.meta['val'] = torch.empty_strided(val.shape, val.stride(), device=val.device, dtype=val.dtype)
     return proxy
 
 def thunkify(f, *args, **kwargs):