Bump op_version_set (#19812)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/19812
ghimport-source-id: 7cdb24c6a7501c6ec5f0eae07325512746f5abb9

Differential Revision: D15102803

Pulled By: zdevito

fbshipit-source-id: bedf0bd6e1170fa294c65c87df75b82d8694f89c
diff --git a/test/cpp/jit/test_class_import.h b/test/cpp/jit/test_class_import.h
index 20ac2dd..d8b3f84 100644
--- a/test/cpp/jit/test_class_import.h
+++ b/test/cpp/jit/test_class_import.h
@@ -9,7 +9,7 @@
 namespace script {
 
 static const auto classSrcs1 = R"JIT(
-op_version_set = 0
+op_version_set = 1
 class FooNestedTest:
     def __init__(self, y):
         self.y = y
@@ -25,7 +25,7 @@
 )JIT";
 
 static const auto classSrcs2 = R"JIT(
-op_version_set = 0
+op_version_set = 1
 class FooTest:
     def __init__(self, x):
       self.dx = x
diff --git a/test/test_jit.py b/test/test_jit.py
index f5fa866..7382c1d 100644
--- a/test/test_jit.py
+++ b/test/test_jit.py
@@ -2996,6 +2996,15 @@
         cu = torch.jit.CompilationUnit()._import(r, [])
         self.assertExpected(cu.foo.code)
 
+    def test_import_way_too_new(self):
+        @torch.jit.script
+        def foo(x, y):
+            return 2 * x + y
+
+        r, _ = _jit_python_print(foo)
+        with self.assertRaisesRegex(RuntimeError, "generated from a newer version"):
+            torch.jit.CompilationUnit()._import(r, [], op_version_set=10000)
+
     def test_function_default_values(self):
         outer_var = torch.tensor(20)
         outer_var2 = torch.tensor(30)
diff --git a/torch/csrc/jit/export.cpp b/torch/csrc/jit/export.cpp
index d19c30e..1508381 100644
--- a/torch/csrc/jit/export.cpp
+++ b/torch/csrc/jit/export.cpp
@@ -541,7 +541,6 @@
   OrderedDict<ClassTypePtr, std::string> converted_classes_;
   std::unordered_map<ClassTypePtr, std::vector<ClassTypePtr>> class_to_deps_;
 
-  static const size_t op_version_set = 0;
 };
 
 // ScriptModuleSerializer's methods
@@ -612,7 +611,7 @@
     const std::string& src = fileToSrc.at(filename).str();
 
     std::ostringstream lib_stream;
-    lib_stream << "op_version_set = " << op_version_set << "\n";
+    lib_stream << "op_version_set = " << CURRENT_OP_VERSION_SET << "\n";
     lib_stream << src;
     std::string lib_str = lib_stream.str();
     writer_.writeRecord(filename, lib_str.c_str(), lib_str.size());
@@ -771,7 +770,7 @@
 
   if (module.get_methods().size() > 0) {
     std::ostringstream methods;
-    methods << "op_version_set = " << op_version_set << "\n";
+    methods << "op_version_set = " << CURRENT_OP_VERSION_SET << "\n";
     PythonPrint(
         methods,
         module.class_compilation_unit(),
diff --git a/torch/csrc/jit/export.h b/torch/csrc/jit/export.h
index 636fa2e..ad3a7f5 100644
--- a/torch/csrc/jit/export.h
+++ b/torch/csrc/jit/export.h
@@ -19,6 +19,8 @@
 // file contents being the raw tensor data.
 using RawDataExportMap = std::unordered_map<std::string, at::Tensor>;
 
+constexpr size_t CURRENT_OP_VERSION_SET = 1;
+
 TORCH_API std::tuple<std::string, RawDataExportMap> export_onnx(
     const std::shared_ptr<Graph>& graph,
     const std::map<std::string, at::Tensor>& initializers,
diff --git a/torch/csrc/jit/import_source.cpp b/torch/csrc/jit/import_source.cpp
index 02ea7fc..8809fe8 100644
--- a/torch/csrc/jit/import_source.cpp
+++ b/torch/csrc/jit/import_source.cpp
@@ -3,6 +3,7 @@
 #include <ATen/core/qualified_name.h>
 #include <torch/csrc/jit/script/parser.h>
 #include <torch/csrc/jit/script/resolver.h>
+#include <torch/csrc/jit/export.h>
 
 namespace torch {
 namespace jit {
@@ -163,7 +164,18 @@
         std::make_shared<SourceResolver>(lib_cu_, version_, constant_table_);
   }
 
+  void checkVersionNumber() {
+    // note: this cannot be called in the constructor because it may throw
+    if (version_ > CURRENT_OP_VERSION_SET) {
+      throw ErrorReport(p_.lexer().cur().range)
+          << "Attempting to load a script generated from a newer version of PyTorch. Maximum supported TorchScript version is "
+          << CURRENT_OP_VERSION_SET << " but the script being loaded is version "
+          << version_ << ".";
+    }
+  }
+
   void importLibs(CompilationUnit& owner, const std::string& class_qualifier) {
+    checkVersionNumber();
     auto& L = p_.lexer();
 
     while (L.cur().kind != TK_EOF) {
@@ -192,6 +204,7 @@
   }
 
   void importFunctions(CompilationUnit& cu, const Self& self) {
+    checkVersionNumber();
     parseImportsAndDoCallback();
 
     std::vector<Def> definitions;
diff --git a/torch/jit/__init__.py b/torch/jit/__init__.py
index 8c9b684..0c3057e 100644
--- a/torch/jit/__init__.py
+++ b/torch/jit/__init__.py
@@ -854,9 +854,9 @@
             raise AttributeError("'CompilationUnit' has no attribute '{}'".format(attr))
         return r
 
-    def _import(self, src, constants):
+    def _import(self, src, constants, op_version_set=1):
         """ test import logic for single function, use only for testing """
-        src = "op_version_set = 0\n{}".format(src)
+        src = "op_version_set = {}\n{}".format(op_version_set, src)
         torch._C._jit_import_functions(self._c, src, constants, None)
         return self