Add test operator in upgrader entry (#69427)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/69427
Test Plan: Imported from OSS
Reviewed By: gmagogsfm
Differential Revision: D32867984
Pulled By: tugsbayasgalan
fbshipit-source-id: 25810fc2fd4b943911f950618968af067c04da5c
diff --git a/test/jit/test_upgraders.py b/test/jit/test_upgraders.py
index a02dfb2..aaea353 100644
--- a/test/jit/test_upgraders.py
+++ b/test/jit/test_upgraders.py
@@ -37,6 +37,23 @@
self.assertTrue(upgraders_size == upgraders_size_second_time)
self.assertTrue(upgraders_dump == upgraders_dump_second_time)
+ def test_add_value_to_version_map(self):
+ map_before_test = torch._C._get_operator_version_map()
+
+ upgrader_bumped_version = 3
+ upgrader_name = "_test_serialization_subcmul_0_2"
+ upgrader_schema = "aten::_test_serialization_subcmul(Tensor self, Tensor other, Scalar alpha=2) -> Tensor"
+ dummy_entry = torch._C._UpgraderEntry(upgrader_bumped_version, upgrader_name, upgrader_schema)
+
+ torch._C._test_only_add_entry_to_op_version_map("aten::_test_serialization_subcmul", dummy_entry)
+ map_after_test = torch._C._get_operator_version_map()
+ self.assertTrue("aten::_test_serialization_subcmul" in map_after_test)
+ self.assertTrue(len(map_after_test) - len(map_before_test) == 1)
+ torch._C._test_only_remove_entry_to_op_version_map("aten::_test_serialization_subcmul")
+ map_after_remove_test = torch._C._get_operator_version_map()
+ self.assertTrue("aten::_test_serialization_subcmul" not in map_after_remove_test)
+ self.assertEqual(len(map_after_remove_test), len(map_before_test))
+
def test_populated_test_upgrader_graph(self):
@torch.jit.script
def f():
diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in
index e281626..f77803c 100644
--- a/torch/_C/__init__.pyi.in
+++ b/torch/_C/__init__.pyi.in
@@ -455,9 +455,14 @@
bumped_at_version: _int
upgrader_name: str
old_schema: str
+ def __init__(self, bumped_at_version: _int, upgrader_name: str, old_schema: str) -> None: ...
def _get_operator_version_map() -> Dict[str, List[_UpgraderEntry]]: ...
+def _test_only_add_entry_to_op_version(op_name: str, entry: _UpgraderEntry) -> None: ...
+
+def _test_only_remove_entry_to_op_version(op_name: str) -> None: ...
+
# Defined in torch/csrc/jit/python/script_init.cpp
class ScriptModuleSerializer(object):
def __init__(self, export_writer: PyTorchFileWriter) -> None: ...
diff --git a/torch/csrc/jit/operator_upgraders/version_map.h b/torch/csrc/jit/operator_upgraders/version_map.h
index b4b7539..d48d1a2 100644
--- a/torch/csrc/jit/operator_upgraders/version_map.h
+++ b/torch/csrc/jit/operator_upgraders/version_map.h
@@ -12,7 +12,7 @@
std::string old_schema;
};
-const static std::unordered_map<std::string, std::vector<UpgraderEntry>> kOperatorVersionMap(
+static std::unordered_map<std::string, std::vector<UpgraderEntry>> operatorVersionMap(
{{"aten::div.Tensor",
{{4,
"div_Tensor_0_3",
@@ -42,9 +42,17 @@
"full_out_0_4",
"aten::full.out(int[] size, Scalar fill_value, *, Tensor(a!) out) -> Tensor(a!)"}}}});
-std::unordered_map<std::string, std::vector<UpgraderEntry>>
+const std::unordered_map<std::string, std::vector<UpgraderEntry>>&
get_operator_version_map() {
- return kOperatorVersionMap;
+ return operatorVersionMap;
+}
+
+void test_only_add_entry(std::string op_name, UpgraderEntry entry) {
+ operatorVersionMap[op_name].push_back(entry);
+}
+
+void test_only_remove_entry(std::string op_name) {
+ operatorVersionMap.erase(op_name);
}
} // namespace jit
diff --git a/torch/csrc/jit/python/script_init.cpp b/torch/csrc/jit/python/script_init.cpp
index 4345db0..a484b99 100644
--- a/torch/csrc/jit/python/script_init.cpp
+++ b/torch/csrc/jit/python/script_init.cpp
@@ -1017,6 +1017,7 @@
py::class_<DeepCopyMemoTable>(m, "DeepCopyMemoTable");
py::class_<UpgraderEntry>(m, "_UpgraderEntry")
+ .def(py::init<int, std::string, std::string>())
.def_property_readonly(
"bumped_at_version",
[](const UpgraderEntry& self) { return self.bumped_at_version; })
@@ -1738,6 +1739,8 @@
m.def("merge_type_from_type_comment", &mergeTypesFromTypeComment);
m.def("_get_operator_version_map", &get_operator_version_map);
+ m.def("_test_only_add_entry_to_op_version_map", &test_only_add_entry);
+ m.def("_test_only_remove_entry_to_op_version_map", &test_only_remove_entry);
m.def(
"import_ir_module",
[](std::shared_ptr<CompilationUnit> cu,
diff --git a/torch/jit/operator_upgraders.py b/torch/jit/operator_upgraders.py
index 75a18da..dc2bb65 100644
--- a/torch/jit/operator_upgraders.py
+++ b/torch/jit/operator_upgraders.py
@@ -86,7 +86,9 @@
# in the torch/csrc/operator_upgraders/version_map.h
entries = globals()
- version_map = torch._C._get_operator_version_map()
+ # ignore test operators
+ version_map = {k : v for k, v in torch._C._get_operator_version_map().items()
+ if not k.startswith("aten::_test")}
# 1. Check if everything in version_map.h is defined here
available_upgraders_in_version_map = set()
@@ -101,7 +103,7 @@
for entry in entries:
if isinstance(entries[entry], torch.jit.ScriptFunction):
if entry not in available_upgraders_in_version_map:
- raise AssertionError("The upgrader {} is not registered in the version_map.h")
+ raise AssertionError("The upgrader {} is not registered in the version_map.h".format(entry))
return available_upgraders_in_version_map