Add an Undefined node for null arguments to tensors.
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
diff --git a/test/expect/TestJit.test_conv.expect b/test/expect/TestJit.test_conv.expect
index 2207e2b..3ed2e85 100644
--- a/test/expect/TestJit.test_conv.expect
+++ b/test/expect/TestJit.test_conv.expect
@@ -1,6 +1,6 @@
graph(%1 : Double(13, 16, 3, 3)
%2 : Double(20, 16, 50, 40)) {
- %4 : UNKNOWN_TYPE = Constant[Value=<Tensor>](), uses = [%3.i2];
+ %4 : UNKNOWN_TYPE = Undefined(), uses = [%3.i2];
%3.0 : Double(20, 13, 48, 38), %3.1 : Handle = CppOp[ConvForward](%2, %1, %4), uses = [[%0.i0], []];
return (%3.0);
}
diff --git a/torch/csrc/autograd/functions/toffee/convolution.cpp b/torch/csrc/autograd/functions/toffee/convolution.cpp
index 71da56e..b60010e 100644
--- a/torch/csrc/autograd/functions/toffee/convolution.cpp
+++ b/torch/csrc/autograd/functions/toffee/convolution.cpp
@@ -14,11 +14,10 @@
jit::node_list ConvForward::primspec(PrimSpecContext* ctx, jit::node_list inputs) {
auto & g = ctx->graph;
- auto n = g->appendNode(g->create(!transposed ? jit::kConv : jit::kConvTranspose, {inputs.at(0),inputs.at(1)}));
+ auto n = g->appendNode(g->create(!transposed ? jit::kConv : jit::kConvTranspose,
+ {inputs.at(0), inputs.at(1)}));
- // TODO: Factor this logic into a helper, and make sure it gets applied
- // consistently. See also batch_normalization.cpp
- if (inputs.at(2)->kind() != jit::kConstant || inputs.at(2)->t(jit::kValue).defined()) {
+ if (inputs.at(2)->kind() != jit::kUndefined) {
n->addInput(inputs.at(2));
}
diff --git a/torch/csrc/jit/interned_strings.h b/torch/csrc/jit/interned_strings.h
index d8bb365..b932a98 100644
--- a/torch/csrc/jit/interned_strings.h
+++ b/torch/csrc/jit/interned_strings.h
@@ -21,6 +21,7 @@
_(Sigmoid) \
_(Tanh) \
_(Constant) \
+_(Undefined) \
_(FusionGroup) \
_(Chunk) \
_(NumChunks) \
diff --git a/torch/csrc/jit/ir.h b/torch/csrc/jit/ir.h
index 26c96bc..e001f96 100644
--- a/torch/csrc/jit/ir.h
+++ b/torch/csrc/jit/ir.h
@@ -750,9 +750,7 @@
return n;
}
Node * createUndefined() {
- auto n = create(kConstant);
- n->t_(kValue,at::Tensor());
- return n;
+ return create(kUndefined);
}
Node * createConstant(const at::Tensor& ref) {
JIT_ASSERT(ref.defined());
diff --git a/torch/csrc/toffee/export.cpp b/torch/csrc/toffee/export.cpp
index 1b76b27..068d147 100644
--- a/torch/csrc/toffee/export.cpp
+++ b/torch/csrc/toffee/export.cpp
@@ -137,7 +137,7 @@
}
setOutputs(node, outputs);
IR_ELSE()
- if(node->kind() == kConstant && node->t(kValue).defined()) {
+ if(node->kind() == kConstant) {
throw std::runtime_error("Constant not supported yet");
}
auto n_ = ctx.graph->createClone(node, envFn);
@@ -238,7 +238,7 @@
// of the select invariant
continue;
}
- if (node->kind() == kConstant && !node->t(kValue).defined() && node->uses().empty()) {
+ if (node->kind() == kUndefined && node->uses().empty()) {
// Undefined nodes never show up in ToffeeIR; they're just a tool
// to help primspecs do the right thing.
continue;