Add prim::TypeCheck operation (#43026)
Summary:
TypeCheck is a new operation to check the shape of tensors against
expectd shapes. TypeCheck is a variadic operation. An example,
%t0 : Tensor = ...
%t1 : Tensor = ...
%2 : FLOAT(20, 20), %3 : FLOAT(30, 30), %1 : bool =
prim::TypeCheck(%t1, %t2)
prim::If(%1)
Fixes #{issue number}
Pull Request resolved: https://github.com/pytorch/pytorch/pull/43026
Reviewed By: ZolotukhinM
Differential Revision: D23115830
Pulled By: bzinodev
fbshipit-source-id: fbf142126002173d2d865cf4b932dea3864466b4
diff --git a/aten/src/ATen/core/interned_strings.h b/aten/src/ATen/core/interned_strings.h
index 185c18c..3741dcc 100644
--- a/aten/src/ATen/core/interned_strings.h
+++ b/aten/src/ATen/core/interned_strings.h
@@ -94,6 +94,7 @@
_(aten, backward) \
_(prim, Guard) \
_(prim, BailOut) \
+ _(prim, TypeCheck) \
_(prim, FusedConcat) \
_(prim, ConstantChunk) \
_(prim, MMTreeReduce) \
diff --git a/test/cpp/jit/test_interpreter.cpp b/test/cpp/jit/test_interpreter.cpp
index 4c3036e..0f424e1 100644
--- a/test/cpp/jit/test_interpreter.cpp
+++ b/test/cpp/jit/test_interpreter.cpp
@@ -1,9 +1,108 @@
#include "test/cpp/jit/test_base.h"
#include "test/cpp/jit/test_utils.h"
+#include <stdexcept>
namespace torch {
namespace jit {
+void testTypeCheck() {
+ {
+ auto graph = std::make_shared<Graph>();
+ std::unordered_map<std::string, Value*> vmap;
+ parseIR(
+ R"IR(
+graph(%a.1 : Tensor,
+ %b.1 : Tensor):
+ %t0 : Float(2:2, 2:1, device=cpu, requires_grad=1), %t1 : Float(3:3, 3:1), %type_matched : bool = prim::TypeCheck(%a.1, %b.1)
+ return (%t0, %t1, %type_matched)
+ )IR",
+ &*graph,
+ vmap);
+
+ Code function(graph, "");
+ InterpreterState interp(function);
+ {
+ // TypeCheck yields to true! Shape, grad and device matches.
+ auto a = at::zeros({2, 2}, at::kFloat);
+ auto b = at::ones({3, 3}, at::kFloat);
+ a.set_requires_grad(true);
+ a = a.to(at::kCPU);
+ std::vector<IValue> stack({a, b});
+ interp.run(stack);
+ ASSERT_TRUE(exactlyEqual(stack[0].toTensor(), a));
+ ASSERT_TRUE(exactlyEqual(stack[1].toTensor(), b));
+ ASSERT_TRUE(stack[2].toBool());
+ }
+ {
+ auto a = at::zeros({2, 2}, at::kFloat);
+ auto b = at::ones({2, 2}, at::kFloat); // Size mismatch
+ a.set_requires_grad(true);
+ a = a.to(at::kCPU);
+ std::vector<IValue> stack({a, b});
+ interp.run(stack);
+ ASSERT_FALSE(stack[2].toBool());
+ }
+ {
+ auto a = at::zeros({2, 2}, at::kFloat);
+ auto b = at::ones({3, 3}, at::kFloat);
+ a = a.to(at::kCPU);
+ a.set_requires_grad(false); // Gradient mismatch
+ std::vector<IValue> stack({a, b});
+ interp.run(stack);
+ ASSERT_FALSE(stack[2].toBool());
+ }
+ {
+ auto a = at::zeros({2, 2}, at::kFloat);
+ auto b = at::ones({3, 3}, at::kFloat);
+ a = a.to(at::kCPU);
+ a.set_requires_grad(true);
+ a = a.to(at::kInt); // Scalar type mismatch
+ std::vector<IValue> stack({a, b});
+ interp.run(stack);
+ ASSERT_FALSE(stack[2].toBool());
+ }
+ {
+ auto a = at::zeros({2, 2}, at::kFloat);
+ auto b = at::ones({3, 3}, at::kFloat);
+ a.set_requires_grad(true);
+ a = a.to(at::kCUDA); // Device mismatch
+ std::vector<IValue> stack({a, b});
+ interp.run(stack);
+ ASSERT_FALSE(stack[2].toBool());
+ }
+ }
+
+ try { // Test empty Typecheck raises an internal assertion
+ auto graph = std::make_shared<Graph>();
+ std::unordered_map<std::string, Value*> vmap;
+ parseIR(
+ R"IR(
+graph(%a.1 : Tensor,
+ %b.1 : Tensor):
+ %type_matched : bool = prim::TypeCheck()
+ return (%type_matched)
+ )IR",
+ &*graph,
+ vmap);
+ ASSERT_TRUE(false);
+ } catch (const std::exception& e) {
+ }
+ try { // Test for assertion if num_inputs + 1 != num_outputs
+ auto graph = std::make_shared<Graph>();
+ std::unordered_map<std::string, Value*> vmap;
+ parseIR(
+ R"IR(
+graph(%a.1 : Tensor,
+ %b.1 : Tensor):
+ %type_matched : bool = prim::TypeCheck(%a.1)
+ return (%type_matched)
+ )IR",
+ &*graph,
+ vmap);
+ ASSERT_TRUE(false);
+ } catch (const std::exception& e) {
+ }
+}
void testInterp() {
constexpr int batch_size = 4;
constexpr int input_size = 256;
@@ -23,7 +122,6 @@
auto outputs = run(lstm_interp, {input[0], hx, cx, w_ih, w_hh});
std::tie(hx, cx) = lstm(input[0], hx, cx, w_ih, w_hh);
- // std::cout << almostEqual(outputs[0],hx) << "\n";
ASSERT_TRUE(exactlyEqual(outputs[0], hx));
ASSERT_TRUE(exactlyEqual(outputs[1], cx));
}
diff --git a/test/cpp/jit/tests.h b/test/cpp/jit/tests.h
index 665969c..66dfb64 100644
--- a/test/cpp/jit/tests.h
+++ b/test/cpp/jit/tests.h
@@ -120,6 +120,7 @@
_(GraphExecutor) \
_(ModuleConversion) \
_(Interp) \
+ _(TypeCheck) \
_(GPU_IrGraphGenerator) \
_(GPU_FusionDispatch) \
_(GPU_FusionClear) \
@@ -225,7 +226,8 @@
_(Fusion) \
_(GraphExecutor) \
_(ModuleConversion) \
- _(Interp)
+ _(Interp) \
+ _(TypeCheck)
#endif
#define DECLARE_JIT_TEST(name) void test##name();
diff --git a/torch/csrc/jit/ir/alias_analysis.cpp b/torch/csrc/jit/ir/alias_analysis.cpp
index 11ca26e..aadfa8d 100644
--- a/torch/csrc/jit/ir/alias_analysis.cpp
+++ b/torch/csrc/jit/ir/alias_analysis.cpp
@@ -527,6 +527,13 @@
makePointerTo(node->output(), node->inputs().at(0));
}
return;
+ case prim::TypeCheck: {
+ auto num_inputs = node->inputs().size();
+ for (size_t i = 0; i < num_inputs; i++) {
+ makePointerTo(node->outputs().at(i), node->inputs().at(i));
+ }
+ return;
+ }
case prim::BailOut:
TORCH_INTERNAL_ASSERT(
node->inputs().at(0)->node()->kind() == prim::BailoutTemplate);
diff --git a/torch/csrc/jit/runtime/instruction.h b/torch/csrc/jit/runtime/instruction.h
index 7f6fc7a..8cfbb17 100644
--- a/torch/csrc/jit/runtime/instruction.h
+++ b/torch/csrc/jit/runtime/instruction.h
@@ -18,41 +18,42 @@
// S - index into object slots
// C - index into code table
-#define FORALL_OPCODES(_) \
- _(OP, "O") /* invoke operator X */ \
- _(OPN, "OI") /* invoke vararg operator X with N arguments */ \
- _(LOAD, "R") /* push a value from a register X */ \
- _(MOVE, "R") /* push a value from register X, clearing the register */ \
- _(STOREN, "RI") /* store N values to registers [X, X+N) */ \
- _(STORE, "R") /* store 1 value to registers X */ \
- _(DROP, "") /* drop 1 value from the top of the stack */ \
- _(DROPR, "R") /* clear register X */ \
- _(LOADC, "C") /* push the constant X */ \
- _(JF, "P") /* pop the top of the stack, if false, branch to P */ \
- _(JMP, "P") /* unconditional branch to X */ \
- _(LOOP, "PI") /* perform a loop, X is where to branch if cond is false */ \
- _(RET, "") /* exit execution */ \
- _(WAIT, "") /* wait for a future to be complete */ \
- _(CALL, "F") /* call function X */ \
- _(GUARD, "T") /* check a guard against type_table, true if passes */ \
- _(FAIL_GUARD, "T") /* fail a guard, patch back to GUARD */ \
- _(PROFILE_OP, "F") /* get a callback from profile_function_table at X */ \
- _(TAIL_CALL, "F") /* replace current frame with function F */ \
- _(INTERFACE_CALL, "CI") /* call method X on the first argument (of N) */ \
- _(GET_ATTR, "S") /* get attribute from slot X in an Object */ \
- _(SET_ATTR, "S") /* set attribute to slot X in an Object */ \
- _(LIST_UNPACK, "I") /* unpack list expecting length I */ \
- _(TUPLE_CONSTRUCT, "I") /* construct a tuple using X inputs */ \
- _(NAMED_TUPLE_CONSTRUCT, \
- "TI") /* construct a tuple of type X, using N inputs */ \
- _(LIST_CONSTRUCT, "TI") /* construct a list of type X, using N inputs */ \
- _(DICT_CONSTRUCT, "TI") /* construct a dict of type X, using N inputs */ \
- _(CREATE_OBJECT, "T") /* create an object of type X */ \
- _(ISINSTANCE, "TI") /* check object is one of types[X:X+N] */ \
- _(TUPLE_SLICE, "II") /* slice tup[X:(X+N)] */ \
- _(FORK, "CN") /* launch a thread to run code entry x with N inputs */ \
- _(WARN, "") /* emit a warning with line information */ \
- _(ENTER, "EN") /* enter scope of a contextmanager */ \
+#define FORALL_OPCODES(_) \
+ _(OP, "O") /* invoke operator X */ \
+ _(OPN, "OI") /* invoke vararg operator X with N arguments */ \
+ _(LOAD, "R") /* push a value from a register X */ \
+ _(MOVE, "R") /* push a value from register X, clearing the register */ \
+ _(STOREN, "RI") /* store N values to registers [X, X+N) */ \
+ _(STORE, "R") /* store 1 value to registers X */ \
+ _(DROP, "") /* drop 1 value from the top of the stack */ \
+ _(DROPR, "R") /* clear register X */ \
+ _(LOADC, "C") /* push the constant X */ \
+ _(JF, "P") /* pop the top of the stack, if false, branch to P */ \
+ _(JMP, "P") /* unconditional branch to X */ \
+ _(LOOP, "PI") /* perform a loop, X is where to branch if cond is false */ \
+ _(RET, "") /* exit execution */ \
+ _(WAIT, "") /* wait for a future to be complete */ \
+ _(CALL, "F") /* call function X */ \
+ _(GUARD, "T") /* check a guard against type_table, true if passes */ \
+ _(TYPECHECK, "TN") /* check each type of input[i] against type_table[X+N] */ \
+ _(FAIL_GUARD, "T") /* fail a guard, patch back to GUARD */ \
+ _(PROFILE_OP, "F") /* get a callback from profile_function_table at X */ \
+ _(TAIL_CALL, "F") /* replace current frame with function F */ \
+ _(INTERFACE_CALL, "CI") /* call method X on the first argument (of N) */ \
+ _(GET_ATTR, "S") /* get attribute from slot X in an Object */ \
+ _(SET_ATTR, "S") /* set attribute to slot X in an Object */ \
+ _(LIST_UNPACK, "I") /* unpack list expecting length I */ \
+ _(TUPLE_CONSTRUCT, "I") /* construct a tuple using X inputs */ \
+ _(NAMED_TUPLE_CONSTRUCT, \
+ "TI") /* construct a tuple of type X, using N inputs */ \
+ _(LIST_CONSTRUCT, "TI") /* construct a list of type X, using N inputs */ \
+ _(DICT_CONSTRUCT, "TI") /* construct a dict of type X, using N inputs */ \
+ _(CREATE_OBJECT, "T") /* create an object of type X */ \
+ _(ISINSTANCE, "TI") /* check object is one of types[X:X+N] */ \
+ _(TUPLE_SLICE, "II") /* slice tup[X:(X+N)] */ \
+ _(FORK, "CN") /* launch a thread to run code entry x with N inputs */ \
+ _(WARN, "") /* emit a warning with line information */ \
+ _(ENTER, "EN") /* enter scope of a contextmanager */ \
_(EXIT, "EX") /* exit the last entered contextmanager */
enum OpCode : uint8_t {
diff --git a/torch/csrc/jit/runtime/interpreter.cpp b/torch/csrc/jit/runtime/interpreter.cpp
index 928d113..52c32c0 100644
--- a/torch/csrc/jit/runtime/interpreter.cpp
+++ b/torch/csrc/jit/runtime/interpreter.cpp
@@ -701,6 +701,22 @@
return r;
}
+ void emitTypeCheck(Node* node) {
+ auto num_inputs = node->inputs().size();
+
+ // Check that TypeCheck has at least one input.
+ TORCH_INTERNAL_ASSERT(
+ num_inputs && num_inputs + 1 == node->outputs().size());
+ emitLoadInputs(node->inputs());
+
+ // Emit the expected type.
+ size_t types_start = type_table_.size();
+ for (size_t i = 0; i < num_inputs; i++) {
+ emitType(node->outputs()[i]->type());
+ }
+ insertInstruction(TYPECHECK, types_start, num_inputs);
+ }
+
size_t emitGuard(Node* node) {
// unoptimized graph is at index 0
// guarded input is at index 1
@@ -880,6 +896,9 @@
emitInterfaceCall(node->s(attr::name), node->inputs());
}
break;
+ case prim::TypeCheck:
+ emitTypeCheck(node);
+ break;
case prim::BailOut:
emitBailOut(node);
break;
@@ -1345,6 +1364,30 @@
++af.pc;
break;
}
+ case TYPECHECK: {
+ int num_inputs = inst.N, i = 0;
+ TORCH_INTERNAL_ASSERT(stack.size() >= num_inputs && num_inputs > 0);
+ // Check every input's shape against profiled (expected) shape.
+ for (i = 0; i < num_inputs; i++) {
+ auto& input = peek(stack, i, num_inputs);
+ TORCH_INTERNAL_ASSERT(input.isTensor());
+ auto t = input.toTensor();
+ const TypePtr& expected = af.types[inst.X + i];
+ auto expected_type = expected->cast<TensorType>();
+ if (t.defined() &&
+ (!frames.back().symbols2dims.bindSymbolicShapes(
+ t.sizes(), expected_type->symbolic_sizes()) ||
+ !expected_type->matchTensor(t))) {
+ push(stack, false);
+ break;
+ }
+ }
+ if (i == num_inputs) {
+ push(stack, true);
+ }
+ ++af.pc;
+ break;
+ }
case GUARD: {
if (!stack.back().isTensor()) {
// stack.back() is an Uninitialized IValue and this is a guard
diff --git a/torch/csrc/jit/runtime/operator.cpp b/torch/csrc/jit/runtime/operator.cpp
index 04a9261..2d11b69 100644
--- a/torch/csrc/jit/runtime/operator.cpp
+++ b/torch/csrc/jit/runtime/operator.cpp
@@ -239,6 +239,7 @@
prim::MMBatchSide, // used as an optimization
prim::Store, // used in interpreter only
prim::profile, // used in interpreter only
+ prim::TypeCheck, // used in interpreter only
};
@@ -293,6 +294,7 @@
prim::GetAttr,
prim::SetAttr,
prim::profile,
+ prim::TypeCheck,
prim::Print,
prim::CallFunction,
prim::CallMethod,
diff --git a/torch/csrc/jit/runtime/register_prim_ops_fulljit.cpp b/torch/csrc/jit/runtime/register_prim_ops_fulljit.cpp
index c8ac72b..e47ab44 100644
--- a/torch/csrc/jit/runtime/register_prim_ops_fulljit.cpp
+++ b/torch/csrc/jit/runtime/register_prim_ops_fulljit.cpp
@@ -46,6 +46,14 @@
},
aliasAnalysisSpecialCase()),
Operator(
+ prim::TypeCheck /* (...) -> (..., bool) */,
+ [](const Node * /* node */) -> Operation {
+ return [](Stack* /* stack */) {
+ AT_ERROR("prim::TypeCheck not yet implemented"); // NOLINT
+ };
+ },
+ aliasAnalysisSpecialCase()),
+ Operator(
"prim::Guard(Tensor(a) t) -> Tensor(a)",
[](Stack* stack) { AT_ERROR("Should be replaced by prim::BailOut"); },
aliasAnalysisFromSchema()),