Add prim::TypeCheck operation (#43026)

Summary:
TypeCheck is a new operation to check the shape of tensors against
 expectd shapes. TypeCheck is a variadic operation. An example,

 %t0 : Tensor = ...
 %t1 : Tensor = ...
 %2 : FLOAT(20, 20), %3 : FLOAT(30, 30), %1 : bool =
 prim::TypeCheck(%t1, %t2)
 prim::If(%1)

Fixes #{issue number}

Pull Request resolved: https://github.com/pytorch/pytorch/pull/43026

Reviewed By: ZolotukhinM

Differential Revision: D23115830

Pulled By: bzinodev

fbshipit-source-id: fbf142126002173d2d865cf4b932dea3864466b4
diff --git a/aten/src/ATen/core/interned_strings.h b/aten/src/ATen/core/interned_strings.h
index 185c18c..3741dcc 100644
--- a/aten/src/ATen/core/interned_strings.h
+++ b/aten/src/ATen/core/interned_strings.h
@@ -94,6 +94,7 @@
   _(aten, backward)                  \
   _(prim, Guard)                     \
   _(prim, BailOut)                   \
+  _(prim, TypeCheck)                 \
   _(prim, FusedConcat)               \
   _(prim, ConstantChunk)             \
   _(prim, MMTreeReduce)              \
diff --git a/test/cpp/jit/test_interpreter.cpp b/test/cpp/jit/test_interpreter.cpp
index 4c3036e..0f424e1 100644
--- a/test/cpp/jit/test_interpreter.cpp
+++ b/test/cpp/jit/test_interpreter.cpp
@@ -1,9 +1,108 @@
 #include "test/cpp/jit/test_base.h"
 #include "test/cpp/jit/test_utils.h"
 
+#include <stdexcept>
 namespace torch {
 namespace jit {
 
+void testTypeCheck() {
+  {
+    auto graph = std::make_shared<Graph>();
+    std::unordered_map<std::string, Value*> vmap;
+    parseIR(
+        R"IR(
+graph(%a.1 : Tensor,
+      %b.1 : Tensor):
+  %t0 : Float(2:2, 2:1, device=cpu, requires_grad=1), %t1 : Float(3:3, 3:1), %type_matched : bool = prim::TypeCheck(%a.1, %b.1)
+  return (%t0, %t1, %type_matched)
+  )IR",
+        &*graph,
+        vmap);
+
+    Code function(graph, "");
+    InterpreterState interp(function);
+    {
+      // TypeCheck yields to true! Shape, grad and device matches.
+      auto a = at::zeros({2, 2}, at::kFloat);
+      auto b = at::ones({3, 3}, at::kFloat);
+      a.set_requires_grad(true);
+      a = a.to(at::kCPU);
+      std::vector<IValue> stack({a, b});
+      interp.run(stack);
+      ASSERT_TRUE(exactlyEqual(stack[0].toTensor(), a));
+      ASSERT_TRUE(exactlyEqual(stack[1].toTensor(), b));
+      ASSERT_TRUE(stack[2].toBool());
+    }
+    {
+      auto a = at::zeros({2, 2}, at::kFloat);
+      auto b = at::ones({2, 2}, at::kFloat); // Size mismatch
+      a.set_requires_grad(true);
+      a = a.to(at::kCPU);
+      std::vector<IValue> stack({a, b});
+      interp.run(stack);
+      ASSERT_FALSE(stack[2].toBool());
+    }
+    {
+      auto a = at::zeros({2, 2}, at::kFloat);
+      auto b = at::ones({3, 3}, at::kFloat);
+      a = a.to(at::kCPU);
+      a.set_requires_grad(false); // Gradient mismatch
+      std::vector<IValue> stack({a, b});
+      interp.run(stack);
+      ASSERT_FALSE(stack[2].toBool());
+    }
+    {
+      auto a = at::zeros({2, 2}, at::kFloat);
+      auto b = at::ones({3, 3}, at::kFloat);
+      a = a.to(at::kCPU);
+      a.set_requires_grad(true);
+      a = a.to(at::kInt); // Scalar type mismatch
+      std::vector<IValue> stack({a, b});
+      interp.run(stack);
+      ASSERT_FALSE(stack[2].toBool());
+    }
+    {
+      auto a = at::zeros({2, 2}, at::kFloat);
+      auto b = at::ones({3, 3}, at::kFloat);
+      a.set_requires_grad(true);
+      a = a.to(at::kCUDA); // Device mismatch
+      std::vector<IValue> stack({a, b});
+      interp.run(stack);
+      ASSERT_FALSE(stack[2].toBool());
+    }
+  }
+
+  try { // Test empty Typecheck raises an internal assertion
+    auto graph = std::make_shared<Graph>();
+    std::unordered_map<std::string, Value*> vmap;
+    parseIR(
+        R"IR(
+graph(%a.1 : Tensor,
+      %b.1 : Tensor):
+  %type_matched : bool = prim::TypeCheck()
+  return (%type_matched)
+  )IR",
+        &*graph,
+        vmap);
+    ASSERT_TRUE(false);
+  } catch (const std::exception& e) {
+  }
+  try { // Test for assertion if num_inputs + 1 != num_outputs
+    auto graph = std::make_shared<Graph>();
+    std::unordered_map<std::string, Value*> vmap;
+    parseIR(
+        R"IR(
+graph(%a.1 : Tensor,
+      %b.1 : Tensor):
+  %type_matched : bool = prim::TypeCheck(%a.1)
+  return (%type_matched)
+  )IR",
+        &*graph,
+        vmap);
+    ASSERT_TRUE(false);
+  } catch (const std::exception& e) {
+  }
+}
 void testInterp() {
   constexpr int batch_size = 4;
   constexpr int input_size = 256;
@@ -23,7 +122,6 @@
   auto outputs = run(lstm_interp, {input[0], hx, cx, w_ih, w_hh});
   std::tie(hx, cx) = lstm(input[0], hx, cx, w_ih, w_hh);
 
-  // std::cout << almostEqual(outputs[0],hx) << "\n";
   ASSERT_TRUE(exactlyEqual(outputs[0], hx));
   ASSERT_TRUE(exactlyEqual(outputs[1], cx));
 }
diff --git a/test/cpp/jit/tests.h b/test/cpp/jit/tests.h
index 665969c..66dfb64 100644
--- a/test/cpp/jit/tests.h
+++ b/test/cpp/jit/tests.h
@@ -120,6 +120,7 @@
   _(GraphExecutor)                                  \
   _(ModuleConversion)                               \
   _(Interp)                                         \
+  _(TypeCheck)                                      \
   _(GPU_IrGraphGenerator)                           \
   _(GPU_FusionDispatch)                             \
   _(GPU_FusionClear)                                \
@@ -225,7 +226,8 @@
   _(Fusion)                     \
   _(GraphExecutor)              \
   _(ModuleConversion)           \
-  _(Interp)
+  _(Interp)                     \
+  _(TypeCheck)
 #endif
 
 #define DECLARE_JIT_TEST(name) void test##name();
diff --git a/torch/csrc/jit/ir/alias_analysis.cpp b/torch/csrc/jit/ir/alias_analysis.cpp
index 11ca26e..aadfa8d 100644
--- a/torch/csrc/jit/ir/alias_analysis.cpp
+++ b/torch/csrc/jit/ir/alias_analysis.cpp
@@ -527,6 +527,13 @@
         makePointerTo(node->output(), node->inputs().at(0));
       }
       return;
+    case prim::TypeCheck: {
+      auto num_inputs = node->inputs().size();
+      for (size_t i = 0; i < num_inputs; i++) {
+        makePointerTo(node->outputs().at(i), node->inputs().at(i));
+      }
+      return;
+    }
     case prim::BailOut:
       TORCH_INTERNAL_ASSERT(
           node->inputs().at(0)->node()->kind() == prim::BailoutTemplate);
diff --git a/torch/csrc/jit/runtime/instruction.h b/torch/csrc/jit/runtime/instruction.h
index 7f6fc7a..8cfbb17 100644
--- a/torch/csrc/jit/runtime/instruction.h
+++ b/torch/csrc/jit/runtime/instruction.h
@@ -18,41 +18,42 @@
 // S - index into object slots
 // C - index into code table
 
-#define FORALL_OPCODES(_)                                                   \
-  _(OP, "O") /* invoke operator X */                                        \
-  _(OPN, "OI") /* invoke vararg operator X with N arguments */              \
-  _(LOAD, "R") /* push a value from a register X */                         \
-  _(MOVE, "R") /* push a value from register X, clearing the register */    \
-  _(STOREN, "RI") /* store N values to registers [X, X+N) */                \
-  _(STORE, "R") /* store 1 value to registers X */                          \
-  _(DROP, "") /* drop 1 value from the top of the stack */                  \
-  _(DROPR, "R") /* clear register X */                                      \
-  _(LOADC, "C") /* push the constant X */                                   \
-  _(JF, "P") /* pop the top of the stack, if false, branch to P */          \
-  _(JMP, "P") /* unconditional branch to X */                               \
-  _(LOOP, "PI") /* perform a loop, X is where to branch if cond is false */ \
-  _(RET, "") /* exit execution */                                           \
-  _(WAIT, "") /* wait for a future to be complete */                        \
-  _(CALL, "F") /* call function X */                                        \
-  _(GUARD, "T") /* check a guard against type_table, true if passes */      \
-  _(FAIL_GUARD, "T") /* fail a guard, patch back to GUARD */                \
-  _(PROFILE_OP, "F") /* get a callback from profile_function_table at X */  \
-  _(TAIL_CALL, "F") /* replace current frame with function F */             \
-  _(INTERFACE_CALL, "CI") /* call method X on the first argument (of N) */  \
-  _(GET_ATTR, "S") /* get attribute from slot X in an Object */             \
-  _(SET_ATTR, "S") /* set attribute to slot X in an Object */               \
-  _(LIST_UNPACK, "I") /* unpack list expecting length I */                  \
-  _(TUPLE_CONSTRUCT, "I") /* construct a tuple using X inputs */            \
-  _(NAMED_TUPLE_CONSTRUCT,                                                  \
-    "TI") /* construct a tuple of type X, using N inputs */                 \
-  _(LIST_CONSTRUCT, "TI") /* construct a list of type X, using N inputs */  \
-  _(DICT_CONSTRUCT, "TI") /* construct a dict of type X, using N inputs */  \
-  _(CREATE_OBJECT, "T") /* create an object of type X */                    \
-  _(ISINSTANCE, "TI") /* check object is one of  types[X:X+N]  */           \
-  _(TUPLE_SLICE, "II") /* slice tup[X:(X+N)] */                             \
-  _(FORK, "CN") /* launch a thread to run code entry x with N inputs  */    \
-  _(WARN, "") /* emit a warning with line information */                    \
-  _(ENTER, "EN") /* enter scope of a contextmanager */                      \
+#define FORALL_OPCODES(_)                                                      \
+  _(OP, "O") /* invoke operator X */                                           \
+  _(OPN, "OI") /* invoke vararg operator X with N arguments */                 \
+  _(LOAD, "R") /* push a value from a register X */                            \
+  _(MOVE, "R") /* push a value from register X, clearing the register */       \
+  _(STOREN, "RI") /* store N values to registers [X, X+N) */                   \
+  _(STORE, "R") /* store 1 value to registers X */                             \
+  _(DROP, "") /* drop 1 value from the top of the stack */                     \
+  _(DROPR, "R") /* clear register X */                                         \
+  _(LOADC, "C") /* push the constant X */                                      \
+  _(JF, "P") /* pop the top of the stack, if false, branch to P */             \
+  _(JMP, "P") /* unconditional branch to X */                                  \
+  _(LOOP, "PI") /* perform a loop, X is where to branch if cond is false */    \
+  _(RET, "") /* exit execution */                                              \
+  _(WAIT, "") /* wait for a future to be complete */                           \
+  _(CALL, "F") /* call function X */                                           \
+  _(GUARD, "T") /* check a guard against type_table, true if passes */         \
+  _(TYPECHECK, "TN") /* check each type of input[i] against type_table[X+N] */ \
+  _(FAIL_GUARD, "T") /* fail a guard, patch back to GUARD */                   \
+  _(PROFILE_OP, "F") /* get a callback from profile_function_table at X */     \
+  _(TAIL_CALL, "F") /* replace current frame with function F */                \
+  _(INTERFACE_CALL, "CI") /* call method X on the first argument (of N) */     \
+  _(GET_ATTR, "S") /* get attribute from slot X in an Object */                \
+  _(SET_ATTR, "S") /* set attribute to slot X in an Object */                  \
+  _(LIST_UNPACK, "I") /* unpack list expecting length I */                     \
+  _(TUPLE_CONSTRUCT, "I") /* construct a tuple using X inputs */               \
+  _(NAMED_TUPLE_CONSTRUCT,                                                     \
+    "TI") /* construct a tuple of type X, using N inputs */                    \
+  _(LIST_CONSTRUCT, "TI") /* construct a list of type X, using N inputs */     \
+  _(DICT_CONSTRUCT, "TI") /* construct a dict of type X, using N inputs */     \
+  _(CREATE_OBJECT, "T") /* create an object of type X */                       \
+  _(ISINSTANCE, "TI") /* check object is one of  types[X:X+N]  */              \
+  _(TUPLE_SLICE, "II") /* slice tup[X:(X+N)] */                                \
+  _(FORK, "CN") /* launch a thread to run code entry x with N inputs  */       \
+  _(WARN, "") /* emit a warning with line information */                       \
+  _(ENTER, "EN") /* enter scope of a contextmanager */                         \
   _(EXIT, "EX") /* exit the last entered contextmanager */
 
 enum OpCode : uint8_t {
diff --git a/torch/csrc/jit/runtime/interpreter.cpp b/torch/csrc/jit/runtime/interpreter.cpp
index 928d113..52c32c0 100644
--- a/torch/csrc/jit/runtime/interpreter.cpp
+++ b/torch/csrc/jit/runtime/interpreter.cpp
@@ -701,6 +701,22 @@
     return r;
   }
 
+  void emitTypeCheck(Node* node) {
+    auto num_inputs = node->inputs().size();
+
+    // Check that TypeCheck has at least one input.
+    TORCH_INTERNAL_ASSERT(
+        num_inputs && num_inputs + 1 == node->outputs().size());
+    emitLoadInputs(node->inputs());
+
+    // Emit the expected type.
+    size_t types_start = type_table_.size();
+    for (size_t i = 0; i < num_inputs; i++) {
+      emitType(node->outputs()[i]->type());
+    }
+    insertInstruction(TYPECHECK, types_start, num_inputs);
+  }
+
   size_t emitGuard(Node* node) {
     // unoptimized graph is at index 0
     // guarded input is at index 1
@@ -880,6 +896,9 @@
           emitInterfaceCall(node->s(attr::name), node->inputs());
         }
         break;
+      case prim::TypeCheck:
+        emitTypeCheck(node);
+        break;
       case prim::BailOut:
         emitBailOut(node);
         break;
@@ -1345,6 +1364,30 @@
             ++af.pc;
             break;
           }
+          case TYPECHECK: {
+            int num_inputs = inst.N, i = 0;
+            TORCH_INTERNAL_ASSERT(stack.size() >= num_inputs && num_inputs > 0);
+            // Check every input's shape against profiled (expected) shape.
+            for (i = 0; i < num_inputs; i++) {
+              auto& input = peek(stack, i, num_inputs);
+              TORCH_INTERNAL_ASSERT(input.isTensor());
+              auto t = input.toTensor();
+              const TypePtr& expected = af.types[inst.X + i];
+              auto expected_type = expected->cast<TensorType>();
+              if (t.defined() &&
+                  (!frames.back().symbols2dims.bindSymbolicShapes(
+                       t.sizes(), expected_type->symbolic_sizes()) ||
+                   !expected_type->matchTensor(t))) {
+                push(stack, false);
+                break;
+              }
+            }
+            if (i == num_inputs) {
+              push(stack, true);
+            }
+            ++af.pc;
+            break;
+          }
           case GUARD: {
             if (!stack.back().isTensor()) {
               // stack.back() is an Uninitialized IValue and this is a guard
diff --git a/torch/csrc/jit/runtime/operator.cpp b/torch/csrc/jit/runtime/operator.cpp
index 04a9261..2d11b69 100644
--- a/torch/csrc/jit/runtime/operator.cpp
+++ b/torch/csrc/jit/runtime/operator.cpp
@@ -239,6 +239,7 @@
       prim::MMBatchSide, // used as an optimization
       prim::Store, // used in interpreter only
       prim::profile, // used in interpreter only
+      prim::TypeCheck, // used in interpreter only
 
   };
 
@@ -293,6 +294,7 @@
       prim::GetAttr,
       prim::SetAttr,
       prim::profile,
+      prim::TypeCheck,
       prim::Print,
       prim::CallFunction,
       prim::CallMethod,
diff --git a/torch/csrc/jit/runtime/register_prim_ops_fulljit.cpp b/torch/csrc/jit/runtime/register_prim_ops_fulljit.cpp
index c8ac72b..e47ab44 100644
--- a/torch/csrc/jit/runtime/register_prim_ops_fulljit.cpp
+++ b/torch/csrc/jit/runtime/register_prim_ops_fulljit.cpp
@@ -46,6 +46,14 @@
          },
          aliasAnalysisSpecialCase()),
      Operator(
+         prim::TypeCheck /* (...)  -> (..., bool) */,
+         [](const Node * /* node */) -> Operation {
+           return [](Stack* /* stack */) {
+             AT_ERROR("prim::TypeCheck not yet implemented"); // NOLINT
+           };
+         },
+         aliasAnalysisSpecialCase()),
+     Operator(
          "prim::Guard(Tensor(a) t) -> Tensor(a)",
          [](Stack* stack) { AT_ERROR("Should be replaced by prim::BailOut"); },
          aliasAnalysisFromSchema()),