Rename prim::Undefined to prim::AutogradZero (#17611)

Summary:
supersedes #17245
Pull Request resolved: https://github.com/pytorch/pytorch/pull/17611

Differential Revision: D14283581

Pulled By: wanchaol

fbshipit-source-id: 8022d02b8a021ea2fee9a18a2c8920eb123200c5
diff --git a/aten/src/ATen/core/interned_strings.h b/aten/src/ATen/core/interned_strings.h
index d58fdc1..2bd28fb 100644
--- a/aten/src/ATen/core/interned_strings.h
+++ b/aten/src/ATen/core/interned_strings.h
@@ -24,7 +24,6 @@
   _(prim, BroadcastSizes)          \
   _(prim, Constant)                \
   _(prim, ChunkSizes)              \
-  _(prim, None)                    \
   _(prim, Drop)                    \
   _(prim, Eval)                    \
   _(prim, Expand) /* onnx */       \
@@ -46,7 +45,8 @@
   _(prim, Reverse)                 \
   _(prim, Return)                  \
   _(prim, Store)                   \
-  _(prim, Undefined)               \
+  _(prim, AutogradZero)            \
+  _(prim, AutogradAnyNonZero)      \
   _(prim, Starred)                 \
   _(prim, TupleConstruct)          \
   _(prim, TupleUnpack)             \
@@ -67,7 +67,6 @@
   _(prim, requires_grad)           \
   _(prim, AutogradAdd)             \
   _(prim, GradOf)                  \
-  _(prim, AnyDefined)              \
   _(prim, FusedConcat)             \
   _(prim, ConstantChunk)           \
   _(prim, MMTreeReduce)            \
diff --git a/aten/src/ATen/core/jit_type.h b/aten/src/ATen/core/jit_type.h
index 066fb13..3759566 100644
--- a/aten/src/ATen/core/jit_type.h
+++ b/aten/src/ATen/core/jit_type.h
@@ -29,7 +29,7 @@
 _(TensorType) \
 _(DimensionedTensorType) \
 _(CompleteTensorType) \
-_(UndefinedTensorType) \
+_(AutogradZeroTensorType) \
 _(TupleType) \
 _(ListType) \
 _(DictType) \
@@ -255,9 +255,9 @@
 struct TensorType;
 using TensorTypePtr = std::shared_ptr<TensorType>;
 // This type represents a single Tensor, with an unknown shape.
-// Subtype hierarchy for Tensor Types (DynamicType as the base type):
-// CompleteTensorType <: TensorType <: DynamicType
-// UndefinedTensorType <: DynamicType
+// Subtype hierarchy for Tensor Types (TensorType as the base type):
+// CompleteTensorType <: DimensionedTensorType <: TensorType
+// AutogradZeroTensorType <: TensorType
 struct CAFFE2_API TensorType : public Type {
   static TensorTypePtr create() {
     return TensorTypePtr(new TensorType()); // NOLINT(modernize-make-shared)
@@ -280,15 +280,15 @@
   : Type(kind) {}
 };
 
-struct UndefinedTensorType;
-using UndefinedTensorTypePtr = std::shared_ptr<UndefinedTensorType>;
+struct AutogradZeroTensorType;
+using AutogradZeroTensorTypePtr = std::shared_ptr<AutogradZeroTensorType>;
 // This type represents an undefined tensor.
-struct CAFFE2_API UndefinedTensorType : public TensorType {
-  static UndefinedTensorTypePtr create() {
-    return UndefinedTensorTypePtr(new UndefinedTensorType()); // NOLINT(modernize-make-shared)
+struct CAFFE2_API AutogradZeroTensorType : public TensorType {
+  static AutogradZeroTensorTypePtr create() {
+    return AutogradZeroTensorTypePtr(new AutogradZeroTensorType()); // NOLINT(modernize-make-shared)
   }
 
-  DEFINE_IS_SUBCLASS(UndefinedTensorType);
+  DEFINE_IS_SUBCLASS(AutogradZeroTensorType);
 
   bool requires_grad() const override { return false; }
 
@@ -297,18 +297,18 @@
   }
   bool isSubtypeOf(const TypePtr rhs) const override {
     return rhs->kind() == TypeKind::TensorType ||
-           rhs->kind() == TypeKind::UndefinedTensorType ||
+           rhs->kind() == TypeKind::AutogradZeroTensorType ||
            TensorType::isSubtypeOf(rhs);
   }
   std::string str() const override {
     return "UndefinedTensor";
   }
 
-  static const TypeKind Kind = TypeKind::UndefinedTensorType;
+  static const TypeKind Kind = TypeKind::AutogradZeroTensorType;
   // global singleton
-  static UndefinedTensorTypePtr get();
+  static AutogradZeroTensorTypePtr get();
 protected:
-  UndefinedTensorType(): TensorType(TypeKind::UndefinedTensorType) {}
+  AutogradZeroTensorType(): TensorType(TypeKind::AutogradZeroTensorType) {}
 };
 
 struct DimensionedTensorType;
diff --git a/aten/src/ATen/core/type.cpp b/aten/src/ATen/core/type.cpp
index b17b05a..e8892db 100644
--- a/aten/src/ATen/core/type.cpp
+++ b/aten/src/ATen/core/type.cpp
@@ -59,8 +59,8 @@
   static auto value = TensorType::create();
   return value;
 }
-UndefinedTensorTypePtr UndefinedTensorType::get() {
-  static auto value = UndefinedTensorType::create();
+AutogradZeroTensorTypePtr AutogradZeroTensorType::get() {
+  static auto value = AutogradZeroTensorType::create();
   return value;
 }
 NumberTypePtr NumberType::get() {
diff --git a/test/cpp/jit/test_alias_analysis.h b/test/cpp/jit/test_alias_analysis.h
index c1d2b57..1f5bfc4 100644
--- a/test/cpp/jit/test_alias_analysis.h
+++ b/test/cpp/jit/test_alias_analysis.h
@@ -51,7 +51,7 @@
     for (const auto name : inputNames) {
       inputs.push_back(nodes.at(name)->output());
     }
-    auto node = graph->appendNode(graph->create(prim::Undefined, inputs));
+    auto node = graph->appendNode(graph->create(prim::AutogradZero, inputs));
     node->output()->setUniqueName(name);
     nodes[name] = node;
 
@@ -63,7 +63,7 @@
       }
 
       auto block = node->blocks().at(0);
-      block->appendNode(graph->create(prim::Undefined, blockDeps));
+      block->appendNode(graph->create(prim::AutogradZero, blockDeps));
     }
   }
 
diff --git a/test/cpp/jit/test_misc.h b/test/cpp/jit/test_misc.h
index 36c22dc..e2ea958 100644
--- a/test/cpp/jit/test_misc.h
+++ b/test/cpp/jit/test_misc.h
@@ -1524,10 +1524,10 @@
 void testTopologicalIndex() {
   {
     Graph graph;
-    auto node1 = graph.create(prim::Undefined);
-    auto node2 = graph.create(prim::Undefined);
-    auto node3 = graph.create(prim::Undefined);
-    auto node4 = graph.create(prim::Undefined);
+    auto node1 = graph.create(prim::AutogradZero);
+    auto node2 = graph.create(prim::AutogradZero);
+    auto node3 = graph.create(prim::AutogradZero);
+    auto node4 = graph.create(prim::AutogradZero);
 
     graph.appendNode(node4);
     graph.prependNode(node1);
@@ -1552,12 +1552,12 @@
     //      \      ...
     //      C    block2
     auto block1 = node3->addBlock();
-    auto A = graph.create(prim::Undefined);
+    auto A = graph.create(prim::AutogradZero);
     block1->appendNode(A);
-    auto B = graph.create(prim::Undefined);
+    auto B = graph.create(prim::AutogradZero);
     block1->appendNode(B);
     auto block2 = B->addBlock();
-    auto C = graph.create(prim::Undefined);
+    auto C = graph.create(prim::AutogradZero);
     block2->appendNode(C);
 
     // Check isAfter on different block levels
@@ -1567,7 +1567,7 @@
 
     // make sure things don't blow up on deletions
     node2->destroy();
-    auto node2p = graph.create(prim::Undefined);
+    auto node2p = graph.create(prim::AutogradZero);
     node2p->insertAfter(node1);
     ASSERT_TRUE(node1->isBefore(node2p));
     ASSERT_TRUE(node2p->isBefore(node3));
@@ -1577,11 +1577,11 @@
     Graph graph;
     std::map<size_t, Node*> nodes;
 
-    auto anchor = graph.create(prim::Undefined);
+    auto anchor = graph.create(prim::AutogradZero);
     graph.appendNode(anchor);
     // Inserting to the same place a lot will trigger reindexing
     for (auto i = 0; i < 100; ++i) {
-      auto n = graph.create(prim::Undefined);
+      auto n = graph.create(prim::AutogradZero);
       n->insertAfter(anchor);
       nodes[i] = n;
     }
diff --git a/tools/build_variables.py b/tools/build_variables.py
index 63953ab..ebf6a9e 100644
--- a/tools/build_variables.py
+++ b/tools/build_variables.py
@@ -85,7 +85,7 @@
     "torch/csrc/jit/passes/remove_expands.cpp",
     "torch/csrc/jit/passes/requires_grad_analysis.cpp",
     "torch/csrc/jit/passes/shape_analysis.cpp",
-    "torch/csrc/jit/passes/specialize_undef.cpp",
+    "torch/csrc/jit/passes/specialize_autogradzero.cpp",
     "torch/csrc/jit/passes/utils/subgraph_utils.cpp",
     "torch/csrc/jit/passes/utils/memory_dag.cpp",
     "torch/csrc/jit/register_prim_ops.cpp",
diff --git a/torch/CMakeLists.txt b/torch/CMakeLists.txt
index 159e8bc..a06ac07 100644
--- a/torch/CMakeLists.txt
+++ b/torch/CMakeLists.txt
@@ -161,7 +161,7 @@
   ${TORCH_SRC_DIR}/csrc/jit/passes/remove_inplace_ops.cpp
   ${TORCH_SRC_DIR}/csrc/jit/passes/shape_analysis.cpp
   ${TORCH_SRC_DIR}/csrc/jit/passes/requires_grad_analysis.cpp
-  ${TORCH_SRC_DIR}/csrc/jit/passes/specialize_undef.cpp
+  ${TORCH_SRC_DIR}/csrc/jit/passes/specialize_autogradzero.cpp
   ${TORCH_SRC_DIR}/csrc/jit/passes/python_print.cpp
   ${TORCH_SRC_DIR}/csrc/jit/passes/utils/subgraph_utils.cpp
   ${TORCH_SRC_DIR}/csrc/jit/passes/utils/check_alias_annotation.cpp
diff --git a/torch/csrc/jit/argument_spec.h b/torch/csrc/jit/argument_spec.h
index 72155cb..ac959aa 100644
--- a/torch/csrc/jit/argument_spec.h
+++ b/torch/csrc/jit/argument_spec.h
@@ -156,7 +156,7 @@
     if (original->isSubtypeOf(TensorType::get())) {
       auto& arg = args.at(offset++);
       if (!arg.defined())
-        return UndefinedTensorType::get();
+        return AutogradZeroTensorType::get();
       return DimensionedTensorType::create(
           arg.type(),
           ConvertIntToCPUOrCUDA(arg.device()),
diff --git a/torch/csrc/jit/autodiff.cpp b/torch/csrc/jit/autodiff.cpp
index a92f982..b1b1da2 100644
--- a/torch/csrc/jit/autodiff.cpp
+++ b/torch/csrc/jit/autodiff.cpp
@@ -122,7 +122,7 @@
   // perf "aten::atan2(Tensor self) -> Tensor", "aten::max(Tensor self) ->
   // Tensor", "aten::min(Tensor self) -> Tensor"
 
-  if (n->kind() == prim::Constant || n->kind() == prim::Undefined ||
+  if (n->kind() == prim::Constant || n->kind() == prim::AutogradZero ||
       n->kind() == prim::AutogradAdd || n->kind() == prim::ConstantChunk)
     return true;
   if (differentiable_ops.find(n))
@@ -718,7 +718,7 @@
         node->matches(
             "aten::nll_loss(Tensor self, Tensor target, Tensor? weight, int reduction, int ignore_index) -> Tensor")) {
       auto graph = node->owningGraph();
-      auto total_weight = graph->insertNode(graph->createUndefined());
+      auto total_weight = graph->insertNode(graph->createAutogradZero());
       auto weight = graph->insertNode(graph->createNone(TensorType::get()));
       auto backward_value = graph->insert(
           aten::nll_loss_backward,
@@ -748,7 +748,7 @@
       return {backward_value->node()->output(0), nullptr};
 
     } else if (
-        node->kind() == prim::Constant || node->kind() == prim::Undefined) {
+        node->kind() == prim::Constant || node->kind() == prim::AutogradZero) {
       return {};
     }
     throw std::runtime_error(
@@ -834,7 +834,7 @@
   const auto get_grad = [&](Value* v) -> Value* {
     auto it = grad_map.find(v);
     if (it == grad_map.end()) {
-      auto undef = graph.insertNode(graph.createUndefined());
+      auto undef = graph.insertNode(graph.createAutogradZero());
       std::tie(it, std::ignore) = grad_map.emplace(v, undef->output());
     }
     return it->second;
@@ -947,7 +947,7 @@
     AT_ASSERT(
         top_node->kind() == prim::GradOf ||
         top_node->kind() == prim::AutogradAdd ||
-        top_node->kind() == prim::Undefined);
+        top_node->kind() == prim::AutogradZero);
     if (top_node->kind() != prim::GradOf)
       continue;
     Block* grad_body = top_node->blocks().at(0);
diff --git a/torch/csrc/jit/graph_executor.cpp b/torch/csrc/jit/graph_executor.cpp
index 2332dd1..eaaf1d1 100644
--- a/torch/csrc/jit/graph_executor.cpp
+++ b/torch/csrc/jit/graph_executor.cpp
@@ -26,7 +26,7 @@
 #include <torch/csrc/jit/passes/remove_expands.h>
 #include <torch/csrc/jit/passes/requires_grad_analysis.h>
 #include <torch/csrc/jit/passes/shape_analysis.h>
-#include <torch/csrc/jit/passes/specialize_undef.h>
+#include <torch/csrc/jit/passes/specialize_autogradzero.h>
 #include <torch/csrc/jit/symbolic_variable.h>
 #include <torch/csrc/jit/tracer.h>
 
@@ -635,7 +635,7 @@
 }
 
 void runRequiredPasses(const std::shared_ptr<Graph>& g) {
-  specializeUndef(*g);
+  specializeAutogradZero(*g);
   LowerGradOf(*g);
   // implicit inserted expand nodes are not necessarily always valid
   // when used inside script methods that might have unstable shapes
diff --git a/torch/csrc/jit/init.cpp b/torch/csrc/jit/init.cpp
index 4cdf4ad..ad0e0d9 100644
--- a/torch/csrc/jit/init.cpp
+++ b/torch/csrc/jit/init.cpp
@@ -29,7 +29,7 @@
 #include <torch/csrc/jit/passes/remove_expands.h>
 #include <torch/csrc/jit/passes/remove_inplace_ops.h>
 #include <torch/csrc/jit/passes/shape_analysis.h>
-#include <torch/csrc/jit/passes/specialize_undef.h>
+#include <torch/csrc/jit/passes/specialize_autogradzero.h>
 #include <torch/csrc/jit/passes/to_batch.h>
 #include <torch/csrc/jit/passes/utils/check_alias_annotation.h>
 #include <torch/csrc/jit/pybind_utils.h>
@@ -188,7 +188,7 @@
       .def("_jit_pass_onnx_block", BlockToONNX)
       .def("_jit_pass_fixup_onnx_loops", FixupONNXLoops)
       .def("_jit_pass_canonicalize_ops", CanonicalizeOps)
-      .def("_jit_pass_specialize_undef", specializeUndef)
+      .def("_jit_pass_specialize_autogradzero", specializeAutogradZero)
       .def("_jit_override_can_fuse_on_cpu", &overrideCanFuseOnCPU)
       .def(
           "_jit_differentiate",
diff --git a/torch/csrc/jit/ir.cpp b/torch/csrc/jit/ir.cpp
index 79c04e6..613da86 100644
--- a/torch/csrc/jit/ir.cpp
+++ b/torch/csrc/jit/ir.cpp
@@ -1194,8 +1194,8 @@
   return n;
 }
 
-Node* Graph::createUndefined() {
-  return create(prim::Undefined);
+Node* Graph::createAutogradZero() {
+  return create(prim::AutogradZero);
 }
 
 Node* Graph::createNone(TypePtr typ) {
diff --git a/torch/csrc/jit/ir.h b/torch/csrc/jit/ir.h
index 1cd0fc0..4ef6883 100644
--- a/torch/csrc/jit/ir.h
+++ b/torch/csrc/jit/ir.h
@@ -1041,7 +1041,7 @@
 
   TORCH_API Node* createNone(
       TypePtr typ); // value of None with type Optional[typ]
-  TORCH_API Node* createUndefined();
+  TORCH_API Node* createAutogradZero();
   TORCH_API Node* createFusionGroup();
   TORCH_API Node* createDifferentiableSubgraph();
   TORCH_API Node* createTuple(
diff --git a/torch/csrc/jit/passes/alias_analysis.cpp b/torch/csrc/jit/passes/alias_analysis.cpp
index 3a9ac17..8ebec19 100644
--- a/torch/csrc/jit/passes/alias_analysis.cpp
+++ b/torch/csrc/jit/passes/alias_analysis.cpp
@@ -369,7 +369,7 @@
     case prim::DictConstruct:
     case prim::ListConstruct:
     case prim::TupleConstruct:
-    case prim::Undefined:
+    case prim::AutogradZero:
     case prim::FusedConcat:
     case prim::MMTreeReduce:
     case prim::MMBatchSide:
@@ -1128,11 +1128,10 @@
       prim::DictConstruct,
       prim::ListConstruct,
       prim::TupleConstruct,
-      prim::Undefined,
+      prim::AutogradZero,
       prim::FusedConcat,
       prim::MMTreeReduce,
       prim::MMBatchSide,
-      prim::None,
       prim::BroadcastSizes,
       prim::ChunkSizes,
       prim::Function,
@@ -1163,7 +1162,7 @@
       prim::Drop,
       at::onnx::Reshape,
       at::onnx::Shape,
-      prim::AnyDefined,
+      prim::AutogradAnyNonZero,
       prim::AutogradAdd,
   };
 
diff --git a/torch/csrc/jit/passes/constant_propagation.cpp b/torch/csrc/jit/passes/constant_propagation.cpp
index 56a4db7..5b7652b 100644
--- a/torch/csrc/jit/passes/constant_propagation.cpp
+++ b/torch/csrc/jit/passes/constant_propagation.cpp
@@ -18,7 +18,7 @@
     prim::If,
     prim::Loop,
     prim::Constant,
-    prim::Undefined,
+    prim::AutogradZero,
     prim::unchecked_unwrap_optional, // TODO remove
     // TODO (zach): we should consider skipping tensor factories in the cases
     // where the constant tensor would be large but cheap to create.
diff --git a/torch/csrc/jit/passes/graph_fuser.cpp b/torch/csrc/jit/passes/graph_fuser.cpp
index b217dce..123a4fe 100644
--- a/torch/csrc/jit/passes/graph_fuser.cpp
+++ b/torch/csrc/jit/passes/graph_fuser.cpp
@@ -136,8 +136,7 @@
   if (tensor->type()->isSubtypeOf(TensorType::get())) {
     return true;
   }
-  if (tensor->node()->mustBeNone() ||
-      tensor->node()->kind() == prim::Undefined) {
+  if (tensor->node()->mustBeNone()) {
     return false;
   }
   return {};
diff --git a/torch/csrc/jit/passes/lower_grad_of.cpp b/torch/csrc/jit/passes/lower_grad_of.cpp
index f64d66d..944bef0 100644
--- a/torch/csrc/jit/passes/lower_grad_of.cpp
+++ b/torch/csrc/jit/passes/lower_grad_of.cpp
@@ -9,9 +9,9 @@
       // if any_defined(inputs):
       //  outputs = <original_computation>
       // else:
-      //  outputs = undefineds
+      //  outputs = autograd zero tensors
       WithInsertPoint guard(*it);
-      auto cond = g.insertNode(g.create(prim::AnyDefined, it->inputs()))
+      auto cond = g.insertNode(g.create(prim::AutogradAnyNonZero, it->inputs()))
                       ->output()
                       ->setType(IntType::get());
       auto if_stat =
@@ -19,7 +19,7 @@
       if_stat->addBlock()->cloneFrom(
           it->blocks().at(0), [](Value* v) { return v; });
       auto else_block = if_stat->addBlock();
-      auto undef = g.createUndefined()
+      auto undef = g.createAutogradZero()
                        ->insertBefore(else_block->return_node())
                        ->output();
       for (size_t i = 0; i < it->outputs().size(); ++i) {
diff --git a/torch/csrc/jit/passes/python_print.cpp b/torch/csrc/jit/passes/python_print.cpp
index 43291ac..22061fd 100644
--- a/torch/csrc/jit/passes/python_print.cpp
+++ b/torch/csrc/jit/passes/python_print.cpp
@@ -239,16 +239,6 @@
   // The inductive step is that the right-most input should be produced by the
   // node immediatly before the current node if it is in tree order.
 
-  bool isConstantLike(Node* n) {
-    switch (n->kind()) {
-      case prim::Constant:
-      case prim::Undefined:
-        return true;
-      default:
-        return false;
-    }
-  }
-
   bool canInline(Value* v) {
     Node* n = v->node();
     // there must be only 1 values, otherwise we need an assignment to handle
@@ -280,7 +270,7 @@
   // block_point's output.
   Node* scanValue(Node* block_point, Value* v) {
     Node* n = v->node();
-    AT_ASSERT(isConstantLike(n) || output_inline_.count(n) == 0);
+    AT_ASSERT(n->kind() == prim::Constant || output_inline_.count(n) == 0);
 
     if (n == block_point &&
         canInline(v)) { // the node must be at the expected point of the typical
@@ -288,7 +278,7 @@
       // recursively see if we can inline the inputs to this input
       block_point = scanNode(block_point);
       output_inline_.insert(n);
-    } else if (isConstantLike(n)) {
+    } else if (n->kind() == prim::Constant) {
       // constant nodes can always be inlined, we will de-dup them on parsing
       // and put them at the top of the function regardless
       output_inline_.insert(n);
@@ -298,7 +288,7 @@
   Node* previousNonConstant(Node* n) {
     do {
       n = n->prev();
-    } while (isConstantLike(n));
+    } while (n->kind() == prim::Constant);
     return n;
   }
 
@@ -343,7 +333,7 @@
   std::unordered_set<Node*> seen_constants;
   void buildConstantList(Node* n, std::vector<Node*>& constants) {
     for (auto input : n->inputs()) {
-      if (isConstantLike(input->node()) &&
+      if (input->node()->kind() == prim::Constant &&
           seen_constants.count(input->node()) == 0) {
         constants.push_back(input->node());
         seen_constants.insert(input->node());
@@ -602,7 +592,7 @@
   }
 
   bool isNonConstantInline(Value* input) {
-    return !isConstantLike(input->node()) &&
+    return input->node()->kind() != prim::Constant &&
         output_inline_.count(input->node());
   }
 
@@ -649,7 +639,7 @@
   }
 
   void printNode(Node* node, bool print_const) {
-    if (!print_const && isConstantLike(node))
+    if (!print_const && node->kind() == prim::Constant)
       return;
     if (node->kind() == prim::PythonOp) {
       auto value = static_cast<const PythonOp*>(node);
@@ -701,7 +691,7 @@
         // it is not safe to do the same thing for non-constants here
         // because of [reordering of inlines]
         if (output_inline_.count(node) == 0 ||
-            (isConstantLike(node) && isLongLine(ss.str()))) {
+            (node->kind() == prim::Constant && isLongLine(ss.str()))) {
           printOutputDefinition(node, ss.str());
         } else {
           // this node is safe to inline, so assign the output value
@@ -801,8 +791,7 @@
         value->writeScalars(stmt);
         printValueList(stmt, node->inputs(), "(", ")");
       } break;
-      case prim::Constant:
-      case prim::Undefined: {
+      case prim::Constant: {
         if (node->kind() == prim::Constant && !node->mustBeNone()) {
           IValue v = toIValue(node->output()).value();
           printConstant(stmt, v);
@@ -1136,7 +1125,6 @@
       prim::DictIndex,
       prim::TupleSlice,
       prim::TupleUnpack,
-      prim::Undefined,
       prim::CreateObject,
       prim::GetAttr,
       prim::SetAttr,
@@ -1149,7 +1137,8 @@
   const static std::unordered_set<Symbol> unneeded = {
       c10::onnx::Reshape, // only used in onnx
       c10::onnx::Shape, // only used in onnx
-      prim::AnyDefined, // temporarily inserted by autograd
+      prim::AutogradZero, // temporarily inserted by autograd
+      prim::AutogradAnyNonZero, // temporarily inserted by autograd
       prim::AutogradAdd, // temporarily inserted by autograd
       prim::ConstantChunk, // optimization pass adds it
       prim::DifferentiableGraph, // optimization pass adds it
diff --git a/torch/csrc/jit/passes/shape_analysis.cpp b/torch/csrc/jit/passes/shape_analysis.cpp
index bc97018..0b4eb39 100644
--- a/torch/csrc/jit/passes/shape_analysis.cpp
+++ b/torch/csrc/jit/passes/shape_analysis.cpp
@@ -499,7 +499,7 @@
         }
         return;
       }
-      case prim::Undefined: {
+      case prim::AutogradZero: {
         setUnshapedType(node);
         return;
       }
diff --git a/torch/csrc/jit/passes/specialize_undef.cpp b/torch/csrc/jit/passes/specialize_autogradzero.cpp
similarity index 65%
rename from torch/csrc/jit/passes/specialize_undef.cpp
rename to torch/csrc/jit/passes/specialize_autogradzero.cpp
index 07ffa48..06f252d 100644
--- a/torch/csrc/jit/passes/specialize_undef.cpp
+++ b/torch/csrc/jit/passes/specialize_autogradzero.cpp
@@ -1,25 +1,25 @@
-#include <torch/csrc/jit/passes/specialize_undef.h>
+#include <torch/csrc/jit/passes/specialize_autogradzero.h>
 #include <torch/csrc/jit/symbolic_variable.h>
 
 namespace torch {
 namespace jit {
 
-// propagate undefined information through a gradient graph and
+// propagate autograd zero information through a gradient graph and
 // remove grad_of blocks if present.
-// Note: this is a very limited pass. It only propagates undefines for
+// Note: this is a very limited pass. It only propagates autograd zeros for
 // operations generated by the symbolic autodiff code and cleans up
 // AutogradAdds when possible. Outputs of other nodes are conservatively
 // marked Unknown and not optimized.
-void specializeUndef(Graph& g) {
-  enum class State { Defined, Undefined, Unknown };
+void specializeAutogradZero(Graph& g) {
+  enum class State { Nonzero, Zero, Unknown };
   std::unordered_map<Value*, State> state;
 
   for (Value* input : g.inputs()) {
     const auto& tp = input->type();
-    if (tp->isSubtypeOf(UndefinedTensorType::get())) {
-      state[input] = State::Undefined;
+    if (tp->isSubtypeOf(AutogradZeroTensorType::get())) {
+      state[input] = State::Zero;
     } else if (tp->isSubtypeOf(TensorType::get())) {
-      state[input] = State::Defined;
+      state[input] = State::Nonzero;
     } else {
       state[input] = State::Unknown;
     }
@@ -29,29 +29,29 @@
     auto n = *it;
     switch (n->kind()) {
       case prim::GradOf: {
-        auto all_undefined =
+        auto all_zeros =
             std::all_of(n->inputs().begin(), n->inputs().end(), [&](Value* v) {
-              return state[v] == State::Undefined;
+              return state[v] == State::Zero;
             });
-        // Property 1: if all the gradInputs to the GradOf are undefined
+        // Property 1: if all the gradInputs to the GradOf are Zero
         // then the gradOutputs are also zero and will be represented as
-        // undefined nodes
-        if (all_undefined) {
-          auto undef = g.createUndefined()->insertAfter(n)->output();
+        // AutogradZero nodes
+        if (all_zeros) {
+          auto zero = g.createAutogradZero()->insertAfter(n)->output();
           for (auto o : n->outputs()) {
-            o->replaceAllUsesWith(undef);
+            o->replaceAllUsesWith(zero);
           }
         } else {
           // Property 2: GradOfs are required to correctly handle combinations
-          // of defined and undefined inputs. They are expected to produce
-          // defined output tensors in this case.
+          // of Nonzero and zero inputs. They are expected to produce
+          // Nonzero output tensors in this case.
 
           // Remove the GradOf, splicing its body back into the surrounding
           // block
           auto body = n->blocks().at(0);
           for (auto input : n->inputs()) {
             // we should never get into a situation when specializing a GradOf
-            // where we do not know if a value is defined since at the top level
+            // where we do not know if a value is Nonzero since at the top level
             // a gradient graph is composed of Linear nodes and AutogradAdds
             // and LinearNodes only appear in these graphs
             AT_ASSERT(state[input] != State::Unknown);
@@ -70,32 +70,32 @@
       case prim::AutogradAdd: {
         auto a = n->input(0);
         auto b = n->input(1);
-        // if one is undefined, we can just drop the add
-        if (state[a] == State::Undefined) {
-          // Undef + b == b
+        // if one is Autograd zero, we can just drop the add
+        if (state[a] == State::Zero) {
+          // Zero + b == b
           n->output()->replaceAllUsesWith(b);
           it.destroyCurrent();
-        } else if (state[b] == State::Undefined) {
-          // a + Undef == a
+        } else if (state[b] == State::Zero) {
+          // a + Zero == a
           n->output()->replaceAllUsesWith(a);
           it.destroyCurrent();
-        } else if (state[a] == State::Defined && state[b] == State::Defined) {
-          // when both are defined, we can use a normal, optimizable add
+        } else if (state[a] == State::Nonzero && state[b] == State::Nonzero) {
+          // when both are Nonzero, we can use a normal, optimizable add
           // instruction
           WithInsertPoint guard(n);
           Value* new_add = toVar(a) + toVar(b);
-          state[new_add] = State::Defined;
+          state[new_add] = State::Nonzero;
           n->output()->replaceAllUsesWith(new_add);
           it.destroyCurrent();
         } else {
-          // otherwise we have conditionally-defined things, and we need
-          // to actually run an AutogradAdd which will guard for undefs
+          // otherwise we have conditionally-Nonzero things, and we need
+          // to actually run an AutogradAdd which will guard for Zeros
           // so we leave the op as is
           state[n->output()] = State::Unknown;
         }
       } break;
-      case prim::Undefined: {
-        state[n->output()] = State::Undefined;
+      case prim::AutogradZero: {
+        state[n->output()] = State::Zero;
       } break;
       default:
         for (auto o : n->outputs()) {
diff --git a/torch/csrc/jit/passes/specialize_undef.h b/torch/csrc/jit/passes/specialize_autogradzero.h
similarity index 63%
rename from torch/csrc/jit/passes/specialize_undef.h
rename to torch/csrc/jit/passes/specialize_autogradzero.h
index f829570..dfc5cfb 100644
--- a/torch/csrc/jit/passes/specialize_undef.h
+++ b/torch/csrc/jit/passes/specialize_autogradzero.h
@@ -5,13 +5,13 @@
 namespace torch {
 namespace jit {
 
-// propagate undefined information through a gradient graph and
+// propagate autograd zero information through a gradient graph and
 // remove grad_of blocks if present.
-// Note: this is a very limited pass. It only propagates undefines for
+// Note: this is a very limited pass. It only propagates autograd zeros for
 // operations generated by the symbolic autodiff code and cleans up
 // AutogradAdds when possible. Outputs of other nodes are conservatively
 // marked Unknown and not optimized.
-TORCH_API void specializeUndef(Graph& g);
+TORCH_API void specializeAutogradZero(Graph& g);
 
 } // namespace jit
 } // namespace torch
diff --git a/torch/csrc/jit/pybind_utils.h b/torch/csrc/jit/pybind_utils.h
index badaa02..f5a08cb 100644
--- a/torch/csrc/jit/pybind_utils.h
+++ b/torch/csrc/jit/pybind_utils.h
@@ -134,8 +134,8 @@
     c10::optional<int32_t> N) {
   switch (type->kind()) {
     case TypeKind::TensorType:
+    case TypeKind::AutogradZeroTensorType:
     case TypeKind::DimensionedTensorType:
-    case TypeKind::UndefinedTensorType:
     case TypeKind::CompleteTensorType: {
       auto var = py::cast<autograd::Variable>(obj);
       if (var.is_sparse()) {
diff --git a/torch/csrc/jit/register_prim_ops.cpp b/torch/csrc/jit/register_prim_ops.cpp
index 526e008..257e09a 100644
--- a/torch/csrc/jit/register_prim_ops.cpp
+++ b/torch/csrc/jit/register_prim_ops.cpp
@@ -374,7 +374,7 @@
            return 0;
          }),
      Operator(
-         "prim::Undefined() -> Tensor",
+         "prim::AutogradZero() -> Tensor",
          [](const Node* node) {
            return [](Stack& stack) {
              stack.emplace_back(at::Tensor());
@@ -458,7 +458,6 @@
              return 0;
            };
          }),
-
      Operator(
          "prim::RaiseException(str msg) -> ()",
          [](Stack& stack) {
@@ -526,25 +525,23 @@
              return 0;
            };
          }),
-
-     Operator(
-         prim::AnyDefined,
-         [](const Node* node) {
-           size_t num_inputs = node->inputs().size();
-           return [=](Stack& stack) {
-             bool result = false;
-             for (const IValue& t : last(stack, num_inputs)) {
-               if (t.toTensor().defined()) {
-                 result = true;
-                 break;
-               }
-             }
-             drop(stack, num_inputs);
-             stack.emplace_back(result);
-             return 0;
-           };
-         }),
-
+    Operator(
+        prim::AutogradAnyNonZero,
+        [](const Node* node) {
+          size_t num_inputs = node->inputs().size();
+          return [=](Stack& stack) {
+            bool result = false;
+            for (const IValue& t : last(stack, num_inputs)) {
+              if (t.toTensor().defined()) {
+                result = true;
+                break;
+              }
+            }
+            drop(stack, num_inputs);
+            stack.emplace_back(result);
+            return 0;
+          };
+        }),
      Operator(
          prim::AutogradAdd,
          [](const Node* node) {
diff --git a/torch/onnx/symbolic.py b/torch/onnx/symbolic.py
index fe20db9..bfa3b96 100644
--- a/torch/onnx/symbolic.py
+++ b/torch/onnx/symbolic.py
@@ -28,9 +28,9 @@
 #   transparently dispatched to their non inplace versions in
 #   'run_symbolic_function'.   See Note [Export inplace]
 #
-# ---------------------------------------------------------------------
+# ----------------------------------------------------------------------------------
 # A note on Tensor types
-# ---------------------------------------------------------------------
+# ----------------------------------------------------------------------------------
 #
 # In general, we should avoid depending on the type of Tensor Values contained
 # within the trace graph. However, this is sometimes unavoidable (due to ONNX
@@ -41,17 +41,16 @@
 # TensorType - This is a Tensor, but we don't know anything about its
 #               properties (e.g. scalar type, # dims, shapes).
 #               Appears as `Tensor` in graph print-outs.
-# UndefinedTensorType <: TensorType - Denotes an undefined Tensor
 # DimensionedTensorType <: TensorType - Denotes a Tensor for which we know the scalar
 #                             type and number of dimensions, but not the concrete
 #                             shapes. For example, appears as 'Float(*, *)' in
 #                             graph print-outs. Useful accessor methods include
 #                             dim() and scalarType()
-# CompleteTensorType <: TensorType - Denotes a Tensor for which we know the
-#                                    concrete sizes in addition to the information
-#                                    contained in TensorTyper. This adds a sizes()
-#                                    method which can be used to retrieve the
-#                                    concrete sizes.
+# CompleteTensorType <: DimensionedTensorType - Denotes a Tensor for which we know the
+#                                               concrete sizes in addition to the information
+#                                               contained in TensorTyper. This adds a sizes()
+#                                               method which can be used to retrieve the
+#                                               concrete sizes.
 #
 # In general, we should prefer to rely on the least specific information possible.
 # For example, not relying on tensor properties at all is better than relying
@@ -59,9 +58,9 @@
 # concrete shapes (CompleteTensorType). Doing so will make the export symbolics
 # more robust to different graphs.
 
-# ---------------------------------------------------------------------
+# ---------------------------------------------------------------------------------
 # Helper functions
-# ---------------------------------------------------------------------
+# ---------------------------------------------------------------------------------
 
 # Save some builtins as locals, because we'll shadown them below
 _sum = sum