functionalization: fix x.is_contiguous(channels_last) (#94195)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/94195
Approved by: https://github.com/ezyang
diff --git a/aten/src/ATen/FunctionalTensorWrapper.cpp b/aten/src/ATen/FunctionalTensorWrapper.cpp
index b7a939c..0b71d43 100644
--- a/aten/src/ATen/FunctionalTensorWrapper.cpp
+++ b/aten/src/ATen/FunctionalTensorWrapper.cpp
@@ -343,7 +343,7 @@
   return value_.unsafeGetTensorImpl()->numel();
 }
 bool FunctionalTensorWrapper::is_contiguous_custom(at::MemoryFormat memory_format) const {
-  return value_.unsafeGetTensorImpl()->is_contiguous();
+  return value_.unsafeGetTensorImpl()->is_contiguous(memory_format);
 }
 c10::SymIntArrayRef FunctionalTensorWrapper::sym_sizes_custom() const {
   return value_.unsafeGetTensorImpl()->sym_sizes();
diff --git a/test/test_functionalization.py b/test/test_functionalization.py
index cc9e9de..4c9865f 100644
--- a/test/test_functionalization.py
+++ b/test/test_functionalization.py
@@ -583,6 +583,21 @@
     return diagonal_scatter
     """)
 
+    def test_channels_last_contiguous(self):
+        def f(x):
+            return x.contiguous(memory_format=torch.channels_last)
+            tmp = torch.ones(2)
+            y = x.diagonal()
+            y.add_(tmp)
+            return x
+        x = torch.randn(4, 8, 8, 3).permute(0, 3, 1, 2)
+        self.assert_functionalization(f, x)
+        logs = self.get_logs(f, x).strip()
+        # There should be no clone in the graph
+        self.assertExpectedInline(logs, """\
+def forward(self, arg0_1):
+    return arg0_1""")
+
     def test_split(self):
         def f(x):
             # test: view ops that return multiple tensors (split)