| #include <gtest/gtest.h> | 
 |  | 
 | #include <torch/csrc/jit/ir/ir.h> | 
 | #include <torch/csrc/jit/runtime/custom_operator.h> | 
 | #include <torch/csrc/jit/testing/file_check.h> | 
 | #include <torch/jit.h> | 
 |  | 
 | #include <sstream> | 
 | #include <string> | 
 |  | 
 | namespace torch { | 
 | namespace jit { | 
 |  | 
 | TEST(SchemaMatchingTest, VarType) { | 
 |   RegisterOperators reg({ | 
 |       Operator( | 
 |           "aten::test_vartype(t[] a, t b) -> (t)", | 
 |           [](Stack& stack) { | 
 |             c10::List<double> list; | 
 |             // NOLINTNEXTLINE(cppcoreguidelines-init-variables) | 
 |             double a; | 
 |             pop(stack, list, a); | 
 |             push(stack, a); | 
 |           }, | 
 |           c10::AliasAnalysisKind::FROM_SCHEMA), | 
 |   }); | 
 |   Module m("m"); | 
 |   m.define(R"( | 
 |       def test(self): | 
 |         a = (1.0, 2.0) | 
 |         return torch.test_vartype(a, 2.0) | 
 |     )"); | 
 |   auto result = m.run_method("test"); | 
 |   TORCH_INTERNAL_ASSERT(result.toDouble() == 2.0); | 
 |  | 
 |   const std::string error_example = R"JIT( | 
 |       def test_2(self): | 
 |           a = (1.0, 2.0) | 
 |           non_float = (1, 1) | 
 |           return torch.test_vartype(a, non_float) | 
 |     )JIT"; | 
 |  | 
 |   std::string err = ""; | 
 |   try { | 
 |     m.define(error_example); | 
 |   } catch (const std::exception& e) { | 
 |     err = e.what(); | 
 |   } | 
 |   TORCH_INTERNAL_ASSERT( | 
 |       err.find("previously matched to type") != std::string::npos); | 
 | } | 
 |  | 
 | TEST(SchemaMatchingTest, VarType2) { | 
 |   RegisterOperators reg({ | 
 |       Operator( | 
 |           "aten::test_vartype2(t a, t[] b) -> (t[])", | 
 |           [](Stack& stack) { | 
 |             // NOLINTNEXTLINE(cppcoreguidelines-init-variables) | 
 |             double a; | 
 |             c10::List<double> list; | 
 |             pop(stack, a, list); | 
 |             push(stack, a); | 
 |           }, | 
 |           AliasAnalysisKind::FROM_SCHEMA), | 
 |   }); | 
 |   Module m("m"); | 
 |   m.define(R"JIT( | 
 |       def test(self): | 
 |           a = (1.0, 2.0) | 
 |           return torch.test_vartype2(3.0, a) | 
 |     )JIT"); | 
 |   auto result = m.run_method("test"); | 
 |   TORCH_INTERNAL_ASSERT(result.toDouble() == 3.0); | 
 |  | 
 |   static const auto error_exam2 = R"JIT( | 
 |       def test_2(self): | 
 |           a = (1, 2) | 
 |           return torch.test_vartype2(3.0, a) | 
 |     )JIT"; | 
 |  | 
 |   std::string err = ""; | 
 |   try { | 
 |     m.define(error_exam2); | 
 |   } catch (const std::exception& e) { | 
 |     err = e.what(); | 
 |   } | 
 |   TORCH_INTERNAL_ASSERT( | 
 |       err.find("previously matched to type") != std::string::npos); | 
 | } | 
 | } // namespace jit | 
 | } // namespace torch |