Add efficient zero tensors (#64837)

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/64837

Test Plan: Imported from OSS

Reviewed By: gchanan

Differential Revision: D32834987

Pulled By: anjali411

fbshipit-source-id: 20ea08ade0db0044ca633d9c1a117a6a2e65d1fd
diff --git a/BUILD.bazel b/BUILD.bazel
index 703abac..34a00fb 100644
--- a/BUILD.bazel
+++ b/BUILD.bazel
@@ -136,6 +136,7 @@
         "aten/src/ATen/RegisterQuantizedCPU.cpp",
         "aten/src/ATen/RegisterSparseCPU.cpp",
         "aten/src/ATen/RegisterSparseCsrCPU.cpp",
+        "aten/src/ATen/RegisterZeroTensor.cpp",
         "aten/src/ATen/RegisterCompositeImplicitAutograd.cpp",
         "aten/src/ATen/RegisterCompositeExplicitAutograd.cpp",
         "aten/src/ATen/RegisterMeta.cpp",
diff --git a/aten/src/ATen/ConjugateFallback.cpp b/aten/src/ATen/ConjugateFallback.cpp
index 937da7a..8fc6258 100644
--- a/aten/src/ATen/ConjugateFallback.cpp
+++ b/aten/src/ATen/ConjugateFallback.cpp
@@ -51,6 +51,7 @@
 
   TORCH_VIEW_FNS(m)
   TENSOR_UTILITIES_AND_CONSTRUCTORS(m)
+  TORCH_VIEW_FNS_NATIVE_FN_REGISTRATION(m)
 }
 
 }
diff --git a/aten/src/ATen/ZeroTensorFallback.cpp b/aten/src/ATen/ZeroTensorFallback.cpp
new file mode 100644
index 0000000..6d7fee0
--- /dev/null
+++ b/aten/src/ATen/ZeroTensorFallback.cpp
@@ -0,0 +1,104 @@
+#include <ATen/ATen.h>
+#include <ATen/core/dispatch/Dispatcher.h>
+#include <ATen/core/op_registration/op_registration.h>
+#include <ATen/native/UnaryOps.h>
+#include <ATen/NativeFunctions.h>
+#include <c10/util/irange.h>
+#include <torch/library.h>
+#include <ATen/native/MathBitFallThroughLists.h>
+
+namespace at {
+
+  // TODO: add a note explaining the design decisions
+  // ZeroTensors are designed to be immutable. Thus, we error out when an in-place operation is performed on ZeroTensors
+  void zeroTensorFallback(const c10::OperatorHandle& op, DispatchKeySet dispatch_keys, torch::jit::Stack* stack) {
+    const auto& arguments = op.schema().arguments();
+    const auto num_arguments = arguments.size();
+    const auto stack_start = stack->size() - num_arguments;
+
+    c10::optional<bool> is_write;
+    for (const auto i : c10::irange(num_arguments)) {
+      const auto& alias_info = arguments[i].alias_info();
+      if (alias_info != nullptr) {
+        if (is_write.has_value()) {
+          TORCH_CHECK(*is_write == alias_info->isWrite(),
+            "Unsupported operator for ", "ZeroTensorFallback: ", op.schema().name(),
+            "ZeroTensor fallback doesn't work for operators with a mix "
+            "mutable and non-mutable inputs that alias with outputs, "
+            "this must be implemented manually.  "
+            "If you got this error on a core op, please report a bug to PyTorch.");
+        } else {
+          is_write = alias_info->isWrite();
+        }
+      }
+    }
+
+    if (is_write.has_value() && !*is_write) {
+      // We assume that view operators automatically handle the ZeroTensor bit
+      // correctly by propagating the dispatch key in key_set.
+      // This is not necessarily always right, so you should test these cases.
+      op.redispatchBoxed(dispatch_keys & c10::DispatchKeySet(DispatchKeySet::FULL_AFTER, DispatchKey::ZeroTensor), stack);
+      return;
+    }
+
+    for (const auto i : c10::irange(num_arguments)) {
+      auto& ivalue = (*stack)[stack_start + i];
+      if (!(ivalue.isTensor() || ivalue.isTensorList())) {
+        continue;
+      }
+      const auto& argument = arguments[i];
+      bool mut_arg = false;
+
+      if (argument.alias_info()) {
+        // Was already tested by is_write loop above
+        TORCH_INTERNAL_ASSERT_DEBUG_ONLY(argument.alias_info()->isWrite());
+        mut_arg = true;
+      }
+
+      if (ivalue.isTensor()) {
+        auto tensor = std::move(ivalue).toTensor();
+        if (tensor._is_zerotensor()) {
+          TORCH_CHECK(!mut_arg, "ZeroTensors are immutable. Please use the materialized zero tensor ",
+                    "obtained using .clone() if you want a mutable tensor.");
+          tensor = at::zeros({}, tensor.options()).expand(tensor.sizes());
+        }
+        (*stack)[stack_start + i] = std::move(tensor);
+      } else if (ivalue.isTensorList()) {
+        auto tensors = std::move(ivalue).toTensorList();
+        for(const auto j : c10::irange(tensors.size())) {
+          const Tensor& tensor = tensors[j];
+          if (tensor._is_zerotensor()) {
+            // TODO: assert requires_grad=False
+            //_like should not propagate zerotensor dispatch key
+            TORCH_CHECK(!mut_arg, "ZeroTensors are immutable. Please use the materialized zero tensor ",
+                    "obtained using .clone() if you want a mutable tensor.");
+            tensors[j] = at::zeros({}, tensor.options()).expand(tensor.sizes());
+          }
+        }
+        (*stack)[stack_start + i] = std::move(tensors);
+      }
+    }
+
+    op.redispatchBoxed(dispatch_keys & c10::DispatchKeySet(DispatchKeySet::FULL_AFTER, DispatchKey::ZeroTensor), stack);
+  }
+
+
+  TORCH_LIBRARY_IMPL(_, ZeroTensor, m) {
+    m.fallback(torch::CppFunction::makeFromBoxedFunction<&zeroTensorFallback>());
+  }
+
+  TORCH_LIBRARY_IMPL(aten, ZeroTensor, m) {
+    m.impl("zeros_like", torch::CppFunction::makeFallthrough());
+    m.impl("mul.Scalar", torch::CppFunction::makeFallthrough());
+    m.impl("add.Scalar", torch::CppFunction::makeFallthrough());
+    m.impl("copy_", torch::CppFunction::makeFallthrough());
+    m.impl("clone", torch::CppFunction::makeFallthrough());
+    // The functions in the list below have a specific registeration in native_functions.yaml and
+    // do not use the fallback.
+    // m.impl("mul.Tensor", torch::CppFunction::makeFallthrough());
+    // m.impl("add.Tensor", torch::CppFunction::makeFallthrough());
+
+    TORCH_VIEW_FNS(m)
+    TENSOR_UTILITIES_AND_CONSTRUCTORS(m)
+  }
+} // namespace at
diff --git a/aten/src/ATen/core/TensorBase.h b/aten/src/ATen/core/TensorBase.h
index d55aa5d..15117af 100644
--- a/aten/src/ATen/core/TensorBase.h
+++ b/aten/src/ATen/core/TensorBase.h
@@ -304,6 +304,14 @@
     return impl_->storage().is_alias_of(other.storage());
   }
 
+  inline bool _is_zerotensor() const {
+    return impl_->_is_zerotensor();
+  }
+
+  inline void _set_zero(bool zero) const {
+    impl_->_set_zero(zero);
+  }
+
   inline bool is_conj() const {
     return impl_->is_conj();
   }
diff --git a/aten/src/ATen/native/BinaryOps.cpp b/aten/src/ATen/native/BinaryOps.cpp
index 8a7ac46..2d86f06 100644
--- a/aten/src/ATen/native/BinaryOps.cpp
+++ b/aten/src/ATen/native/BinaryOps.cpp
@@ -7,7 +7,8 @@
 #include <ATen/MemoryOverlap.h>
 #include <ATen/NativeFunctions.h>
 #include <ATen/native/TensorIterator.h>
-
+#include <ATen/ExpandUtils.h>
+#include <ATen/RedispatchFunctions.h>
 #include <torch/library.h>
 
 namespace at {
@@ -625,6 +626,45 @@
   return at::mul_out(self, wrapped_scalar_tensor(other), self); // redispatch!
 }
 
+Device correct_out_device(const Tensor& self, const Tensor& other) {
+  if (self.device() == at::kCPU){
+      return other.device();
+  } else {
+    return self.device();
+  }
+}
+
+Tensor mul_zerotensor(const Tensor& self, const Tensor& other) {
+  auto out_device = correct_out_device(self, other);
+  // hack to use the TensorIterator to get the correct broadcasting and type promotion logic
+  auto device_ = Device(DeviceType::Meta);
+  auto meta_out = at::redispatch::mul(c10::DispatchKeySet(at::DispatchKey::Meta), self.to(device_), other.to(device_));
+  return at::_efficientzerotensor(meta_out.sizes(), meta_out.options().device(out_device));
+}
+
+Tensor add_zerotensor(const Tensor& self, const Tensor& other, const Scalar& alpha) {
+  auto out_device = correct_out_device(self, other);
+  // hack to use the TensorIterator to get the correct broadcasting and type promotion logic
+  auto device_ = Device(DeviceType::Meta);
+  auto meta_out = at::redispatch::add(c10::DispatchKeySet(at::DispatchKey::Meta), self.to(device_), other.to(device_));
+
+  auto get_out_like = [&] (const Tensor& tensor)
+  {
+      auto sizes = meta_out.sizes();
+      return at::_to_copy(tensor.expand(sizes), meta_out.options().device(out_device));
+  };
+
+  if (self._is_zerotensor()) {
+    if (other._is_zerotensor()) {
+      return at::_efficientzerotensor(meta_out.sizes(), meta_out.options().device(out_device));
+    }
+    auto res = get_out_like(other);
+    return alpha.equal(1) ? res : res.mul(alpha);
+  } else {
+    return get_out_like(self);
+  }
+}
+
 // multiply, alias for mul
 Tensor& multiply_out(const Tensor& self, const Tensor& other, Tensor& result) {
   return at::mul_out(result, self, other);
diff --git a/aten/src/ATen/native/Copy.cpp b/aten/src/ATen/native/Copy.cpp
index c28ca2b..cb8e0bb 100644
--- a/aten/src/ATen/native/Copy.cpp
+++ b/aten/src/ATen/native/Copy.cpp
@@ -244,6 +244,12 @@
   auto maybe_outnames = namedinference::compute_broadcast_outnames(self, src);
   {
     NoNamesGuard guard;
+    if (self._is_zerotensor()) {
+     TORCH_CHECK(false, "ZeroTensors are immutable. Please materialize the tensor using `.clone()`, if you want a mutable zero tensor.");
+    }
+    if (src._is_zerotensor()) {
+      return self.zero_();
+    }
     copy_impl(self, src, non_blocking);
   }
   namedinference::propagate_names_if_nonempty(self, maybe_outnames);
diff --git a/aten/src/ATen/native/MathBitFallThroughLists.h b/aten/src/ATen/native/MathBitFallThroughLists.h
index ea3fed7..59cf6a3 100644
--- a/aten/src/ATen/native/MathBitFallThroughLists.h
+++ b/aten/src/ATen/native/MathBitFallThroughLists.h
@@ -3,7 +3,6 @@
 namespace at {
 // views and their in-place version ops
 #define TORCH_VIEW_FNS(m) \
-  m.impl("as_strided", torch::CppFunction::makeFallthrough()); \
   m.impl("as_strided_", torch::CppFunction::makeFallthrough()); \
   m.impl("detach", torch::CppFunction::makeFallthrough()); \
   m.impl("detach_", torch::CppFunction::makeFallthrough()); \
@@ -31,7 +30,6 @@
   m.impl("unfold", torch::CppFunction::makeFallthrough()); \
   m.impl("unsqueeze", torch::CppFunction::makeFallthrough()); \
   m.impl("unsqueeze_", torch::CppFunction::makeFallthrough()); \
-  m.impl("view", torch::CppFunction::makeFallthrough()); \
   m.impl("view_as", torch::CppFunction::makeFallthrough()); \
   m.impl("unbind.int", torch::CppFunction::makeFallthrough()); \
   m.impl("unbind.Dimname", torch::CppFunction::makeFallthrough()); \
@@ -67,3 +65,7 @@
   m.impl("is_floating_point", torch::CppFunction::makeFallthrough()); \
   m.impl("requires_grad_", torch::CppFunction::makeFallthrough());
 }
+
+#define TORCH_VIEW_FNS_NATIVE_FN_REGISTRATION(m) \
+  m.impl("as_strided", torch::CppFunction::makeFallthrough()); \
+  m.impl("view", torch::CppFunction::makeFallthrough()); \
diff --git a/aten/src/ATen/native/NegateFallback.cpp b/aten/src/ATen/native/NegateFallback.cpp
index ff23b70..3f88b02 100644
--- a/aten/src/ATen/native/NegateFallback.cpp
+++ b/aten/src/ATen/native/NegateFallback.cpp
@@ -35,6 +35,7 @@
 
   TORCH_VIEW_FNS(m)
   TENSOR_UTILITIES_AND_CONSTRUCTORS(m)
+  TORCH_VIEW_FNS_NATIVE_FN_REGISTRATION(m)
 }
 
 }
diff --git a/aten/src/ATen/native/Resize.cpp b/aten/src/ATen/native/Resize.cpp
index 1937a8b..67e0904 100644
--- a/aten/src/ATen/native/Resize.cpp
+++ b/aten/src/ATen/native/Resize.cpp
@@ -27,8 +27,7 @@
 
 static auto kFunctorchWrappedTensors = DispatchKeySet({
     DispatchKey::FuncTorchGradWrapper,
-    DispatchKey::FuncTorchBatched,
-    DispatchKey::FuncTorchPython});
+    DispatchKey::FuncTorchBatched});
 
 static bool is_functorch_wrapped_tensor(const Tensor& tensor) {
   auto key_set = tensor.unsafeGetTensorImpl()->key_set();
diff --git a/aten/src/ATen/native/TensorFactories.cpp b/aten/src/ATen/native/TensorFactories.cpp
index 0dc86e4..2afb822 100644
--- a/aten/src/ATen/native/TensorFactories.cpp
+++ b/aten/src/ATen/native/TensorFactories.cpp
@@ -349,9 +349,10 @@
     namedinference::propagate_names(result, self.names());
   }
 
-  // never propagate Conjugate and Negative dispatch key
+  // never propagate Conjugate, Negative, and ZeroTensor dispatch key
   result._set_conj(false);
   result._set_neg(false);
+  result._set_zero(false);
   return result;
 }
 
@@ -435,6 +436,27 @@
 
 namespace {
 
+// The ZeroTensor allocator ignores whatever allocation is requested and always
+// gives you nullptr
+struct ZeroTensorAllocator final : public at::Allocator {
+  ZeroTensorAllocator(at::Device device) : device_(device) {};
+  ~ZeroTensorAllocator() override = default;
+  static void deleter(void* const pointer) {
+    TORCH_INTERNAL_ASSERT(!pointer);
+  }
+  DataPtr allocate(const size_t nbytes) const override {
+    return {nullptr, nullptr, &deleter, device_};
+  }
+  DeleterFnPtr raw_deleter() const override {
+    return deleter;
+  }
+  at::Device device_;
+};
+
+at::Allocator* GetZeroTensorAllocator(ZeroTensorAllocator& zt) {
+  return &zt;
+}
+
 // Performs dtype inference for full
 TensorOptions infer_full_options(
   const Scalar& fill_value,
@@ -1057,6 +1079,18 @@
   return result.zero_();
 }
 
+Tensor _efficientzerotensor(IntArrayRef size,
+    c10::optional<ScalarType> dtype,
+    c10::optional<Layout> layout,
+    c10::optional<Device> device,
+    c10::optional<bool> pin_memory) {
+    auto device_ = device_or_default(device);
+    auto allocator = ZeroTensorAllocator(device_);
+    auto dtype_ = dtype_or_default(dtype);
+    auto r = at::detail::empty_generic(size, GetZeroTensorAllocator(allocator), at::DispatchKey::ZeroTensor, dtype_, device_, c10::nullopt);
+    return r;
+}
+
 Tensor& zeros_out(IntArrayRef size, Tensor& result) {
   if (result.is_sparse()) {
     result.sparse_resize_and_clear_(size, size.size(), 0.);
@@ -1427,7 +1461,11 @@
     self = at::empty_like(src, src.options(), memory_format);
   }
 
-  self.copy_(src);
+  if (src._is_zerotensor()) {
+    self.zero_();
+  } else {
+    self.copy_(src);
+  }
   return self;
 }
 
diff --git a/aten/src/ATen/native/TypeProperties.cpp b/aten/src/ATen/native/TypeProperties.cpp
index a84f48f..a49e2a5 100644
--- a/aten/src/ATen/native/TypeProperties.cpp
+++ b/aten/src/ATen/native/TypeProperties.cpp
@@ -30,6 +30,10 @@
   return self.is_signed();
 }
 
+bool _is_zerotensor(const Tensor& self) {
+  return self._is_zerotensor();
+}
+
 bool is_conj(const Tensor& self) {
   return self.is_conj();
 }
diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml
index fc6246b..258a766 100644
--- a/aten/src/ATen/native/native_functions.yaml
+++ b/aten/src/ATen/native/native_functions.yaml
@@ -439,6 +439,7 @@
     SparseCPU, SparseCUDA: add_sparse
     SparseCsrCPU, SparseCsrCUDA: add_sparse_csr
     MkldnnCPU: mkldnn_add
+    ZeroTensor: add_zerotensor
 
 - func: add_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> Tensor(a!)
   device_check: NoCheck   # TensorIterator
@@ -711,7 +712,7 @@
 - func: as_strided(Tensor(a) self, int[] size, int[] stride, int? storage_offset=None) -> Tensor(a)
   variants: function, method
   dispatch:
-    CPU, CUDA, Meta: as_strided_tensorimpl
+    ZeroTensor, CPU, CUDA, Meta: as_strided_tensorimpl
     QuantizedCPU, QuantizedCUDA: as_strided_qtensorimpl
   device_check: NoCheck
   device_guard: False
@@ -2463,6 +2464,11 @@
   device_guard: False
   manual_cpp_binding: True
 
+- func: _is_zerotensor(Tensor self) -> bool
+  variants: function, method
+  device_guard: False
+  manual_cpp_binding: True
+
 - func: is_neg(Tensor self) -> bool
   variants: function, method
   device_guard: False
@@ -3199,6 +3205,7 @@
   dispatch:
     SparseCPU, SparseCUDA: mul_sparse
     MkldnnCPU: mkldnn_mul
+    ZeroTensor: mul_zerotensor
 
 - func: mul_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)
   device_check: NoCheck   # TensorIterator
@@ -3667,7 +3674,7 @@
   device_check: NoCheck
   device_guard: False
   dispatch:
-    CPU, CUDA, Meta, QuantizedCPU, QuantizedCUDA: _reshape_alias
+    CPU, CUDA, Meta, QuantizedCPU, QuantizedCUDA, ZeroTensor: _reshape_alias
     # We don't need to support mkldnn since this is handled explicitly by the reshape operator.
 
 - func: _mkldnn_reshape(Tensor self, int[] shape) -> Tensor
@@ -4827,6 +4834,10 @@
   device_check: NoCheck
   device_guard: False
 
+- func: _efficientzerotensor(int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+  dispatch:
+    CompositeExplicitAutograd: _efficientzerotensor
+
 - func: zeros(int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
 
 - func: zeros.out(int[] size, *, Tensor(a!) out) -> Tensor(a!)
@@ -5886,7 +5897,7 @@
   device_check: NoCheck
   device_guard: False
   dispatch:
-    CPU, CUDA, Meta, QuantizedCPU, QuantizedCUDA: view
+    ZeroTensor, CPU, CUDA, Meta, QuantizedCPU, QuantizedCUDA: view
     MkldnnCPU: mkldnn_view
 
 # Warning: If you want to change the name or overload name of this
diff --git a/aten/src/ATen/templates/Functions.h b/aten/src/ATen/templates/Functions.h
index a7563a1..1a1c1d7 100644
--- a/aten/src/ATen/templates/Functions.h
+++ b/aten/src/ATen/templates/Functions.h
@@ -221,6 +221,10 @@
   return tensor.is_inference();
 }
 
+inline bool _is_zerotensor(const Tensor& tensor) {
+  return tensor._is_zerotensor();
+}
+
 inline bool is_conj(const Tensor& tensor) {
   return tensor.is_conj();
 }
diff --git a/c10/core/DispatchKey.cpp b/c10/core/DispatchKey.cpp
index fc40019..96b714a 100644
--- a/c10/core/DispatchKey.cpp
+++ b/c10/core/DispatchKey.cpp
@@ -111,6 +111,9 @@
       return "AutogradPrivateUse3";
     case DispatchKey::AutogradOther:
       return "AutogradOther";
+
+    case DispatchKey::ZeroTensor:
+      return "ZeroTensor";
     case DispatchKey::BackendSelect:
       return "BackendSelect";
     case DispatchKey::Named:
@@ -149,8 +152,6 @@
     // https://github.com/zou3519/functorch
     // We plan on eventually upstreaming the prototype into core, at which
     // point it will have a different design that should use fewer keys.
-    case DispatchKey::FuncTorchPython:
-      return "FuncTorchPython";
     case DispatchKey::FuncTorchDynamicLayerBackMode:
       return "FuncTorchDynamicLayerBackMode";
     case DispatchKey::FuncTorchDynamicLayerFrontMode:
@@ -242,10 +243,10 @@
       {"PrivateUse3", c10::DispatchKey::PrivateUse3},
       {"BackendSelect", c10::DispatchKey::BackendSelect},
       {"Python", c10::DispatchKey::Python},
-      {"FuncTorchPython", c10::DispatchKey::FuncTorchPython},
       {"Named", c10::DispatchKey::Named},
       {"Conjugate", c10::DispatchKey::Conjugate},
       {"Negative", c10::DispatchKey::Negative},
+      {"ZeroTensor", c10::DispatchKey::ZeroTensor},
       {"FuncTorchDynamicLayerBackMode",
        c10::DispatchKey::FuncTorchDynamicLayerBackMode},
       {"ADInplaceOrView", c10::DispatchKey::ADInplaceOrView},
diff --git a/c10/core/DispatchKey.h b/c10/core/DispatchKey.h
index 2cf41af..1bb8268 100644
--- a/c10/core/DispatchKey.h
+++ b/c10/core/DispatchKey.h
@@ -139,7 +139,6 @@
   BackendSelect,
 
   Python,
-  FuncTorchPython, // See Note [Out-of-tree vmap+grad prototype]
 
   // The named dispatch key is set for any tensors with named dimensions.
   // Although we have a dispatch key for named tensors, for historical reasons,
@@ -165,6 +164,8 @@
   // This is implemented at a dispatch level right before any backends run
   Negative,
 
+  ZeroTensor, // registered at build/aten/src/ATen/RegisterZeroTensor.cpp
+
   // See Note [Out-of-tree vmap+grad prototype]. The purpose of this key
   // is to insert code after the "autograd subsystem" runs, so this key should
   // be directly after ADInplaceOrView and all of the autograd keys.
diff --git a/c10/core/TensorImpl.h b/c10/core/TensorImpl.h
index 857e28f..bee4da9 100644
--- a/c10/core/TensorImpl.h
+++ b/c10/core/TensorImpl.h
@@ -1057,6 +1057,24 @@
   }
 
   /**
+   * Whether or not the tensor is a zerotensor
+  */
+  inline bool _is_zerotensor() const {
+    return key_set_.has(DispatchKey::ZeroTensor);
+  }
+
+  /**
+   Set whether or not the tensor is a zero tensor
+  */
+  void _set_zero(bool value) {
+    if (value) {
+      TORCH_INTERNAL_ASSERT(false, "Please call `torch._efficientzerotensor` if you want to create a tensor with no storage.");
+    } else {
+      key_set_ = key_set_.remove(DispatchKey::ZeroTensor);
+    }
+  }
+
+  /**
    * Whether or not the tensor should be negated
    */
   inline bool is_neg() const {
diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml
index 3daf6d7..16521d7 100644
--- a/tools/autograd/derivatives.yaml
+++ b/tools/autograd/derivatives.yaml
@@ -2433,3 +2433,6 @@
 
 - name: _test_warn_in_autograd(Tensor self) -> Tensor
   self: warn_backwards(grad)
+
+- name: _efficientzerotensor(int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+  output_differentiability: [False]
diff --git a/tools/autograd/gen_variable_type.py b/tools/autograd/gen_variable_type.py
index 0952a33..803097f 100644
--- a/tools/autograd/gen_variable_type.py
+++ b/tools/autograd/gen_variable_type.py
@@ -319,7 +319,8 @@
 
 FW_DERIVATIVE_DEFINED_GRAD_TEMPLATE = CodeTemplate("""\
 auto ${inp}_t_raw = toNonOptFwGrad(${inp});
-auto ${inp}_t = ${inp}_t_raw.defined() ? ${inp}_t_raw : at::zeros_like(toNonOptTensor(${inp}));
+auto ${inp}_tensor = toNonOptTensor(${inp});
+auto ${inp}_t = ${inp}_t_raw.defined() ? ${inp}_t_raw : at::_efficientzerotensor(${inp}_tensor.sizes(), ${inp}_tensor.options());
 """)
 
 FW_DERIVATIVE_DEFINED_PRIMAL_TEMPLATE = CodeTemplate("""\
diff --git a/tools/codegen/gen.py b/tools/codegen/gen.py
index 072b83c..77a1ef6 100644
--- a/tools/codegen/gen.py
+++ b/tools/codegen/gen.py
@@ -1307,6 +1307,7 @@
         # Meta is a magic key: it is automatically generated for structured
         # kernels
         DispatchKey.Meta,
+        DispatchKey.ZeroTensor,
     ]
     # Only a limited set of dispatch keys get CPUFunctions.h headers generated
     # for them; this is the set
diff --git a/tools/codegen/model.py b/tools/codegen/model.py
index 1643089..edcaf95 100644
--- a/tools/codegen/model.py
+++ b/tools/codegen/model.py
@@ -79,6 +79,7 @@
     PrivateUse3 = auto()
     EndOfBackendKeys = PrivateUse3
 
+    ZeroTensor = auto()
     Meta = auto()
     BackendSelect = auto()
     Named = auto()
diff --git a/torch/_tensor_str.py b/torch/_tensor_str.py
index 7996f58..bc9ad2d 100644
--- a/torch/_tensor_str.py
+++ b/torch/_tensor_str.py
@@ -234,6 +234,9 @@
 
     summarize = self.numel() > PRINT_OPTS.threshold
 
+    if self._is_zerotensor():
+        self = self.clone()
+
     # handle the negative bit
     if self.is_neg():
         self = self.resolve_neg()
diff --git a/torch/autograd/gradcheck.py b/torch/autograd/gradcheck.py
index 5c09b06..3d1c7ec 100644
--- a/torch/autograd/gradcheck.py
+++ b/torch/autograd/gradcheck.py
@@ -871,8 +871,59 @@
             raise GradcheckError('grad is incorrect size')
     return True
 
+def _test_undefined_forward_mode(func, outputs, inputs):
+    fwAD = torch.autograd.forward_ad
 
-def _test_undefined_grad(func, outputs, inputs) -> bool:
+    inp_tensors_idx, inp_tensors = _get_inp_tensors(inputs)
+    all_v, all_u, all_u_dense = _make_vectors(inp_tensors, outputs, use_forward_ad=True)
+
+    tensor_inputs = tuple(i for i in inputs if is_tensor_like(i) and i.requires_grad)
+
+    with fwAD.dual_level():
+        fw_grads = []
+        dual_inputs = []
+        for i, inp in enumerate(inputs):
+            if is_tensor_like(inp) and inp.requires_grad:
+                if inp.layout == torch._mkldnn:  # type: ignore[attr-defined]
+                    raise ValueError("MKLDNN inputs are not support for forward AD gradcheck.")
+
+                inp = fwAD.make_dual(inp, torch.zeros_like(inp))
+                # If inp is a differentiable view, the dual might not be the tangent given to
+                # make_dual, so read it explicitly from the dual tensor
+                fw_grads.append(fwAD.unpack_dual(inp)[1])
+            dual_inputs.append(inp)
+
+        for i, (fw_grad, u) in enumerate(zip(fw_grads, all_u)):
+            fw_grad.copy_(u.view_as(fw_grad))
+
+        for idx, inp in enumerate(tensor_inputs):
+            dual_inp_obj = dual_inputs[idx]
+
+            # case 1 (Materialized Zero Tensor Tangent)
+            dual_inputs[idx] = fwAD.make_dual(inp, torch.zeros_like(inp))
+            raw_outputs = _as_tuple(func(*dual_inputs))
+            dual_outputs1 = filter(_is_float_or_complex_tensor, raw_outputs)
+
+            # case 2 (Efficient Zero Tensor Tangent since we don't make a dual object and pass a regular tensor)
+            dual_inputs[idx] = inp
+            raw_outputs = _as_tuple(func(*dual_inputs))
+            dual_outputs2 = filter(_is_float_or_complex_tensor, raw_outputs)
+
+            # reset
+            dual_inputs[idx] = dual_inp_obj
+
+            for index_o, (d_o1, d_o2) in enumerate(zip(dual_outputs1, dual_outputs2)):
+                val1, res1 = fwAD.unpack_dual(d_o1)
+                val2, res2 = fwAD.unpack_dual(d_o2)
+
+                if not (res1 is None or res2 is None):
+                    if not torch.equal(res1, res2):
+                        raise GradcheckError("Mismatch in tangent values for output with index: ", index_o,
+                                             " when input: ", inp, " has an undefined tangent value. ",
+                                             " Got: ", res1, " but expected: ", res2)
+    return True
+
+def _test_undefined_backward_mode(func, outputs, inputs) -> bool:
     diff_input_list: List[torch.Tensor] = list(_iter_tensors(inputs, True))
     if not diff_input_list:
         raise GradcheckError("no Tensors requiring grad found in input")
@@ -984,7 +1035,8 @@
 
 
 def _gradcheck_real_imag(gradcheck_fn, func, func_out, tupled_inputs, outputs, eps, rtol,
-                         atol, check_grad_dtypes, check_forward_ad, check_backward_ad, nondet_tol):
+                         atol, check_grad_dtypes, check_forward_ad, check_backward_ad, nondet_tol,
+                         check_undefined_grad):
     complex_out_indices = [i for i, o in enumerate(outputs) if o.is_complex()]
     has_any_complex_output = any(o.is_complex() for o in _as_tuple(func_out))
     if check_backward_ad:
@@ -1023,10 +1075,14 @@
             gradcheck_fn(real_fn, real_func_out, real_inputs, diff_real_func_out, eps,
                          rtol, atol, check_grad_dtypes, nondet_tol, complex_indices=complex_inp_indices,
                          use_forward_ad=True)
+            if check_undefined_grad:
+                _test_undefined_forward_mode(imag_fn, imag_func_out, imag_inputs)
+                _test_undefined_forward_mode(real_fn, real_func_out, real_inputs)
         else:
             gradcheck_fn(func, func_out, tupled_inputs, outputs, eps,
                          rtol, atol, check_grad_dtypes, nondet_tol, use_forward_ad=True)
-
+            if check_undefined_grad:
+                _test_undefined_forward_mode(func, outputs, tupled_inputs)
 
 def _slow_gradcheck(func, func_out, tupled_inputs, outputs, eps, rtol, atol, check_grad_dtypes,
                     nondet_tol, *, use_forward_ad=False, complex_indices=None, test_imag=False):
@@ -1210,6 +1266,13 @@
                     atol, check_grad_dtypes, nondet_tol, *, use_forward_ad=False, complex_indices=None, test_imag=False):
     # See https://github.com/pytorch/pytorch/issues/53876 for details
     inp_tensors_idx, inp_tensors = _get_inp_tensors(inputs)
+    # Backward mode computes v^T * J (VJP)
+    # Since we computed J * u (JVP) through finite difference method, we perform an equality check
+    # between VJP * u, v * JVP
+    # ----
+    # Forward mode computes J * u (JVP)
+    # Since we already compute JVP through finite difference method,
+    # we don't need v for correctness check here as asserted below
     all_v, all_u, all_u_dense = _make_vectors(inp_tensors, outputs, use_forward_ad=use_forward_ad)
 
     numerical_vJu = _get_numerical_vJu(func, inputs, inp_tensors_idx, func_out, all_u, all_v, eps, is_forward_ad=use_forward_ad)
@@ -1348,7 +1411,8 @@
     gradcheck_fn = _fast_gradcheck if fast_mode else _slow_gradcheck
     _gradcheck_real_imag(gradcheck_fn, func, func_out, tupled_inputs, outputs, eps,
                          rtol, atol, check_grad_dtypes, check_forward_ad=check_forward_ad,
-                         check_backward_ad=check_backward_ad, nondet_tol=nondet_tol)
+                         check_backward_ad=check_backward_ad, nondet_tol=nondet_tol,
+                         check_undefined_grad=check_undefined_grad)
 
     if check_batched_forward_grad:
         _test_batched_grad_forward_ad(func, tupled_inputs)
@@ -1364,7 +1428,7 @@
     _test_backward_mul_by_grad_output(outputs, tupled_inputs, check_sparse_nnz)
 
     if check_undefined_grad:
-        _test_undefined_grad(func, outputs, tupled_inputs)
+        _test_undefined_backward_mode(func, outputs, tupled_inputs)
     return True
 
 
diff --git a/torch/overrides.py b/torch/overrides.py
index fcdcaff..31b8c3c 100644
--- a/torch/overrides.py
+++ b/torch/overrides.py
@@ -243,6 +243,7 @@
         Tensor._conj,
         Tensor._conj_physical,
         Tensor._neg_view,
+        Tensor._is_zerotensor,
     }
 
 
diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py
index 3f8ba49..0d325ea 100644
--- a/torch/testing/_internal/common_methods_invocations.py
+++ b/torch/testing/_internal/common_methods_invocations.py
@@ -9218,7 +9218,13 @@
                     supports_forward_ad=True,
                     promotes_int_to_float=True,
                     assert_autodiffed=True,
-                    rhs_make_tensor_kwargs=dict(exclude_zero=True)),
+                    rhs_make_tensor_kwargs=dict(exclude_zero=True),
+                    skips=(
+                        DecorateInfo(unittest.skip("Skipped!"), 'TestGradients', 'test_forward_mode_AD',
+                                     device_type='cuda', dtypes=[torch.double, torch.cdouble]),
+                        DecorateInfo(unittest.skip("Skipped!"), 'TestGradients', 'test_inplace_forward_mode_AD',
+                                     device_type='cuda', dtypes=[torch.double, torch.cdouble]),
+                    ),),
     BinaryUfuncInfo('div',
                     aliases=('divide',),
                     variant_test_name='trunc_rounding',
@@ -9227,7 +9233,13 @@
                     supports_forward_ad=True,
                     promotes_int_to_float=True,
                     assert_autodiffed=True,
-                    rhs_make_tensor_kwargs=dict(exclude_zero=True)),
+                    rhs_make_tensor_kwargs=dict(exclude_zero=True),
+                    skips=(
+                        DecorateInfo(unittest.skip("Skipped!"), 'TestGradients', 'test_forward_mode_AD',
+                                     device_type='cuda', dtypes=[torch.double, torch.cdouble]),
+                        DecorateInfo(unittest.skip("Skipped!"), 'TestGradients', 'test_inplace_forward_mode_AD',
+                                     device_type='cuda', dtypes=[torch.double, torch.cdouble]),
+                    ),),
     BinaryUfuncInfo('div',
                     aliases=('divide',),
                     variant_test_name='floor_rounding',
@@ -9236,7 +9248,13 @@
                     supports_forward_ad=True,
                     promotes_int_to_float=True,
                     assert_autodiffed=True,
-                    rhs_make_tensor_kwargs=dict(exclude_zero=True)),
+                    rhs_make_tensor_kwargs=dict(exclude_zero=True),
+                    skips=(
+                        DecorateInfo(unittest.skip("Skipped!"), 'TestGradients', 'test_forward_mode_AD',
+                                     device_type='cuda', dtypes=[torch.double, torch.cdouble]),
+                        DecorateInfo(unittest.skip("Skipped!"), 'TestGradients', 'test_inplace_forward_mode_AD',
+                                     device_type='cuda', dtypes=[torch.double, torch.cdouble]),
+                    ),),
     BinaryUfuncInfo('true_divide',
                     dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
                     supports_forward_ad=True,