Bump op_version_set (#19812)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/19812
ghimport-source-id: 7cdb24c6a7501c6ec5f0eae07325512746f5abb9
Differential Revision: D15102803
Pulled By: zdevito
fbshipit-source-id: bedf0bd6e1170fa294c65c87df75b82d8694f89c
diff --git a/test/cpp/jit/test_class_import.h b/test/cpp/jit/test_class_import.h
index 20ac2dd..d8b3f84 100644
--- a/test/cpp/jit/test_class_import.h
+++ b/test/cpp/jit/test_class_import.h
@@ -9,7 +9,7 @@
namespace script {
static const auto classSrcs1 = R"JIT(
-op_version_set = 0
+op_version_set = 1
class FooNestedTest:
def __init__(self, y):
self.y = y
@@ -25,7 +25,7 @@
)JIT";
static const auto classSrcs2 = R"JIT(
-op_version_set = 0
+op_version_set = 1
class FooTest:
def __init__(self, x):
self.dx = x
diff --git a/test/test_jit.py b/test/test_jit.py
index f5fa866..7382c1d 100644
--- a/test/test_jit.py
+++ b/test/test_jit.py
@@ -2996,6 +2996,15 @@
cu = torch.jit.CompilationUnit()._import(r, [])
self.assertExpected(cu.foo.code)
+ def test_import_way_too_new(self):
+ @torch.jit.script
+ def foo(x, y):
+ return 2 * x + y
+
+ r, _ = _jit_python_print(foo)
+ with self.assertRaisesRegex(RuntimeError, "generated from a newer version"):
+ torch.jit.CompilationUnit()._import(r, [], op_version_set=10000)
+
def test_function_default_values(self):
outer_var = torch.tensor(20)
outer_var2 = torch.tensor(30)
diff --git a/torch/csrc/jit/export.cpp b/torch/csrc/jit/export.cpp
index d19c30e..1508381 100644
--- a/torch/csrc/jit/export.cpp
+++ b/torch/csrc/jit/export.cpp
@@ -541,7 +541,6 @@
OrderedDict<ClassTypePtr, std::string> converted_classes_;
std::unordered_map<ClassTypePtr, std::vector<ClassTypePtr>> class_to_deps_;
- static const size_t op_version_set = 0;
};
// ScriptModuleSerializer's methods
@@ -612,7 +611,7 @@
const std::string& src = fileToSrc.at(filename).str();
std::ostringstream lib_stream;
- lib_stream << "op_version_set = " << op_version_set << "\n";
+ lib_stream << "op_version_set = " << CURRENT_OP_VERSION_SET << "\n";
lib_stream << src;
std::string lib_str = lib_stream.str();
writer_.writeRecord(filename, lib_str.c_str(), lib_str.size());
@@ -771,7 +770,7 @@
if (module.get_methods().size() > 0) {
std::ostringstream methods;
- methods << "op_version_set = " << op_version_set << "\n";
+ methods << "op_version_set = " << CURRENT_OP_VERSION_SET << "\n";
PythonPrint(
methods,
module.class_compilation_unit(),
diff --git a/torch/csrc/jit/export.h b/torch/csrc/jit/export.h
index 636fa2e..ad3a7f5 100644
--- a/torch/csrc/jit/export.h
+++ b/torch/csrc/jit/export.h
@@ -19,6 +19,8 @@
// file contents being the raw tensor data.
using RawDataExportMap = std::unordered_map<std::string, at::Tensor>;
+constexpr size_t CURRENT_OP_VERSION_SET = 1;
+
TORCH_API std::tuple<std::string, RawDataExportMap> export_onnx(
const std::shared_ptr<Graph>& graph,
const std::map<std::string, at::Tensor>& initializers,
diff --git a/torch/csrc/jit/import_source.cpp b/torch/csrc/jit/import_source.cpp
index 02ea7fc..8809fe8 100644
--- a/torch/csrc/jit/import_source.cpp
+++ b/torch/csrc/jit/import_source.cpp
@@ -3,6 +3,7 @@
#include <ATen/core/qualified_name.h>
#include <torch/csrc/jit/script/parser.h>
#include <torch/csrc/jit/script/resolver.h>
+#include <torch/csrc/jit/export.h>
namespace torch {
namespace jit {
@@ -163,7 +164,18 @@
std::make_shared<SourceResolver>(lib_cu_, version_, constant_table_);
}
+ void checkVersionNumber() {
+ // note: this cannot be called in the constructor because it may throw
+ if (version_ > CURRENT_OP_VERSION_SET) {
+ throw ErrorReport(p_.lexer().cur().range)
+ << "Attempting to load a script generated from a newer version of PyTorch. Maximum supported TorchScript version is "
+ << CURRENT_OP_VERSION_SET << " but the script being loaded is version "
+ << version_ << ".";
+ }
+ }
+
void importLibs(CompilationUnit& owner, const std::string& class_qualifier) {
+ checkVersionNumber();
auto& L = p_.lexer();
while (L.cur().kind != TK_EOF) {
@@ -192,6 +204,7 @@
}
void importFunctions(CompilationUnit& cu, const Self& self) {
+ checkVersionNumber();
parseImportsAndDoCallback();
std::vector<Def> definitions;
diff --git a/torch/jit/__init__.py b/torch/jit/__init__.py
index 8c9b684..0c3057e 100644
--- a/torch/jit/__init__.py
+++ b/torch/jit/__init__.py
@@ -854,9 +854,9 @@
raise AttributeError("'CompilationUnit' has no attribute '{}'".format(attr))
return r
- def _import(self, src, constants):
+ def _import(self, src, constants, op_version_set=1):
""" test import logic for single function, use only for testing """
- src = "op_version_set = 0\n{}".format(src)
+ src = "op_version_set = {}\n{}".format(op_version_set, src)
torch._C._jit_import_functions(self._c, src, constants, None)
return self