Add efficient zero tensors (#64837)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/64837
Test Plan: Imported from OSS
Reviewed By: gchanan
Differential Revision: D32834987
Pulled By: anjali411
fbshipit-source-id: 20ea08ade0db0044ca633d9c1a117a6a2e65d1fd
diff --git a/BUILD.bazel b/BUILD.bazel
index 703abac..34a00fb 100644
--- a/BUILD.bazel
+++ b/BUILD.bazel
@@ -136,6 +136,7 @@
"aten/src/ATen/RegisterQuantizedCPU.cpp",
"aten/src/ATen/RegisterSparseCPU.cpp",
"aten/src/ATen/RegisterSparseCsrCPU.cpp",
+ "aten/src/ATen/RegisterZeroTensor.cpp",
"aten/src/ATen/RegisterCompositeImplicitAutograd.cpp",
"aten/src/ATen/RegisterCompositeExplicitAutograd.cpp",
"aten/src/ATen/RegisterMeta.cpp",
diff --git a/aten/src/ATen/ConjugateFallback.cpp b/aten/src/ATen/ConjugateFallback.cpp
index 937da7a..8fc6258 100644
--- a/aten/src/ATen/ConjugateFallback.cpp
+++ b/aten/src/ATen/ConjugateFallback.cpp
@@ -51,6 +51,7 @@
TORCH_VIEW_FNS(m)
TENSOR_UTILITIES_AND_CONSTRUCTORS(m)
+ TORCH_VIEW_FNS_NATIVE_FN_REGISTRATION(m)
}
}
diff --git a/aten/src/ATen/ZeroTensorFallback.cpp b/aten/src/ATen/ZeroTensorFallback.cpp
new file mode 100644
index 0000000..6d7fee0
--- /dev/null
+++ b/aten/src/ATen/ZeroTensorFallback.cpp
@@ -0,0 +1,104 @@
+#include <ATen/ATen.h>
+#include <ATen/core/dispatch/Dispatcher.h>
+#include <ATen/core/op_registration/op_registration.h>
+#include <ATen/native/UnaryOps.h>
+#include <ATen/NativeFunctions.h>
+#include <c10/util/irange.h>
+#include <torch/library.h>
+#include <ATen/native/MathBitFallThroughLists.h>
+
+namespace at {
+
+ // TODO: add a note explaining the design decisions
+ // ZeroTensors are designed to be immutable. Thus, we error out when an in-place operation is performed on ZeroTensors
+ void zeroTensorFallback(const c10::OperatorHandle& op, DispatchKeySet dispatch_keys, torch::jit::Stack* stack) {
+ const auto& arguments = op.schema().arguments();
+ const auto num_arguments = arguments.size();
+ const auto stack_start = stack->size() - num_arguments;
+
+ c10::optional<bool> is_write;
+ for (const auto i : c10::irange(num_arguments)) {
+ const auto& alias_info = arguments[i].alias_info();
+ if (alias_info != nullptr) {
+ if (is_write.has_value()) {
+ TORCH_CHECK(*is_write == alias_info->isWrite(),
+ "Unsupported operator for ", "ZeroTensorFallback: ", op.schema().name(),
+ "ZeroTensor fallback doesn't work for operators with a mix "
+ "mutable and non-mutable inputs that alias with outputs, "
+ "this must be implemented manually. "
+ "If you got this error on a core op, please report a bug to PyTorch.");
+ } else {
+ is_write = alias_info->isWrite();
+ }
+ }
+ }
+
+ if (is_write.has_value() && !*is_write) {
+ // We assume that view operators automatically handle the ZeroTensor bit
+ // correctly by propagating the dispatch key in key_set.
+ // This is not necessarily always right, so you should test these cases.
+ op.redispatchBoxed(dispatch_keys & c10::DispatchKeySet(DispatchKeySet::FULL_AFTER, DispatchKey::ZeroTensor), stack);
+ return;
+ }
+
+ for (const auto i : c10::irange(num_arguments)) {
+ auto& ivalue = (*stack)[stack_start + i];
+ if (!(ivalue.isTensor() || ivalue.isTensorList())) {
+ continue;
+ }
+ const auto& argument = arguments[i];
+ bool mut_arg = false;
+
+ if (argument.alias_info()) {
+ // Was already tested by is_write loop above
+ TORCH_INTERNAL_ASSERT_DEBUG_ONLY(argument.alias_info()->isWrite());
+ mut_arg = true;
+ }
+
+ if (ivalue.isTensor()) {
+ auto tensor = std::move(ivalue).toTensor();
+ if (tensor._is_zerotensor()) {
+ TORCH_CHECK(!mut_arg, "ZeroTensors are immutable. Please use the materialized zero tensor ",
+ "obtained using .clone() if you want a mutable tensor.");
+ tensor = at::zeros({}, tensor.options()).expand(tensor.sizes());
+ }
+ (*stack)[stack_start + i] = std::move(tensor);
+ } else if (ivalue.isTensorList()) {
+ auto tensors = std::move(ivalue).toTensorList();
+ for(const auto j : c10::irange(tensors.size())) {
+ const Tensor& tensor = tensors[j];
+ if (tensor._is_zerotensor()) {
+ // TODO: assert requires_grad=False
+ //_like should not propagate zerotensor dispatch key
+ TORCH_CHECK(!mut_arg, "ZeroTensors are immutable. Please use the materialized zero tensor ",
+ "obtained using .clone() if you want a mutable tensor.");
+ tensors[j] = at::zeros({}, tensor.options()).expand(tensor.sizes());
+ }
+ }
+ (*stack)[stack_start + i] = std::move(tensors);
+ }
+ }
+
+ op.redispatchBoxed(dispatch_keys & c10::DispatchKeySet(DispatchKeySet::FULL_AFTER, DispatchKey::ZeroTensor), stack);
+ }
+
+
+ TORCH_LIBRARY_IMPL(_, ZeroTensor, m) {
+ m.fallback(torch::CppFunction::makeFromBoxedFunction<&zeroTensorFallback>());
+ }
+
+ TORCH_LIBRARY_IMPL(aten, ZeroTensor, m) {
+ m.impl("zeros_like", torch::CppFunction::makeFallthrough());
+ m.impl("mul.Scalar", torch::CppFunction::makeFallthrough());
+ m.impl("add.Scalar", torch::CppFunction::makeFallthrough());
+ m.impl("copy_", torch::CppFunction::makeFallthrough());
+ m.impl("clone", torch::CppFunction::makeFallthrough());
+ // The functions in the list below have a specific registeration in native_functions.yaml and
+ // do not use the fallback.
+ // m.impl("mul.Tensor", torch::CppFunction::makeFallthrough());
+ // m.impl("add.Tensor", torch::CppFunction::makeFallthrough());
+
+ TORCH_VIEW_FNS(m)
+ TENSOR_UTILITIES_AND_CONSTRUCTORS(m)
+ }
+} // namespace at
diff --git a/aten/src/ATen/core/TensorBase.h b/aten/src/ATen/core/TensorBase.h
index d55aa5d..15117af 100644
--- a/aten/src/ATen/core/TensorBase.h
+++ b/aten/src/ATen/core/TensorBase.h
@@ -304,6 +304,14 @@
return impl_->storage().is_alias_of(other.storage());
}
+ inline bool _is_zerotensor() const {
+ return impl_->_is_zerotensor();
+ }
+
+ inline void _set_zero(bool zero) const {
+ impl_->_set_zero(zero);
+ }
+
inline bool is_conj() const {
return impl_->is_conj();
}
diff --git a/aten/src/ATen/native/BinaryOps.cpp b/aten/src/ATen/native/BinaryOps.cpp
index 8a7ac46..2d86f06 100644
--- a/aten/src/ATen/native/BinaryOps.cpp
+++ b/aten/src/ATen/native/BinaryOps.cpp
@@ -7,7 +7,8 @@
#include <ATen/MemoryOverlap.h>
#include <ATen/NativeFunctions.h>
#include <ATen/native/TensorIterator.h>
-
+#include <ATen/ExpandUtils.h>
+#include <ATen/RedispatchFunctions.h>
#include <torch/library.h>
namespace at {
@@ -625,6 +626,45 @@
return at::mul_out(self, wrapped_scalar_tensor(other), self); // redispatch!
}
+Device correct_out_device(const Tensor& self, const Tensor& other) {
+ if (self.device() == at::kCPU){
+ return other.device();
+ } else {
+ return self.device();
+ }
+}
+
+Tensor mul_zerotensor(const Tensor& self, const Tensor& other) {
+ auto out_device = correct_out_device(self, other);
+ // hack to use the TensorIterator to get the correct broadcasting and type promotion logic
+ auto device_ = Device(DeviceType::Meta);
+ auto meta_out = at::redispatch::mul(c10::DispatchKeySet(at::DispatchKey::Meta), self.to(device_), other.to(device_));
+ return at::_efficientzerotensor(meta_out.sizes(), meta_out.options().device(out_device));
+}
+
+Tensor add_zerotensor(const Tensor& self, const Tensor& other, const Scalar& alpha) {
+ auto out_device = correct_out_device(self, other);
+ // hack to use the TensorIterator to get the correct broadcasting and type promotion logic
+ auto device_ = Device(DeviceType::Meta);
+ auto meta_out = at::redispatch::add(c10::DispatchKeySet(at::DispatchKey::Meta), self.to(device_), other.to(device_));
+
+ auto get_out_like = [&] (const Tensor& tensor)
+ {
+ auto sizes = meta_out.sizes();
+ return at::_to_copy(tensor.expand(sizes), meta_out.options().device(out_device));
+ };
+
+ if (self._is_zerotensor()) {
+ if (other._is_zerotensor()) {
+ return at::_efficientzerotensor(meta_out.sizes(), meta_out.options().device(out_device));
+ }
+ auto res = get_out_like(other);
+ return alpha.equal(1) ? res : res.mul(alpha);
+ } else {
+ return get_out_like(self);
+ }
+}
+
// multiply, alias for mul
Tensor& multiply_out(const Tensor& self, const Tensor& other, Tensor& result) {
return at::mul_out(result, self, other);
diff --git a/aten/src/ATen/native/Copy.cpp b/aten/src/ATen/native/Copy.cpp
index c28ca2b..cb8e0bb 100644
--- a/aten/src/ATen/native/Copy.cpp
+++ b/aten/src/ATen/native/Copy.cpp
@@ -244,6 +244,12 @@
auto maybe_outnames = namedinference::compute_broadcast_outnames(self, src);
{
NoNamesGuard guard;
+ if (self._is_zerotensor()) {
+ TORCH_CHECK(false, "ZeroTensors are immutable. Please materialize the tensor using `.clone()`, if you want a mutable zero tensor.");
+ }
+ if (src._is_zerotensor()) {
+ return self.zero_();
+ }
copy_impl(self, src, non_blocking);
}
namedinference::propagate_names_if_nonempty(self, maybe_outnames);
diff --git a/aten/src/ATen/native/MathBitFallThroughLists.h b/aten/src/ATen/native/MathBitFallThroughLists.h
index ea3fed7..59cf6a3 100644
--- a/aten/src/ATen/native/MathBitFallThroughLists.h
+++ b/aten/src/ATen/native/MathBitFallThroughLists.h
@@ -3,7 +3,6 @@
namespace at {
// views and their in-place version ops
#define TORCH_VIEW_FNS(m) \
- m.impl("as_strided", torch::CppFunction::makeFallthrough()); \
m.impl("as_strided_", torch::CppFunction::makeFallthrough()); \
m.impl("detach", torch::CppFunction::makeFallthrough()); \
m.impl("detach_", torch::CppFunction::makeFallthrough()); \
@@ -31,7 +30,6 @@
m.impl("unfold", torch::CppFunction::makeFallthrough()); \
m.impl("unsqueeze", torch::CppFunction::makeFallthrough()); \
m.impl("unsqueeze_", torch::CppFunction::makeFallthrough()); \
- m.impl("view", torch::CppFunction::makeFallthrough()); \
m.impl("view_as", torch::CppFunction::makeFallthrough()); \
m.impl("unbind.int", torch::CppFunction::makeFallthrough()); \
m.impl("unbind.Dimname", torch::CppFunction::makeFallthrough()); \
@@ -67,3 +65,7 @@
m.impl("is_floating_point", torch::CppFunction::makeFallthrough()); \
m.impl("requires_grad_", torch::CppFunction::makeFallthrough());
}
+
+#define TORCH_VIEW_FNS_NATIVE_FN_REGISTRATION(m) \
+ m.impl("as_strided", torch::CppFunction::makeFallthrough()); \
+ m.impl("view", torch::CppFunction::makeFallthrough()); \
diff --git a/aten/src/ATen/native/NegateFallback.cpp b/aten/src/ATen/native/NegateFallback.cpp
index ff23b70..3f88b02 100644
--- a/aten/src/ATen/native/NegateFallback.cpp
+++ b/aten/src/ATen/native/NegateFallback.cpp
@@ -35,6 +35,7 @@
TORCH_VIEW_FNS(m)
TENSOR_UTILITIES_AND_CONSTRUCTORS(m)
+ TORCH_VIEW_FNS_NATIVE_FN_REGISTRATION(m)
}
}
diff --git a/aten/src/ATen/native/Resize.cpp b/aten/src/ATen/native/Resize.cpp
index 1937a8b..67e0904 100644
--- a/aten/src/ATen/native/Resize.cpp
+++ b/aten/src/ATen/native/Resize.cpp
@@ -27,8 +27,7 @@
static auto kFunctorchWrappedTensors = DispatchKeySet({
DispatchKey::FuncTorchGradWrapper,
- DispatchKey::FuncTorchBatched,
- DispatchKey::FuncTorchPython});
+ DispatchKey::FuncTorchBatched});
static bool is_functorch_wrapped_tensor(const Tensor& tensor) {
auto key_set = tensor.unsafeGetTensorImpl()->key_set();
diff --git a/aten/src/ATen/native/TensorFactories.cpp b/aten/src/ATen/native/TensorFactories.cpp
index 0dc86e4..2afb822 100644
--- a/aten/src/ATen/native/TensorFactories.cpp
+++ b/aten/src/ATen/native/TensorFactories.cpp
@@ -349,9 +349,10 @@
namedinference::propagate_names(result, self.names());
}
- // never propagate Conjugate and Negative dispatch key
+ // never propagate Conjugate, Negative, and ZeroTensor dispatch key
result._set_conj(false);
result._set_neg(false);
+ result._set_zero(false);
return result;
}
@@ -435,6 +436,27 @@
namespace {
+// The ZeroTensor allocator ignores whatever allocation is requested and always
+// gives you nullptr
+struct ZeroTensorAllocator final : public at::Allocator {
+ ZeroTensorAllocator(at::Device device) : device_(device) {};
+ ~ZeroTensorAllocator() override = default;
+ static void deleter(void* const pointer) {
+ TORCH_INTERNAL_ASSERT(!pointer);
+ }
+ DataPtr allocate(const size_t nbytes) const override {
+ return {nullptr, nullptr, &deleter, device_};
+ }
+ DeleterFnPtr raw_deleter() const override {
+ return deleter;
+ }
+ at::Device device_;
+};
+
+at::Allocator* GetZeroTensorAllocator(ZeroTensorAllocator& zt) {
+ return &zt;
+}
+
// Performs dtype inference for full
TensorOptions infer_full_options(
const Scalar& fill_value,
@@ -1057,6 +1079,18 @@
return result.zero_();
}
+Tensor _efficientzerotensor(IntArrayRef size,
+ c10::optional<ScalarType> dtype,
+ c10::optional<Layout> layout,
+ c10::optional<Device> device,
+ c10::optional<bool> pin_memory) {
+ auto device_ = device_or_default(device);
+ auto allocator = ZeroTensorAllocator(device_);
+ auto dtype_ = dtype_or_default(dtype);
+ auto r = at::detail::empty_generic(size, GetZeroTensorAllocator(allocator), at::DispatchKey::ZeroTensor, dtype_, device_, c10::nullopt);
+ return r;
+}
+
Tensor& zeros_out(IntArrayRef size, Tensor& result) {
if (result.is_sparse()) {
result.sparse_resize_and_clear_(size, size.size(), 0.);
@@ -1427,7 +1461,11 @@
self = at::empty_like(src, src.options(), memory_format);
}
- self.copy_(src);
+ if (src._is_zerotensor()) {
+ self.zero_();
+ } else {
+ self.copy_(src);
+ }
return self;
}
diff --git a/aten/src/ATen/native/TypeProperties.cpp b/aten/src/ATen/native/TypeProperties.cpp
index a84f48f..a49e2a5 100644
--- a/aten/src/ATen/native/TypeProperties.cpp
+++ b/aten/src/ATen/native/TypeProperties.cpp
@@ -30,6 +30,10 @@
return self.is_signed();
}
+bool _is_zerotensor(const Tensor& self) {
+ return self._is_zerotensor();
+}
+
bool is_conj(const Tensor& self) {
return self.is_conj();
}
diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml
index fc6246b..258a766 100644
--- a/aten/src/ATen/native/native_functions.yaml
+++ b/aten/src/ATen/native/native_functions.yaml
@@ -439,6 +439,7 @@
SparseCPU, SparseCUDA: add_sparse
SparseCsrCPU, SparseCsrCUDA: add_sparse_csr
MkldnnCPU: mkldnn_add
+ ZeroTensor: add_zerotensor
- func: add_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> Tensor(a!)
device_check: NoCheck # TensorIterator
@@ -711,7 +712,7 @@
- func: as_strided(Tensor(a) self, int[] size, int[] stride, int? storage_offset=None) -> Tensor(a)
variants: function, method
dispatch:
- CPU, CUDA, Meta: as_strided_tensorimpl
+ ZeroTensor, CPU, CUDA, Meta: as_strided_tensorimpl
QuantizedCPU, QuantizedCUDA: as_strided_qtensorimpl
device_check: NoCheck
device_guard: False
@@ -2463,6 +2464,11 @@
device_guard: False
manual_cpp_binding: True
+- func: _is_zerotensor(Tensor self) -> bool
+ variants: function, method
+ device_guard: False
+ manual_cpp_binding: True
+
- func: is_neg(Tensor self) -> bool
variants: function, method
device_guard: False
@@ -3199,6 +3205,7 @@
dispatch:
SparseCPU, SparseCUDA: mul_sparse
MkldnnCPU: mkldnn_mul
+ ZeroTensor: mul_zerotensor
- func: mul_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)
device_check: NoCheck # TensorIterator
@@ -3667,7 +3674,7 @@
device_check: NoCheck
device_guard: False
dispatch:
- CPU, CUDA, Meta, QuantizedCPU, QuantizedCUDA: _reshape_alias
+ CPU, CUDA, Meta, QuantizedCPU, QuantizedCUDA, ZeroTensor: _reshape_alias
# We don't need to support mkldnn since this is handled explicitly by the reshape operator.
- func: _mkldnn_reshape(Tensor self, int[] shape) -> Tensor
@@ -4827,6 +4834,10 @@
device_check: NoCheck
device_guard: False
+- func: _efficientzerotensor(int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+ dispatch:
+ CompositeExplicitAutograd: _efficientzerotensor
+
- func: zeros(int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
- func: zeros.out(int[] size, *, Tensor(a!) out) -> Tensor(a!)
@@ -5886,7 +5897,7 @@
device_check: NoCheck
device_guard: False
dispatch:
- CPU, CUDA, Meta, QuantizedCPU, QuantizedCUDA: view
+ ZeroTensor, CPU, CUDA, Meta, QuantizedCPU, QuantizedCUDA: view
MkldnnCPU: mkldnn_view
# Warning: If you want to change the name or overload name of this
diff --git a/aten/src/ATen/templates/Functions.h b/aten/src/ATen/templates/Functions.h
index a7563a1..1a1c1d7 100644
--- a/aten/src/ATen/templates/Functions.h
+++ b/aten/src/ATen/templates/Functions.h
@@ -221,6 +221,10 @@
return tensor.is_inference();
}
+inline bool _is_zerotensor(const Tensor& tensor) {
+ return tensor._is_zerotensor();
+}
+
inline bool is_conj(const Tensor& tensor) {
return tensor.is_conj();
}
diff --git a/c10/core/DispatchKey.cpp b/c10/core/DispatchKey.cpp
index fc40019..96b714a 100644
--- a/c10/core/DispatchKey.cpp
+++ b/c10/core/DispatchKey.cpp
@@ -111,6 +111,9 @@
return "AutogradPrivateUse3";
case DispatchKey::AutogradOther:
return "AutogradOther";
+
+ case DispatchKey::ZeroTensor:
+ return "ZeroTensor";
case DispatchKey::BackendSelect:
return "BackendSelect";
case DispatchKey::Named:
@@ -149,8 +152,6 @@
// https://github.com/zou3519/functorch
// We plan on eventually upstreaming the prototype into core, at which
// point it will have a different design that should use fewer keys.
- case DispatchKey::FuncTorchPython:
- return "FuncTorchPython";
case DispatchKey::FuncTorchDynamicLayerBackMode:
return "FuncTorchDynamicLayerBackMode";
case DispatchKey::FuncTorchDynamicLayerFrontMode:
@@ -242,10 +243,10 @@
{"PrivateUse3", c10::DispatchKey::PrivateUse3},
{"BackendSelect", c10::DispatchKey::BackendSelect},
{"Python", c10::DispatchKey::Python},
- {"FuncTorchPython", c10::DispatchKey::FuncTorchPython},
{"Named", c10::DispatchKey::Named},
{"Conjugate", c10::DispatchKey::Conjugate},
{"Negative", c10::DispatchKey::Negative},
+ {"ZeroTensor", c10::DispatchKey::ZeroTensor},
{"FuncTorchDynamicLayerBackMode",
c10::DispatchKey::FuncTorchDynamicLayerBackMode},
{"ADInplaceOrView", c10::DispatchKey::ADInplaceOrView},
diff --git a/c10/core/DispatchKey.h b/c10/core/DispatchKey.h
index 2cf41af..1bb8268 100644
--- a/c10/core/DispatchKey.h
+++ b/c10/core/DispatchKey.h
@@ -139,7 +139,6 @@
BackendSelect,
Python,
- FuncTorchPython, // See Note [Out-of-tree vmap+grad prototype]
// The named dispatch key is set for any tensors with named dimensions.
// Although we have a dispatch key for named tensors, for historical reasons,
@@ -165,6 +164,8 @@
// This is implemented at a dispatch level right before any backends run
Negative,
+ ZeroTensor, // registered at build/aten/src/ATen/RegisterZeroTensor.cpp
+
// See Note [Out-of-tree vmap+grad prototype]. The purpose of this key
// is to insert code after the "autograd subsystem" runs, so this key should
// be directly after ADInplaceOrView and all of the autograd keys.
diff --git a/c10/core/TensorImpl.h b/c10/core/TensorImpl.h
index 857e28f..bee4da9 100644
--- a/c10/core/TensorImpl.h
+++ b/c10/core/TensorImpl.h
@@ -1057,6 +1057,24 @@
}
/**
+ * Whether or not the tensor is a zerotensor
+ */
+ inline bool _is_zerotensor() const {
+ return key_set_.has(DispatchKey::ZeroTensor);
+ }
+
+ /**
+ Set whether or not the tensor is a zero tensor
+ */
+ void _set_zero(bool value) {
+ if (value) {
+ TORCH_INTERNAL_ASSERT(false, "Please call `torch._efficientzerotensor` if you want to create a tensor with no storage.");
+ } else {
+ key_set_ = key_set_.remove(DispatchKey::ZeroTensor);
+ }
+ }
+
+ /**
* Whether or not the tensor should be negated
*/
inline bool is_neg() const {
diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml
index 3daf6d7..16521d7 100644
--- a/tools/autograd/derivatives.yaml
+++ b/tools/autograd/derivatives.yaml
@@ -2433,3 +2433,6 @@
- name: _test_warn_in_autograd(Tensor self) -> Tensor
self: warn_backwards(grad)
+
+- name: _efficientzerotensor(int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+ output_differentiability: [False]
diff --git a/tools/autograd/gen_variable_type.py b/tools/autograd/gen_variable_type.py
index 0952a33..803097f 100644
--- a/tools/autograd/gen_variable_type.py
+++ b/tools/autograd/gen_variable_type.py
@@ -319,7 +319,8 @@
FW_DERIVATIVE_DEFINED_GRAD_TEMPLATE = CodeTemplate("""\
auto ${inp}_t_raw = toNonOptFwGrad(${inp});
-auto ${inp}_t = ${inp}_t_raw.defined() ? ${inp}_t_raw : at::zeros_like(toNonOptTensor(${inp}));
+auto ${inp}_tensor = toNonOptTensor(${inp});
+auto ${inp}_t = ${inp}_t_raw.defined() ? ${inp}_t_raw : at::_efficientzerotensor(${inp}_tensor.sizes(), ${inp}_tensor.options());
""")
FW_DERIVATIVE_DEFINED_PRIMAL_TEMPLATE = CodeTemplate("""\
diff --git a/tools/codegen/gen.py b/tools/codegen/gen.py
index 072b83c..77a1ef6 100644
--- a/tools/codegen/gen.py
+++ b/tools/codegen/gen.py
@@ -1307,6 +1307,7 @@
# Meta is a magic key: it is automatically generated for structured
# kernels
DispatchKey.Meta,
+ DispatchKey.ZeroTensor,
]
# Only a limited set of dispatch keys get CPUFunctions.h headers generated
# for them; this is the set
diff --git a/tools/codegen/model.py b/tools/codegen/model.py
index 1643089..edcaf95 100644
--- a/tools/codegen/model.py
+++ b/tools/codegen/model.py
@@ -79,6 +79,7 @@
PrivateUse3 = auto()
EndOfBackendKeys = PrivateUse3
+ ZeroTensor = auto()
Meta = auto()
BackendSelect = auto()
Named = auto()
diff --git a/torch/_tensor_str.py b/torch/_tensor_str.py
index 7996f58..bc9ad2d 100644
--- a/torch/_tensor_str.py
+++ b/torch/_tensor_str.py
@@ -234,6 +234,9 @@
summarize = self.numel() > PRINT_OPTS.threshold
+ if self._is_zerotensor():
+ self = self.clone()
+
# handle the negative bit
if self.is_neg():
self = self.resolve_neg()
diff --git a/torch/autograd/gradcheck.py b/torch/autograd/gradcheck.py
index 5c09b06..3d1c7ec 100644
--- a/torch/autograd/gradcheck.py
+++ b/torch/autograd/gradcheck.py
@@ -871,8 +871,59 @@
raise GradcheckError('grad is incorrect size')
return True
+def _test_undefined_forward_mode(func, outputs, inputs):
+ fwAD = torch.autograd.forward_ad
-def _test_undefined_grad(func, outputs, inputs) -> bool:
+ inp_tensors_idx, inp_tensors = _get_inp_tensors(inputs)
+ all_v, all_u, all_u_dense = _make_vectors(inp_tensors, outputs, use_forward_ad=True)
+
+ tensor_inputs = tuple(i for i in inputs if is_tensor_like(i) and i.requires_grad)
+
+ with fwAD.dual_level():
+ fw_grads = []
+ dual_inputs = []
+ for i, inp in enumerate(inputs):
+ if is_tensor_like(inp) and inp.requires_grad:
+ if inp.layout == torch._mkldnn: # type: ignore[attr-defined]
+ raise ValueError("MKLDNN inputs are not support for forward AD gradcheck.")
+
+ inp = fwAD.make_dual(inp, torch.zeros_like(inp))
+ # If inp is a differentiable view, the dual might not be the tangent given to
+ # make_dual, so read it explicitly from the dual tensor
+ fw_grads.append(fwAD.unpack_dual(inp)[1])
+ dual_inputs.append(inp)
+
+ for i, (fw_grad, u) in enumerate(zip(fw_grads, all_u)):
+ fw_grad.copy_(u.view_as(fw_grad))
+
+ for idx, inp in enumerate(tensor_inputs):
+ dual_inp_obj = dual_inputs[idx]
+
+ # case 1 (Materialized Zero Tensor Tangent)
+ dual_inputs[idx] = fwAD.make_dual(inp, torch.zeros_like(inp))
+ raw_outputs = _as_tuple(func(*dual_inputs))
+ dual_outputs1 = filter(_is_float_or_complex_tensor, raw_outputs)
+
+ # case 2 (Efficient Zero Tensor Tangent since we don't make a dual object and pass a regular tensor)
+ dual_inputs[idx] = inp
+ raw_outputs = _as_tuple(func(*dual_inputs))
+ dual_outputs2 = filter(_is_float_or_complex_tensor, raw_outputs)
+
+ # reset
+ dual_inputs[idx] = dual_inp_obj
+
+ for index_o, (d_o1, d_o2) in enumerate(zip(dual_outputs1, dual_outputs2)):
+ val1, res1 = fwAD.unpack_dual(d_o1)
+ val2, res2 = fwAD.unpack_dual(d_o2)
+
+ if not (res1 is None or res2 is None):
+ if not torch.equal(res1, res2):
+ raise GradcheckError("Mismatch in tangent values for output with index: ", index_o,
+ " when input: ", inp, " has an undefined tangent value. ",
+ " Got: ", res1, " but expected: ", res2)
+ return True
+
+def _test_undefined_backward_mode(func, outputs, inputs) -> bool:
diff_input_list: List[torch.Tensor] = list(_iter_tensors(inputs, True))
if not diff_input_list:
raise GradcheckError("no Tensors requiring grad found in input")
@@ -984,7 +1035,8 @@
def _gradcheck_real_imag(gradcheck_fn, func, func_out, tupled_inputs, outputs, eps, rtol,
- atol, check_grad_dtypes, check_forward_ad, check_backward_ad, nondet_tol):
+ atol, check_grad_dtypes, check_forward_ad, check_backward_ad, nondet_tol,
+ check_undefined_grad):
complex_out_indices = [i for i, o in enumerate(outputs) if o.is_complex()]
has_any_complex_output = any(o.is_complex() for o in _as_tuple(func_out))
if check_backward_ad:
@@ -1023,10 +1075,14 @@
gradcheck_fn(real_fn, real_func_out, real_inputs, diff_real_func_out, eps,
rtol, atol, check_grad_dtypes, nondet_tol, complex_indices=complex_inp_indices,
use_forward_ad=True)
+ if check_undefined_grad:
+ _test_undefined_forward_mode(imag_fn, imag_func_out, imag_inputs)
+ _test_undefined_forward_mode(real_fn, real_func_out, real_inputs)
else:
gradcheck_fn(func, func_out, tupled_inputs, outputs, eps,
rtol, atol, check_grad_dtypes, nondet_tol, use_forward_ad=True)
-
+ if check_undefined_grad:
+ _test_undefined_forward_mode(func, outputs, tupled_inputs)
def _slow_gradcheck(func, func_out, tupled_inputs, outputs, eps, rtol, atol, check_grad_dtypes,
nondet_tol, *, use_forward_ad=False, complex_indices=None, test_imag=False):
@@ -1210,6 +1266,13 @@
atol, check_grad_dtypes, nondet_tol, *, use_forward_ad=False, complex_indices=None, test_imag=False):
# See https://github.com/pytorch/pytorch/issues/53876 for details
inp_tensors_idx, inp_tensors = _get_inp_tensors(inputs)
+ # Backward mode computes v^T * J (VJP)
+ # Since we computed J * u (JVP) through finite difference method, we perform an equality check
+ # between VJP * u, v * JVP
+ # ----
+ # Forward mode computes J * u (JVP)
+ # Since we already compute JVP through finite difference method,
+ # we don't need v for correctness check here as asserted below
all_v, all_u, all_u_dense = _make_vectors(inp_tensors, outputs, use_forward_ad=use_forward_ad)
numerical_vJu = _get_numerical_vJu(func, inputs, inp_tensors_idx, func_out, all_u, all_v, eps, is_forward_ad=use_forward_ad)
@@ -1348,7 +1411,8 @@
gradcheck_fn = _fast_gradcheck if fast_mode else _slow_gradcheck
_gradcheck_real_imag(gradcheck_fn, func, func_out, tupled_inputs, outputs, eps,
rtol, atol, check_grad_dtypes, check_forward_ad=check_forward_ad,
- check_backward_ad=check_backward_ad, nondet_tol=nondet_tol)
+ check_backward_ad=check_backward_ad, nondet_tol=nondet_tol,
+ check_undefined_grad=check_undefined_grad)
if check_batched_forward_grad:
_test_batched_grad_forward_ad(func, tupled_inputs)
@@ -1364,7 +1428,7 @@
_test_backward_mul_by_grad_output(outputs, tupled_inputs, check_sparse_nnz)
if check_undefined_grad:
- _test_undefined_grad(func, outputs, tupled_inputs)
+ _test_undefined_backward_mode(func, outputs, tupled_inputs)
return True
diff --git a/torch/overrides.py b/torch/overrides.py
index fcdcaff..31b8c3c 100644
--- a/torch/overrides.py
+++ b/torch/overrides.py
@@ -243,6 +243,7 @@
Tensor._conj,
Tensor._conj_physical,
Tensor._neg_view,
+ Tensor._is_zerotensor,
}
diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py
index 3f8ba49..0d325ea 100644
--- a/torch/testing/_internal/common_methods_invocations.py
+++ b/torch/testing/_internal/common_methods_invocations.py
@@ -9218,7 +9218,13 @@
supports_forward_ad=True,
promotes_int_to_float=True,
assert_autodiffed=True,
- rhs_make_tensor_kwargs=dict(exclude_zero=True)),
+ rhs_make_tensor_kwargs=dict(exclude_zero=True),
+ skips=(
+ DecorateInfo(unittest.skip("Skipped!"), 'TestGradients', 'test_forward_mode_AD',
+ device_type='cuda', dtypes=[torch.double, torch.cdouble]),
+ DecorateInfo(unittest.skip("Skipped!"), 'TestGradients', 'test_inplace_forward_mode_AD',
+ device_type='cuda', dtypes=[torch.double, torch.cdouble]),
+ ),),
BinaryUfuncInfo('div',
aliases=('divide',),
variant_test_name='trunc_rounding',
@@ -9227,7 +9233,13 @@
supports_forward_ad=True,
promotes_int_to_float=True,
assert_autodiffed=True,
- rhs_make_tensor_kwargs=dict(exclude_zero=True)),
+ rhs_make_tensor_kwargs=dict(exclude_zero=True),
+ skips=(
+ DecorateInfo(unittest.skip("Skipped!"), 'TestGradients', 'test_forward_mode_AD',
+ device_type='cuda', dtypes=[torch.double, torch.cdouble]),
+ DecorateInfo(unittest.skip("Skipped!"), 'TestGradients', 'test_inplace_forward_mode_AD',
+ device_type='cuda', dtypes=[torch.double, torch.cdouble]),
+ ),),
BinaryUfuncInfo('div',
aliases=('divide',),
variant_test_name='floor_rounding',
@@ -9236,7 +9248,13 @@
supports_forward_ad=True,
promotes_int_to_float=True,
assert_autodiffed=True,
- rhs_make_tensor_kwargs=dict(exclude_zero=True)),
+ rhs_make_tensor_kwargs=dict(exclude_zero=True),
+ skips=(
+ DecorateInfo(unittest.skip("Skipped!"), 'TestGradients', 'test_forward_mode_AD',
+ device_type='cuda', dtypes=[torch.double, torch.cdouble]),
+ DecorateInfo(unittest.skip("Skipped!"), 'TestGradients', 'test_inplace_forward_mode_AD',
+ device_type='cuda', dtypes=[torch.double, torch.cdouble]),
+ ),),
BinaryUfuncInfo('true_divide',
dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
supports_forward_ad=True,