Rename prim::Undefined to prim::AutogradZero (#17611)
Summary:
supersedes #17245
Pull Request resolved: https://github.com/pytorch/pytorch/pull/17611
Differential Revision: D14283581
Pulled By: wanchaol
fbshipit-source-id: 8022d02b8a021ea2fee9a18a2c8920eb123200c5
diff --git a/aten/src/ATen/core/interned_strings.h b/aten/src/ATen/core/interned_strings.h
index d58fdc1..2bd28fb 100644
--- a/aten/src/ATen/core/interned_strings.h
+++ b/aten/src/ATen/core/interned_strings.h
@@ -24,7 +24,6 @@
_(prim, BroadcastSizes) \
_(prim, Constant) \
_(prim, ChunkSizes) \
- _(prim, None) \
_(prim, Drop) \
_(prim, Eval) \
_(prim, Expand) /* onnx */ \
@@ -46,7 +45,8 @@
_(prim, Reverse) \
_(prim, Return) \
_(prim, Store) \
- _(prim, Undefined) \
+ _(prim, AutogradZero) \
+ _(prim, AutogradAnyNonZero) \
_(prim, Starred) \
_(prim, TupleConstruct) \
_(prim, TupleUnpack) \
@@ -67,7 +67,6 @@
_(prim, requires_grad) \
_(prim, AutogradAdd) \
_(prim, GradOf) \
- _(prim, AnyDefined) \
_(prim, FusedConcat) \
_(prim, ConstantChunk) \
_(prim, MMTreeReduce) \
diff --git a/aten/src/ATen/core/jit_type.h b/aten/src/ATen/core/jit_type.h
index 066fb13..3759566 100644
--- a/aten/src/ATen/core/jit_type.h
+++ b/aten/src/ATen/core/jit_type.h
@@ -29,7 +29,7 @@
_(TensorType) \
_(DimensionedTensorType) \
_(CompleteTensorType) \
-_(UndefinedTensorType) \
+_(AutogradZeroTensorType) \
_(TupleType) \
_(ListType) \
_(DictType) \
@@ -255,9 +255,9 @@
struct TensorType;
using TensorTypePtr = std::shared_ptr<TensorType>;
// This type represents a single Tensor, with an unknown shape.
-// Subtype hierarchy for Tensor Types (DynamicType as the base type):
-// CompleteTensorType <: TensorType <: DynamicType
-// UndefinedTensorType <: DynamicType
+// Subtype hierarchy for Tensor Types (TensorType as the base type):
+// CompleteTensorType <: DimensionedTensorType <: TensorType
+// AutogradZeroTensorType <: TensorType
struct CAFFE2_API TensorType : public Type {
static TensorTypePtr create() {
return TensorTypePtr(new TensorType()); // NOLINT(modernize-make-shared)
@@ -280,15 +280,15 @@
: Type(kind) {}
};
-struct UndefinedTensorType;
-using UndefinedTensorTypePtr = std::shared_ptr<UndefinedTensorType>;
+struct AutogradZeroTensorType;
+using AutogradZeroTensorTypePtr = std::shared_ptr<AutogradZeroTensorType>;
// This type represents an undefined tensor.
-struct CAFFE2_API UndefinedTensorType : public TensorType {
- static UndefinedTensorTypePtr create() {
- return UndefinedTensorTypePtr(new UndefinedTensorType()); // NOLINT(modernize-make-shared)
+struct CAFFE2_API AutogradZeroTensorType : public TensorType {
+ static AutogradZeroTensorTypePtr create() {
+ return AutogradZeroTensorTypePtr(new AutogradZeroTensorType()); // NOLINT(modernize-make-shared)
}
- DEFINE_IS_SUBCLASS(UndefinedTensorType);
+ DEFINE_IS_SUBCLASS(AutogradZeroTensorType);
bool requires_grad() const override { return false; }
@@ -297,18 +297,18 @@
}
bool isSubtypeOf(const TypePtr rhs) const override {
return rhs->kind() == TypeKind::TensorType ||
- rhs->kind() == TypeKind::UndefinedTensorType ||
+ rhs->kind() == TypeKind::AutogradZeroTensorType ||
TensorType::isSubtypeOf(rhs);
}
std::string str() const override {
return "UndefinedTensor";
}
- static const TypeKind Kind = TypeKind::UndefinedTensorType;
+ static const TypeKind Kind = TypeKind::AutogradZeroTensorType;
// global singleton
- static UndefinedTensorTypePtr get();
+ static AutogradZeroTensorTypePtr get();
protected:
- UndefinedTensorType(): TensorType(TypeKind::UndefinedTensorType) {}
+ AutogradZeroTensorType(): TensorType(TypeKind::AutogradZeroTensorType) {}
};
struct DimensionedTensorType;
diff --git a/aten/src/ATen/core/type.cpp b/aten/src/ATen/core/type.cpp
index b17b05a..e8892db 100644
--- a/aten/src/ATen/core/type.cpp
+++ b/aten/src/ATen/core/type.cpp
@@ -59,8 +59,8 @@
static auto value = TensorType::create();
return value;
}
-UndefinedTensorTypePtr UndefinedTensorType::get() {
- static auto value = UndefinedTensorType::create();
+AutogradZeroTensorTypePtr AutogradZeroTensorType::get() {
+ static auto value = AutogradZeroTensorType::create();
return value;
}
NumberTypePtr NumberType::get() {
diff --git a/test/cpp/jit/test_alias_analysis.h b/test/cpp/jit/test_alias_analysis.h
index c1d2b57..1f5bfc4 100644
--- a/test/cpp/jit/test_alias_analysis.h
+++ b/test/cpp/jit/test_alias_analysis.h
@@ -51,7 +51,7 @@
for (const auto name : inputNames) {
inputs.push_back(nodes.at(name)->output());
}
- auto node = graph->appendNode(graph->create(prim::Undefined, inputs));
+ auto node = graph->appendNode(graph->create(prim::AutogradZero, inputs));
node->output()->setUniqueName(name);
nodes[name] = node;
@@ -63,7 +63,7 @@
}
auto block = node->blocks().at(0);
- block->appendNode(graph->create(prim::Undefined, blockDeps));
+ block->appendNode(graph->create(prim::AutogradZero, blockDeps));
}
}
diff --git a/test/cpp/jit/test_misc.h b/test/cpp/jit/test_misc.h
index 36c22dc..e2ea958 100644
--- a/test/cpp/jit/test_misc.h
+++ b/test/cpp/jit/test_misc.h
@@ -1524,10 +1524,10 @@
void testTopologicalIndex() {
{
Graph graph;
- auto node1 = graph.create(prim::Undefined);
- auto node2 = graph.create(prim::Undefined);
- auto node3 = graph.create(prim::Undefined);
- auto node4 = graph.create(prim::Undefined);
+ auto node1 = graph.create(prim::AutogradZero);
+ auto node2 = graph.create(prim::AutogradZero);
+ auto node3 = graph.create(prim::AutogradZero);
+ auto node4 = graph.create(prim::AutogradZero);
graph.appendNode(node4);
graph.prependNode(node1);
@@ -1552,12 +1552,12 @@
// \ ...
// C block2
auto block1 = node3->addBlock();
- auto A = graph.create(prim::Undefined);
+ auto A = graph.create(prim::AutogradZero);
block1->appendNode(A);
- auto B = graph.create(prim::Undefined);
+ auto B = graph.create(prim::AutogradZero);
block1->appendNode(B);
auto block2 = B->addBlock();
- auto C = graph.create(prim::Undefined);
+ auto C = graph.create(prim::AutogradZero);
block2->appendNode(C);
// Check isAfter on different block levels
@@ -1567,7 +1567,7 @@
// make sure things don't blow up on deletions
node2->destroy();
- auto node2p = graph.create(prim::Undefined);
+ auto node2p = graph.create(prim::AutogradZero);
node2p->insertAfter(node1);
ASSERT_TRUE(node1->isBefore(node2p));
ASSERT_TRUE(node2p->isBefore(node3));
@@ -1577,11 +1577,11 @@
Graph graph;
std::map<size_t, Node*> nodes;
- auto anchor = graph.create(prim::Undefined);
+ auto anchor = graph.create(prim::AutogradZero);
graph.appendNode(anchor);
// Inserting to the same place a lot will trigger reindexing
for (auto i = 0; i < 100; ++i) {
- auto n = graph.create(prim::Undefined);
+ auto n = graph.create(prim::AutogradZero);
n->insertAfter(anchor);
nodes[i] = n;
}
diff --git a/tools/build_variables.py b/tools/build_variables.py
index 63953ab..ebf6a9e 100644
--- a/tools/build_variables.py
+++ b/tools/build_variables.py
@@ -85,7 +85,7 @@
"torch/csrc/jit/passes/remove_expands.cpp",
"torch/csrc/jit/passes/requires_grad_analysis.cpp",
"torch/csrc/jit/passes/shape_analysis.cpp",
- "torch/csrc/jit/passes/specialize_undef.cpp",
+ "torch/csrc/jit/passes/specialize_autogradzero.cpp",
"torch/csrc/jit/passes/utils/subgraph_utils.cpp",
"torch/csrc/jit/passes/utils/memory_dag.cpp",
"torch/csrc/jit/register_prim_ops.cpp",
diff --git a/torch/CMakeLists.txt b/torch/CMakeLists.txt
index 159e8bc..a06ac07 100644
--- a/torch/CMakeLists.txt
+++ b/torch/CMakeLists.txt
@@ -161,7 +161,7 @@
${TORCH_SRC_DIR}/csrc/jit/passes/remove_inplace_ops.cpp
${TORCH_SRC_DIR}/csrc/jit/passes/shape_analysis.cpp
${TORCH_SRC_DIR}/csrc/jit/passes/requires_grad_analysis.cpp
- ${TORCH_SRC_DIR}/csrc/jit/passes/specialize_undef.cpp
+ ${TORCH_SRC_DIR}/csrc/jit/passes/specialize_autogradzero.cpp
${TORCH_SRC_DIR}/csrc/jit/passes/python_print.cpp
${TORCH_SRC_DIR}/csrc/jit/passes/utils/subgraph_utils.cpp
${TORCH_SRC_DIR}/csrc/jit/passes/utils/check_alias_annotation.cpp
diff --git a/torch/csrc/jit/argument_spec.h b/torch/csrc/jit/argument_spec.h
index 72155cb..ac959aa 100644
--- a/torch/csrc/jit/argument_spec.h
+++ b/torch/csrc/jit/argument_spec.h
@@ -156,7 +156,7 @@
if (original->isSubtypeOf(TensorType::get())) {
auto& arg = args.at(offset++);
if (!arg.defined())
- return UndefinedTensorType::get();
+ return AutogradZeroTensorType::get();
return DimensionedTensorType::create(
arg.type(),
ConvertIntToCPUOrCUDA(arg.device()),
diff --git a/torch/csrc/jit/autodiff.cpp b/torch/csrc/jit/autodiff.cpp
index a92f982..b1b1da2 100644
--- a/torch/csrc/jit/autodiff.cpp
+++ b/torch/csrc/jit/autodiff.cpp
@@ -122,7 +122,7 @@
// perf "aten::atan2(Tensor self) -> Tensor", "aten::max(Tensor self) ->
// Tensor", "aten::min(Tensor self) -> Tensor"
- if (n->kind() == prim::Constant || n->kind() == prim::Undefined ||
+ if (n->kind() == prim::Constant || n->kind() == prim::AutogradZero ||
n->kind() == prim::AutogradAdd || n->kind() == prim::ConstantChunk)
return true;
if (differentiable_ops.find(n))
@@ -718,7 +718,7 @@
node->matches(
"aten::nll_loss(Tensor self, Tensor target, Tensor? weight, int reduction, int ignore_index) -> Tensor")) {
auto graph = node->owningGraph();
- auto total_weight = graph->insertNode(graph->createUndefined());
+ auto total_weight = graph->insertNode(graph->createAutogradZero());
auto weight = graph->insertNode(graph->createNone(TensorType::get()));
auto backward_value = graph->insert(
aten::nll_loss_backward,
@@ -748,7 +748,7 @@
return {backward_value->node()->output(0), nullptr};
} else if (
- node->kind() == prim::Constant || node->kind() == prim::Undefined) {
+ node->kind() == prim::Constant || node->kind() == prim::AutogradZero) {
return {};
}
throw std::runtime_error(
@@ -834,7 +834,7 @@
const auto get_grad = [&](Value* v) -> Value* {
auto it = grad_map.find(v);
if (it == grad_map.end()) {
- auto undef = graph.insertNode(graph.createUndefined());
+ auto undef = graph.insertNode(graph.createAutogradZero());
std::tie(it, std::ignore) = grad_map.emplace(v, undef->output());
}
return it->second;
@@ -947,7 +947,7 @@
AT_ASSERT(
top_node->kind() == prim::GradOf ||
top_node->kind() == prim::AutogradAdd ||
- top_node->kind() == prim::Undefined);
+ top_node->kind() == prim::AutogradZero);
if (top_node->kind() != prim::GradOf)
continue;
Block* grad_body = top_node->blocks().at(0);
diff --git a/torch/csrc/jit/graph_executor.cpp b/torch/csrc/jit/graph_executor.cpp
index 2332dd1..eaaf1d1 100644
--- a/torch/csrc/jit/graph_executor.cpp
+++ b/torch/csrc/jit/graph_executor.cpp
@@ -26,7 +26,7 @@
#include <torch/csrc/jit/passes/remove_expands.h>
#include <torch/csrc/jit/passes/requires_grad_analysis.h>
#include <torch/csrc/jit/passes/shape_analysis.h>
-#include <torch/csrc/jit/passes/specialize_undef.h>
+#include <torch/csrc/jit/passes/specialize_autogradzero.h>
#include <torch/csrc/jit/symbolic_variable.h>
#include <torch/csrc/jit/tracer.h>
@@ -635,7 +635,7 @@
}
void runRequiredPasses(const std::shared_ptr<Graph>& g) {
- specializeUndef(*g);
+ specializeAutogradZero(*g);
LowerGradOf(*g);
// implicit inserted expand nodes are not necessarily always valid
// when used inside script methods that might have unstable shapes
diff --git a/torch/csrc/jit/init.cpp b/torch/csrc/jit/init.cpp
index 4cdf4ad..ad0e0d9 100644
--- a/torch/csrc/jit/init.cpp
+++ b/torch/csrc/jit/init.cpp
@@ -29,7 +29,7 @@
#include <torch/csrc/jit/passes/remove_expands.h>
#include <torch/csrc/jit/passes/remove_inplace_ops.h>
#include <torch/csrc/jit/passes/shape_analysis.h>
-#include <torch/csrc/jit/passes/specialize_undef.h>
+#include <torch/csrc/jit/passes/specialize_autogradzero.h>
#include <torch/csrc/jit/passes/to_batch.h>
#include <torch/csrc/jit/passes/utils/check_alias_annotation.h>
#include <torch/csrc/jit/pybind_utils.h>
@@ -188,7 +188,7 @@
.def("_jit_pass_onnx_block", BlockToONNX)
.def("_jit_pass_fixup_onnx_loops", FixupONNXLoops)
.def("_jit_pass_canonicalize_ops", CanonicalizeOps)
- .def("_jit_pass_specialize_undef", specializeUndef)
+ .def("_jit_pass_specialize_autogradzero", specializeAutogradZero)
.def("_jit_override_can_fuse_on_cpu", &overrideCanFuseOnCPU)
.def(
"_jit_differentiate",
diff --git a/torch/csrc/jit/ir.cpp b/torch/csrc/jit/ir.cpp
index 79c04e6..613da86 100644
--- a/torch/csrc/jit/ir.cpp
+++ b/torch/csrc/jit/ir.cpp
@@ -1194,8 +1194,8 @@
return n;
}
-Node* Graph::createUndefined() {
- return create(prim::Undefined);
+Node* Graph::createAutogradZero() {
+ return create(prim::AutogradZero);
}
Node* Graph::createNone(TypePtr typ) {
diff --git a/torch/csrc/jit/ir.h b/torch/csrc/jit/ir.h
index 1cd0fc0..4ef6883 100644
--- a/torch/csrc/jit/ir.h
+++ b/torch/csrc/jit/ir.h
@@ -1041,7 +1041,7 @@
TORCH_API Node* createNone(
TypePtr typ); // value of None with type Optional[typ]
- TORCH_API Node* createUndefined();
+ TORCH_API Node* createAutogradZero();
TORCH_API Node* createFusionGroup();
TORCH_API Node* createDifferentiableSubgraph();
TORCH_API Node* createTuple(
diff --git a/torch/csrc/jit/passes/alias_analysis.cpp b/torch/csrc/jit/passes/alias_analysis.cpp
index 3a9ac17..8ebec19 100644
--- a/torch/csrc/jit/passes/alias_analysis.cpp
+++ b/torch/csrc/jit/passes/alias_analysis.cpp
@@ -369,7 +369,7 @@
case prim::DictConstruct:
case prim::ListConstruct:
case prim::TupleConstruct:
- case prim::Undefined:
+ case prim::AutogradZero:
case prim::FusedConcat:
case prim::MMTreeReduce:
case prim::MMBatchSide:
@@ -1128,11 +1128,10 @@
prim::DictConstruct,
prim::ListConstruct,
prim::TupleConstruct,
- prim::Undefined,
+ prim::AutogradZero,
prim::FusedConcat,
prim::MMTreeReduce,
prim::MMBatchSide,
- prim::None,
prim::BroadcastSizes,
prim::ChunkSizes,
prim::Function,
@@ -1163,7 +1162,7 @@
prim::Drop,
at::onnx::Reshape,
at::onnx::Shape,
- prim::AnyDefined,
+ prim::AutogradAnyNonZero,
prim::AutogradAdd,
};
diff --git a/torch/csrc/jit/passes/constant_propagation.cpp b/torch/csrc/jit/passes/constant_propagation.cpp
index 56a4db7..5b7652b 100644
--- a/torch/csrc/jit/passes/constant_propagation.cpp
+++ b/torch/csrc/jit/passes/constant_propagation.cpp
@@ -18,7 +18,7 @@
prim::If,
prim::Loop,
prim::Constant,
- prim::Undefined,
+ prim::AutogradZero,
prim::unchecked_unwrap_optional, // TODO remove
// TODO (zach): we should consider skipping tensor factories in the cases
// where the constant tensor would be large but cheap to create.
diff --git a/torch/csrc/jit/passes/graph_fuser.cpp b/torch/csrc/jit/passes/graph_fuser.cpp
index b217dce..123a4fe 100644
--- a/torch/csrc/jit/passes/graph_fuser.cpp
+++ b/torch/csrc/jit/passes/graph_fuser.cpp
@@ -136,8 +136,7 @@
if (tensor->type()->isSubtypeOf(TensorType::get())) {
return true;
}
- if (tensor->node()->mustBeNone() ||
- tensor->node()->kind() == prim::Undefined) {
+ if (tensor->node()->mustBeNone()) {
return false;
}
return {};
diff --git a/torch/csrc/jit/passes/lower_grad_of.cpp b/torch/csrc/jit/passes/lower_grad_of.cpp
index f64d66d..944bef0 100644
--- a/torch/csrc/jit/passes/lower_grad_of.cpp
+++ b/torch/csrc/jit/passes/lower_grad_of.cpp
@@ -9,9 +9,9 @@
// if any_defined(inputs):
// outputs = <original_computation>
// else:
- // outputs = undefineds
+ // outputs = autograd zero tensors
WithInsertPoint guard(*it);
- auto cond = g.insertNode(g.create(prim::AnyDefined, it->inputs()))
+ auto cond = g.insertNode(g.create(prim::AutogradAnyNonZero, it->inputs()))
->output()
->setType(IntType::get());
auto if_stat =
@@ -19,7 +19,7 @@
if_stat->addBlock()->cloneFrom(
it->blocks().at(0), [](Value* v) { return v; });
auto else_block = if_stat->addBlock();
- auto undef = g.createUndefined()
+ auto undef = g.createAutogradZero()
->insertBefore(else_block->return_node())
->output();
for (size_t i = 0; i < it->outputs().size(); ++i) {
diff --git a/torch/csrc/jit/passes/python_print.cpp b/torch/csrc/jit/passes/python_print.cpp
index 43291ac..22061fd 100644
--- a/torch/csrc/jit/passes/python_print.cpp
+++ b/torch/csrc/jit/passes/python_print.cpp
@@ -239,16 +239,6 @@
// The inductive step is that the right-most input should be produced by the
// node immediatly before the current node if it is in tree order.
- bool isConstantLike(Node* n) {
- switch (n->kind()) {
- case prim::Constant:
- case prim::Undefined:
- return true;
- default:
- return false;
- }
- }
-
bool canInline(Value* v) {
Node* n = v->node();
// there must be only 1 values, otherwise we need an assignment to handle
@@ -280,7 +270,7 @@
// block_point's output.
Node* scanValue(Node* block_point, Value* v) {
Node* n = v->node();
- AT_ASSERT(isConstantLike(n) || output_inline_.count(n) == 0);
+ AT_ASSERT(n->kind() == prim::Constant || output_inline_.count(n) == 0);
if (n == block_point &&
canInline(v)) { // the node must be at the expected point of the typical
@@ -288,7 +278,7 @@
// recursively see if we can inline the inputs to this input
block_point = scanNode(block_point);
output_inline_.insert(n);
- } else if (isConstantLike(n)) {
+ } else if (n->kind() == prim::Constant) {
// constant nodes can always be inlined, we will de-dup them on parsing
// and put them at the top of the function regardless
output_inline_.insert(n);
@@ -298,7 +288,7 @@
Node* previousNonConstant(Node* n) {
do {
n = n->prev();
- } while (isConstantLike(n));
+ } while (n->kind() == prim::Constant);
return n;
}
@@ -343,7 +333,7 @@
std::unordered_set<Node*> seen_constants;
void buildConstantList(Node* n, std::vector<Node*>& constants) {
for (auto input : n->inputs()) {
- if (isConstantLike(input->node()) &&
+ if (input->node()->kind() == prim::Constant &&
seen_constants.count(input->node()) == 0) {
constants.push_back(input->node());
seen_constants.insert(input->node());
@@ -602,7 +592,7 @@
}
bool isNonConstantInline(Value* input) {
- return !isConstantLike(input->node()) &&
+ return input->node()->kind() != prim::Constant &&
output_inline_.count(input->node());
}
@@ -649,7 +639,7 @@
}
void printNode(Node* node, bool print_const) {
- if (!print_const && isConstantLike(node))
+ if (!print_const && node->kind() == prim::Constant)
return;
if (node->kind() == prim::PythonOp) {
auto value = static_cast<const PythonOp*>(node);
@@ -701,7 +691,7 @@
// it is not safe to do the same thing for non-constants here
// because of [reordering of inlines]
if (output_inline_.count(node) == 0 ||
- (isConstantLike(node) && isLongLine(ss.str()))) {
+ (node->kind() == prim::Constant && isLongLine(ss.str()))) {
printOutputDefinition(node, ss.str());
} else {
// this node is safe to inline, so assign the output value
@@ -801,8 +791,7 @@
value->writeScalars(stmt);
printValueList(stmt, node->inputs(), "(", ")");
} break;
- case prim::Constant:
- case prim::Undefined: {
+ case prim::Constant: {
if (node->kind() == prim::Constant && !node->mustBeNone()) {
IValue v = toIValue(node->output()).value();
printConstant(stmt, v);
@@ -1136,7 +1125,6 @@
prim::DictIndex,
prim::TupleSlice,
prim::TupleUnpack,
- prim::Undefined,
prim::CreateObject,
prim::GetAttr,
prim::SetAttr,
@@ -1149,7 +1137,8 @@
const static std::unordered_set<Symbol> unneeded = {
c10::onnx::Reshape, // only used in onnx
c10::onnx::Shape, // only used in onnx
- prim::AnyDefined, // temporarily inserted by autograd
+ prim::AutogradZero, // temporarily inserted by autograd
+ prim::AutogradAnyNonZero, // temporarily inserted by autograd
prim::AutogradAdd, // temporarily inserted by autograd
prim::ConstantChunk, // optimization pass adds it
prim::DifferentiableGraph, // optimization pass adds it
diff --git a/torch/csrc/jit/passes/shape_analysis.cpp b/torch/csrc/jit/passes/shape_analysis.cpp
index bc97018..0b4eb39 100644
--- a/torch/csrc/jit/passes/shape_analysis.cpp
+++ b/torch/csrc/jit/passes/shape_analysis.cpp
@@ -499,7 +499,7 @@
}
return;
}
- case prim::Undefined: {
+ case prim::AutogradZero: {
setUnshapedType(node);
return;
}
diff --git a/torch/csrc/jit/passes/specialize_undef.cpp b/torch/csrc/jit/passes/specialize_autogradzero.cpp
similarity index 65%
rename from torch/csrc/jit/passes/specialize_undef.cpp
rename to torch/csrc/jit/passes/specialize_autogradzero.cpp
index 07ffa48..06f252d 100644
--- a/torch/csrc/jit/passes/specialize_undef.cpp
+++ b/torch/csrc/jit/passes/specialize_autogradzero.cpp
@@ -1,25 +1,25 @@
-#include <torch/csrc/jit/passes/specialize_undef.h>
+#include <torch/csrc/jit/passes/specialize_autogradzero.h>
#include <torch/csrc/jit/symbolic_variable.h>
namespace torch {
namespace jit {
-// propagate undefined information through a gradient graph and
+// propagate autograd zero information through a gradient graph and
// remove grad_of blocks if present.
-// Note: this is a very limited pass. It only propagates undefines for
+// Note: this is a very limited pass. It only propagates autograd zeros for
// operations generated by the symbolic autodiff code and cleans up
// AutogradAdds when possible. Outputs of other nodes are conservatively
// marked Unknown and not optimized.
-void specializeUndef(Graph& g) {
- enum class State { Defined, Undefined, Unknown };
+void specializeAutogradZero(Graph& g) {
+ enum class State { Nonzero, Zero, Unknown };
std::unordered_map<Value*, State> state;
for (Value* input : g.inputs()) {
const auto& tp = input->type();
- if (tp->isSubtypeOf(UndefinedTensorType::get())) {
- state[input] = State::Undefined;
+ if (tp->isSubtypeOf(AutogradZeroTensorType::get())) {
+ state[input] = State::Zero;
} else if (tp->isSubtypeOf(TensorType::get())) {
- state[input] = State::Defined;
+ state[input] = State::Nonzero;
} else {
state[input] = State::Unknown;
}
@@ -29,29 +29,29 @@
auto n = *it;
switch (n->kind()) {
case prim::GradOf: {
- auto all_undefined =
+ auto all_zeros =
std::all_of(n->inputs().begin(), n->inputs().end(), [&](Value* v) {
- return state[v] == State::Undefined;
+ return state[v] == State::Zero;
});
- // Property 1: if all the gradInputs to the GradOf are undefined
+ // Property 1: if all the gradInputs to the GradOf are Zero
// then the gradOutputs are also zero and will be represented as
- // undefined nodes
- if (all_undefined) {
- auto undef = g.createUndefined()->insertAfter(n)->output();
+ // AutogradZero nodes
+ if (all_zeros) {
+ auto zero = g.createAutogradZero()->insertAfter(n)->output();
for (auto o : n->outputs()) {
- o->replaceAllUsesWith(undef);
+ o->replaceAllUsesWith(zero);
}
} else {
// Property 2: GradOfs are required to correctly handle combinations
- // of defined and undefined inputs. They are expected to produce
- // defined output tensors in this case.
+ // of Nonzero and zero inputs. They are expected to produce
+ // Nonzero output tensors in this case.
// Remove the GradOf, splicing its body back into the surrounding
// block
auto body = n->blocks().at(0);
for (auto input : n->inputs()) {
// we should never get into a situation when specializing a GradOf
- // where we do not know if a value is defined since at the top level
+ // where we do not know if a value is Nonzero since at the top level
// a gradient graph is composed of Linear nodes and AutogradAdds
// and LinearNodes only appear in these graphs
AT_ASSERT(state[input] != State::Unknown);
@@ -70,32 +70,32 @@
case prim::AutogradAdd: {
auto a = n->input(0);
auto b = n->input(1);
- // if one is undefined, we can just drop the add
- if (state[a] == State::Undefined) {
- // Undef + b == b
+ // if one is Autograd zero, we can just drop the add
+ if (state[a] == State::Zero) {
+ // Zero + b == b
n->output()->replaceAllUsesWith(b);
it.destroyCurrent();
- } else if (state[b] == State::Undefined) {
- // a + Undef == a
+ } else if (state[b] == State::Zero) {
+ // a + Zero == a
n->output()->replaceAllUsesWith(a);
it.destroyCurrent();
- } else if (state[a] == State::Defined && state[b] == State::Defined) {
- // when both are defined, we can use a normal, optimizable add
+ } else if (state[a] == State::Nonzero && state[b] == State::Nonzero) {
+ // when both are Nonzero, we can use a normal, optimizable add
// instruction
WithInsertPoint guard(n);
Value* new_add = toVar(a) + toVar(b);
- state[new_add] = State::Defined;
+ state[new_add] = State::Nonzero;
n->output()->replaceAllUsesWith(new_add);
it.destroyCurrent();
} else {
- // otherwise we have conditionally-defined things, and we need
- // to actually run an AutogradAdd which will guard for undefs
+ // otherwise we have conditionally-Nonzero things, and we need
+ // to actually run an AutogradAdd which will guard for Zeros
// so we leave the op as is
state[n->output()] = State::Unknown;
}
} break;
- case prim::Undefined: {
- state[n->output()] = State::Undefined;
+ case prim::AutogradZero: {
+ state[n->output()] = State::Zero;
} break;
default:
for (auto o : n->outputs()) {
diff --git a/torch/csrc/jit/passes/specialize_undef.h b/torch/csrc/jit/passes/specialize_autogradzero.h
similarity index 63%
rename from torch/csrc/jit/passes/specialize_undef.h
rename to torch/csrc/jit/passes/specialize_autogradzero.h
index f829570..dfc5cfb 100644
--- a/torch/csrc/jit/passes/specialize_undef.h
+++ b/torch/csrc/jit/passes/specialize_autogradzero.h
@@ -5,13 +5,13 @@
namespace torch {
namespace jit {
-// propagate undefined information through a gradient graph and
+// propagate autograd zero information through a gradient graph and
// remove grad_of blocks if present.
-// Note: this is a very limited pass. It only propagates undefines for
+// Note: this is a very limited pass. It only propagates autograd zeros for
// operations generated by the symbolic autodiff code and cleans up
// AutogradAdds when possible. Outputs of other nodes are conservatively
// marked Unknown and not optimized.
-TORCH_API void specializeUndef(Graph& g);
+TORCH_API void specializeAutogradZero(Graph& g);
} // namespace jit
} // namespace torch
diff --git a/torch/csrc/jit/pybind_utils.h b/torch/csrc/jit/pybind_utils.h
index badaa02..f5a08cb 100644
--- a/torch/csrc/jit/pybind_utils.h
+++ b/torch/csrc/jit/pybind_utils.h
@@ -134,8 +134,8 @@
c10::optional<int32_t> N) {
switch (type->kind()) {
case TypeKind::TensorType:
+ case TypeKind::AutogradZeroTensorType:
case TypeKind::DimensionedTensorType:
- case TypeKind::UndefinedTensorType:
case TypeKind::CompleteTensorType: {
auto var = py::cast<autograd::Variable>(obj);
if (var.is_sparse()) {
diff --git a/torch/csrc/jit/register_prim_ops.cpp b/torch/csrc/jit/register_prim_ops.cpp
index 526e008..257e09a 100644
--- a/torch/csrc/jit/register_prim_ops.cpp
+++ b/torch/csrc/jit/register_prim_ops.cpp
@@ -374,7 +374,7 @@
return 0;
}),
Operator(
- "prim::Undefined() -> Tensor",
+ "prim::AutogradZero() -> Tensor",
[](const Node* node) {
return [](Stack& stack) {
stack.emplace_back(at::Tensor());
@@ -458,7 +458,6 @@
return 0;
};
}),
-
Operator(
"prim::RaiseException(str msg) -> ()",
[](Stack& stack) {
@@ -526,25 +525,23 @@
return 0;
};
}),
-
- Operator(
- prim::AnyDefined,
- [](const Node* node) {
- size_t num_inputs = node->inputs().size();
- return [=](Stack& stack) {
- bool result = false;
- for (const IValue& t : last(stack, num_inputs)) {
- if (t.toTensor().defined()) {
- result = true;
- break;
- }
- }
- drop(stack, num_inputs);
- stack.emplace_back(result);
- return 0;
- };
- }),
-
+ Operator(
+ prim::AutogradAnyNonZero,
+ [](const Node* node) {
+ size_t num_inputs = node->inputs().size();
+ return [=](Stack& stack) {
+ bool result = false;
+ for (const IValue& t : last(stack, num_inputs)) {
+ if (t.toTensor().defined()) {
+ result = true;
+ break;
+ }
+ }
+ drop(stack, num_inputs);
+ stack.emplace_back(result);
+ return 0;
+ };
+ }),
Operator(
prim::AutogradAdd,
[](const Node* node) {
diff --git a/torch/onnx/symbolic.py b/torch/onnx/symbolic.py
index fe20db9..bfa3b96 100644
--- a/torch/onnx/symbolic.py
+++ b/torch/onnx/symbolic.py
@@ -28,9 +28,9 @@
# transparently dispatched to their non inplace versions in
# 'run_symbolic_function'. See Note [Export inplace]
#
-# ---------------------------------------------------------------------
+# ----------------------------------------------------------------------------------
# A note on Tensor types
-# ---------------------------------------------------------------------
+# ----------------------------------------------------------------------------------
#
# In general, we should avoid depending on the type of Tensor Values contained
# within the trace graph. However, this is sometimes unavoidable (due to ONNX
@@ -41,17 +41,16 @@
# TensorType - This is a Tensor, but we don't know anything about its
# properties (e.g. scalar type, # dims, shapes).
# Appears as `Tensor` in graph print-outs.
-# UndefinedTensorType <: TensorType - Denotes an undefined Tensor
# DimensionedTensorType <: TensorType - Denotes a Tensor for which we know the scalar
# type and number of dimensions, but not the concrete
# shapes. For example, appears as 'Float(*, *)' in
# graph print-outs. Useful accessor methods include
# dim() and scalarType()
-# CompleteTensorType <: TensorType - Denotes a Tensor for which we know the
-# concrete sizes in addition to the information
-# contained in TensorTyper. This adds a sizes()
-# method which can be used to retrieve the
-# concrete sizes.
+# CompleteTensorType <: DimensionedTensorType - Denotes a Tensor for which we know the
+# concrete sizes in addition to the information
+# contained in TensorTyper. This adds a sizes()
+# method which can be used to retrieve the
+# concrete sizes.
#
# In general, we should prefer to rely on the least specific information possible.
# For example, not relying on tensor properties at all is better than relying
@@ -59,9 +58,9 @@
# concrete shapes (CompleteTensorType). Doing so will make the export symbolics
# more robust to different graphs.
-# ---------------------------------------------------------------------
+# ---------------------------------------------------------------------------------
# Helper functions
-# ---------------------------------------------------------------------
+# ---------------------------------------------------------------------------------
# Save some builtins as locals, because we'll shadown them below
_sum = sum