Expose Channel Last 3d enum
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/32947
Test Plan: Imported from OSS
Differential Revision: D19707716
Pulled By: glaringlee
fbshipit-source-id: 03824769376043bc6151a4580aba27654de5077f
diff --git a/c10/core/MemoryFormat.h b/c10/core/MemoryFormat.h
index 58834f3..8c6bca6 100644
--- a/c10/core/MemoryFormat.h
+++ b/c10/core/MemoryFormat.h
@@ -24,7 +24,7 @@
namespace c10 {
-enum class MemoryFormat : int8_t { Contiguous, Preserve, ChannelsLast };
+enum class MemoryFormat : int8_t { Contiguous, Preserve, ChannelsLast, ChannelsLast3d };
// If you are seeing this, it means that this call site was not checked if
// the memory format could be preserved, and it was switched to old default
@@ -45,6 +45,8 @@
return stream << "Contiguous";
case MemoryFormat::ChannelsLast:
return stream << "ChannelsLast";
+ case MemoryFormat::ChannelsLast3d:
+ return stream << "ChannelsLast3d";
default:
AT_ERROR("Unknown memory format");
}
@@ -62,12 +64,12 @@
// Note [Ambiguous is_channels_last_strides]
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
-// The flaw of carrying memory_format implicitly through strides is very hard
+// The flaw of carrying memory_format implicitly through strides is very hard
// to WAR properly. issue #24090
// Without the history of permutation, we can't infer the memory_format of a
// tensor from the snapshot of its size & stride
// e.g.
-//
+//
// 1. We can NOT specify the memory_format of N111 tensor through strides in a
// meaningful way;
//
@@ -79,7 +81,7 @@
//
// Due to the limitations, our temporary WAR `is_channels_last_strides` does the
// best effort to infer whether the original memory_format of a tensor is
-// at::MemoryFormat::ChannelsLast. The two objectives of this function (ordered
+// at::MemoryFormat::ChannelsLast. The two objectives of this function (ordered
// by their importance):
// 1. Ensure that normal shape manipulation does not accidentally change the
// MemoryFormat of an existing tensor.
diff --git a/c10/core/TensorImpl.h b/c10/core/TensorImpl.h
index 9c804ba..94ebaa3 100644
--- a/c10/core/TensorImpl.h
+++ b/c10/core/TensorImpl.h
@@ -1356,6 +1356,14 @@
set_sizes_and_strides(sizes(), get_channels_last_strides(sizes()));
break;
}
+ case MemoryFormat::ChannelsLast3d: {
+ TORCH_CHECK(
+ dim() == 5,
+ "required rank 5 tensor to use channels_last_3d format");
+ TORCH_CHECK(false, "unsupported memory format ", memory_format);
+ //TODO Implement set_sizes_and_strides for channels last 3d
+ break;
+ }
case MemoryFormat::Preserve:
TORCH_CHECK(false, "unsupported memory format ", memory_format);
// Cleaning warning messages, no need to break as TORCH_CHECK(false)
diff --git a/torch/csrc/utils/tensor_memoryformats.cpp b/torch/csrc/utils/tensor_memoryformats.cpp
index 117e8e4..d8f1d23 100644
--- a/torch/csrc/utils/tensor_memoryformats.cpp
+++ b/torch/csrc/utils/tensor_memoryformats.cpp
@@ -31,6 +31,7 @@
_ADD_MEMORY_FORMAT(at::MemoryFormat::Preserve, "preserve_format");
_ADD_MEMORY_FORMAT(at::MemoryFormat::Contiguous, "contiguous_format");
_ADD_MEMORY_FORMAT(at::MemoryFormat::ChannelsLast, "channels_last");
+ _ADD_MEMORY_FORMAT(at::MemoryFormat::ChannelsLast3d, "_channels_last_3d");
}