Add test operator in upgrader entry (#69427)

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/69427

Test Plan: Imported from OSS

Reviewed By: gmagogsfm

Differential Revision: D32867984

Pulled By: tugsbayasgalan

fbshipit-source-id: 25810fc2fd4b943911f950618968af067c04da5c
diff --git a/test/jit/test_upgraders.py b/test/jit/test_upgraders.py
index a02dfb2..aaea353 100644
--- a/test/jit/test_upgraders.py
+++ b/test/jit/test_upgraders.py
@@ -37,6 +37,23 @@
         self.assertTrue(upgraders_size == upgraders_size_second_time)
         self.assertTrue(upgraders_dump == upgraders_dump_second_time)
 
+    def test_add_value_to_version_map(self):
+        map_before_test = torch._C._get_operator_version_map()
+
+        upgrader_bumped_version = 3
+        upgrader_name = "_test_serialization_subcmul_0_2"
+        upgrader_schema = "aten::_test_serialization_subcmul(Tensor self, Tensor other, Scalar alpha=2) -> Tensor"
+        dummy_entry = torch._C._UpgraderEntry(upgrader_bumped_version, upgrader_name, upgrader_schema)
+
+        torch._C._test_only_add_entry_to_op_version_map("aten::_test_serialization_subcmul", dummy_entry)
+        map_after_test = torch._C._get_operator_version_map()
+        self.assertTrue("aten::_test_serialization_subcmul" in map_after_test)
+        self.assertTrue(len(map_after_test) - len(map_before_test) == 1)
+        torch._C._test_only_remove_entry_to_op_version_map("aten::_test_serialization_subcmul")
+        map_after_remove_test = torch._C._get_operator_version_map()
+        self.assertTrue("aten::_test_serialization_subcmul" not in map_after_remove_test)
+        self.assertEqual(len(map_after_remove_test), len(map_before_test))
+
     def test_populated_test_upgrader_graph(self):
         @torch.jit.script
         def f():
diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in
index e281626..f77803c 100644
--- a/torch/_C/__init__.pyi.in
+++ b/torch/_C/__init__.pyi.in
@@ -455,9 +455,14 @@
     bumped_at_version: _int
     upgrader_name: str
     old_schema: str
+    def __init__(self, bumped_at_version: _int, upgrader_name: str, old_schema: str) -> None: ...
 
 def _get_operator_version_map() -> Dict[str, List[_UpgraderEntry]]: ...
 
+def _test_only_add_entry_to_op_version(op_name: str, entry: _UpgraderEntry) -> None: ...
+
+def _test_only_remove_entry_to_op_version(op_name: str) -> None: ...
+
 # Defined in torch/csrc/jit/python/script_init.cpp
 class ScriptModuleSerializer(object):
     def __init__(self, export_writer: PyTorchFileWriter) -> None: ...
diff --git a/torch/csrc/jit/operator_upgraders/version_map.h b/torch/csrc/jit/operator_upgraders/version_map.h
index b4b7539..d48d1a2 100644
--- a/torch/csrc/jit/operator_upgraders/version_map.h
+++ b/torch/csrc/jit/operator_upgraders/version_map.h
@@ -12,7 +12,7 @@
   std::string old_schema;
 };
 
-const static std::unordered_map<std::string, std::vector<UpgraderEntry>> kOperatorVersionMap(
+static std::unordered_map<std::string, std::vector<UpgraderEntry>> operatorVersionMap(
     {{"aten::div.Tensor",
       {{4,
         "div_Tensor_0_3",
@@ -42,9 +42,17 @@
         "full_out_0_4",
         "aten::full.out(int[] size, Scalar fill_value, *, Tensor(a!) out) -> Tensor(a!)"}}}});
 
-std::unordered_map<std::string, std::vector<UpgraderEntry>>
+const std::unordered_map<std::string, std::vector<UpgraderEntry>>&
 get_operator_version_map() {
-  return kOperatorVersionMap;
+  return operatorVersionMap;
+}
+
+void test_only_add_entry(std::string op_name, UpgraderEntry entry) {
+  operatorVersionMap[op_name].push_back(entry);
+}
+
+void test_only_remove_entry(std::string op_name) {
+  operatorVersionMap.erase(op_name);
 }
 
 } // namespace jit
diff --git a/torch/csrc/jit/python/script_init.cpp b/torch/csrc/jit/python/script_init.cpp
index 4345db0..a484b99 100644
--- a/torch/csrc/jit/python/script_init.cpp
+++ b/torch/csrc/jit/python/script_init.cpp
@@ -1017,6 +1017,7 @@
   py::class_<DeepCopyMemoTable>(m, "DeepCopyMemoTable");
 
   py::class_<UpgraderEntry>(m, "_UpgraderEntry")
+      .def(py::init<int, std::string, std::string>())
       .def_property_readonly(
           "bumped_at_version",
           [](const UpgraderEntry& self) { return self.bumped_at_version; })
@@ -1738,6 +1739,8 @@
 
   m.def("merge_type_from_type_comment", &mergeTypesFromTypeComment);
   m.def("_get_operator_version_map", &get_operator_version_map);
+  m.def("_test_only_add_entry_to_op_version_map", &test_only_add_entry);
+  m.def("_test_only_remove_entry_to_op_version_map", &test_only_remove_entry);
   m.def(
       "import_ir_module",
       [](std::shared_ptr<CompilationUnit> cu,
diff --git a/torch/jit/operator_upgraders.py b/torch/jit/operator_upgraders.py
index 75a18da..dc2bb65 100644
--- a/torch/jit/operator_upgraders.py
+++ b/torch/jit/operator_upgraders.py
@@ -86,7 +86,9 @@
     # in the torch/csrc/operator_upgraders/version_map.h
 
     entries = globals()
-    version_map = torch._C._get_operator_version_map()
+    # ignore test operators
+    version_map = {k : v for k, v in torch._C._get_operator_version_map().items()
+                   if not k.startswith("aten::_test")}
 
     # 1. Check if everything in version_map.h is defined here
     available_upgraders_in_version_map = set()
@@ -101,7 +103,7 @@
     for entry in entries:
         if isinstance(entries[entry], torch.jit.ScriptFunction):
             if entry not in available_upgraders_in_version_map:
-                raise AssertionError("The upgrader {} is not registered in the version_map.h")
+                raise AssertionError("The upgrader {} is not registered in the version_map.h".format(entry))
 
     return available_upgraders_in_version_map