Expose Channel Last 3d enum

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/32947

Test Plan: Imported from OSS

Differential Revision: D19707716

Pulled By: glaringlee

fbshipit-source-id: 03824769376043bc6151a4580aba27654de5077f
diff --git a/c10/core/MemoryFormat.h b/c10/core/MemoryFormat.h
index 58834f3..8c6bca6 100644
--- a/c10/core/MemoryFormat.h
+++ b/c10/core/MemoryFormat.h
@@ -24,7 +24,7 @@
 
 
 namespace c10 {
-enum class MemoryFormat : int8_t { Contiguous, Preserve, ChannelsLast };
+enum class MemoryFormat : int8_t { Contiguous, Preserve, ChannelsLast, ChannelsLast3d };
 
 // If you are seeing this, it means that this call site was not checked if
 // the memory format could be preserved, and it was switched to old default
@@ -45,6 +45,8 @@
       return stream << "Contiguous";
     case MemoryFormat::ChannelsLast:
       return stream << "ChannelsLast";
+    case MemoryFormat::ChannelsLast3d:
+      return stream << "ChannelsLast3d";
     default:
       AT_ERROR("Unknown memory format");
   }
@@ -62,12 +64,12 @@
 
 // Note [Ambiguous is_channels_last_strides]
 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
-// The flaw of carrying memory_format implicitly through strides is very hard 
+// The flaw of carrying memory_format implicitly through strides is very hard
 // to WAR properly. issue #24090
 // Without the history of permutation, we can't infer the memory_format of a
 // tensor from the snapshot of its size & stride
 // e.g.
-// 
+//
 // 1. We can NOT specify the memory_format of N111 tensor through strides in a
 //  meaningful way;
 //
@@ -79,7 +81,7 @@
 //
 // Due to the limitations, our temporary WAR `is_channels_last_strides` does the
 // best effort to infer whether the original memory_format of a tensor is
-// at::MemoryFormat::ChannelsLast. The two objectives of this function (ordered 
+// at::MemoryFormat::ChannelsLast. The two objectives of this function (ordered
 // by their importance):
 //   1. Ensure that normal shape manipulation does not accidentally change the
 //      MemoryFormat of an existing tensor.
diff --git a/c10/core/TensorImpl.h b/c10/core/TensorImpl.h
index 9c804ba..94ebaa3 100644
--- a/c10/core/TensorImpl.h
+++ b/c10/core/TensorImpl.h
@@ -1356,6 +1356,14 @@
         set_sizes_and_strides(sizes(), get_channels_last_strides(sizes()));
         break;
       }
+      case MemoryFormat::ChannelsLast3d: {
+        TORCH_CHECK(
+            dim() == 5,
+            "required rank 5 tensor to use channels_last_3d format");
+        TORCH_CHECK(false, "unsupported memory format ", memory_format);
+        //TODO Implement set_sizes_and_strides for channels last 3d
+        break;
+      }
       case MemoryFormat::Preserve:
         TORCH_CHECK(false, "unsupported memory format ", memory_format);
         // Cleaning warning messages, no need to break as TORCH_CHECK(false)
diff --git a/torch/csrc/utils/tensor_memoryformats.cpp b/torch/csrc/utils/tensor_memoryformats.cpp
index 117e8e4..d8f1d23 100644
--- a/torch/csrc/utils/tensor_memoryformats.cpp
+++ b/torch/csrc/utils/tensor_memoryformats.cpp
@@ -31,6 +31,7 @@
   _ADD_MEMORY_FORMAT(at::MemoryFormat::Preserve, "preserve_format");
   _ADD_MEMORY_FORMAT(at::MemoryFormat::Contiguous, "contiguous_format");
   _ADD_MEMORY_FORMAT(at::MemoryFormat::ChannelsLast, "channels_last");
+  _ADD_MEMORY_FORMAT(at::MemoryFormat::ChannelsLast3d, "_channels_last_3d");
 
 }