Attribute serialization improvements (#18188)
Summary:
* adds attributes to `ScriptModule.__getattr__` so they can be accessed in Python after re-importing
* full support for all the possible values for an `int64_t`
* this necessitated a bunch more `pushWhatever` functions, so re-introduced a templated version to cut down on duplicate code
* tests to validate references / value sharing works
* adds `torch.jit.Unpickler` which people can use to de-serialize the pickle files into Python / have a quick reference on how to do this without PyTorch
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18188
Differential Revision: D14527490
Pulled By: driazati
fbshipit-source-id: efd15579cc04aa2e28c4b2c9490d82d849dee559
diff --git a/test/test_jit.py b/test/test_jit.py
index 4599c96..3f33683 100644
--- a/test/test_jit.py
+++ b/test/test_jit.py
@@ -16,7 +16,7 @@
from torch.autograd.function import traceable
from torch.testing import assert_allclose
from torch.onnx import OperatorExportTypes
-from torch._six import inf, PY2, builtins
+from torch._six import inf, PY2, builtins, StringIO
from common_utils import TestCase, run_tests, IS_WINDOWS, TEST_WITH_UBSAN, \
skipIfRocm, skipIfNoLapack, suppress_warnings, load_tests, IS_SANDCASTLE, \
freeze_rng_state, set_rng_seed, slowTest
@@ -37,7 +37,10 @@
import math
import types
import pickle
+import pickletools
import copy
+import zipfile
+
from common_methods_invocations import method_tests as autograd_method_tests
from common_methods_invocations import create_input, unpack_variables, \
@@ -10488,8 +10491,6 @@
@unittest.skipIf(IS_WINDOWS or IS_SANDCASTLE, "NYI: TemporaryFileName support for Windows or Sandcastle")
def test_attribute_unpickling(self):
- import zipfile
-
class M(torch.jit.ScriptModule):
def __init__(self):
super(M, self).__init__()
@@ -10557,6 +10558,69 @@
imported_m = self.getExportImportCopy(m)
self.assertEqual(m(), imported_m())
+ def test_serialization_big_ints(self):
+ class M(torch.jit.ScriptModule):
+ def __init__(self):
+ super(M, self).__init__()
+ self.int32_max = torch.jit.Attribute(2**31 - 1, int)
+ self.int32_min = torch.jit.Attribute(-2**31, int)
+ self.uint32_max = torch.jit.Attribute(2**32, int)
+
+ self.int64_max = torch.jit.Attribute(2**63 - 1, int)
+ self.int64_min = torch.jit.Attribute(-2**63, int)
+
+ self.tensor = torch.nn.Parameter(torch.ones(2, 2))
+
+ @torch.jit.script_method
+ def forward(self, x):
+ # type: (int) -> (int)
+ return x + (self.int32_max + self.int32_min) + (self.int64_max + self.int64_min)
+
+ m = M()
+ imported = self.getExportImportCopy(m)
+ self.assertEqual(m(10), imported(10))
+
+ self.assertEqual(m.int32_max, imported.int32_max)
+ self.assertEqual(m.int32_min, imported.int32_min)
+ self.assertEqual(m.uint32_max, imported.uint32_max)
+ self.assertEqual(m.int64_max, imported.int64_max)
+ self.assertEqual(m.int64_min, imported.int64_min)
+
+ def test_serialization_sharing(self):
+ class M(torch.jit.ScriptModule):
+ def __init__(self):
+ super(M, self).__init__()
+ self.list = torch.jit.Attribute([], List[str])
+
+ @torch.jit.script_method
+ def forward(self, key):
+ # type: (str) -> List[str]
+ self.list.append(key)
+ self.list.append(key)
+ self.list.append(key)
+ return self.list
+
+ # the text of the string should only appear once in the pickling
+ m = M()
+ s1 = "a long string"
+ s2 = "a different, even longer string"
+ self.assertEqual(m(s1), [s1] * 3)
+ self.assertEqual(m(s2), [s1] * 3 + [s2] * 3)
+ with TemporaryFileName() as fname:
+ m.save(fname)
+ archive_name = os.path.basename(os.path.normpath(fname))
+ archive = zipfile.ZipFile(fname, 'r')
+ pickled_data = archive.read(os.path.join(archive_name, 'attributes.pkl'))
+
+ out = StringIO()
+ pickletools.dis(pickled_data, out=out)
+ disassembled = out.getvalue()
+
+ FileCheck().check_count(s1, 1, exactly=True) \
+ .check_count("BINGET", 2, exactly=True) \
+ .check_count(s2, 1, exactly=True) \
+ .check_count("BINGET", 2, exactly=True).run(out.getvalue())
+
def test_optional_tuple(self):
def fn(x=None):
# type: (Optional[Tuple[int, int]]) -> Tuple[int, int]
diff --git a/torch/_six.py b/torch/_six.py
index 5eeb18a..b062114 100644
--- a/torch/_six.py
+++ b/torch/_six.py
@@ -137,6 +137,13 @@
elif PY3:
import builtins
+if PY2:
+ import StringIO
+ StringIO = StringIO.StringIO
+elif PY3:
+ import io
+ StringIO = io.StringIO
+
# The codes below is not copied from the six package, so the copyright
# declaration at the beginning does not apply.
diff --git a/torch/csrc/jit/passes/python_print.cpp b/torch/csrc/jit/passes/python_print.cpp
index 3e37054..d515014 100644
--- a/torch/csrc/jit/passes/python_print.cpp
+++ b/torch/csrc/jit/passes/python_print.cpp
@@ -827,9 +827,9 @@
if (enforce_importable_) {
throw script::ErrorReport(node->getSourceLocation())
<< "could not export python function call " << value->name()
- << ". Remove calls to Python functions before export."
- << "Did you forget add @script annotation? "
- << "If this is a modulelist, add it to __constants__.";
+ << ". Remove calls to Python functions before export. "
+ << "Did you forget add @script or @script_method annotation? "
+ << "If this is a nn.ModuleList, add it to __constants__.";
}
stmt << "^" << value->name();
diff --git a/torch/csrc/jit/pickler.cpp b/torch/csrc/jit/pickler.cpp
index 3da3132..df95bc4 100644
--- a/torch/csrc/jit/pickler.cpp
+++ b/torch/csrc/jit/pickler.cpp
@@ -37,18 +37,18 @@
}
void Pickler::start() {
- pushOpCode(OpCode::PROTO);
- pushUint8(2);
+ push<OpCode>(OpCode::PROTO);
+ push<uint8_t>(2);
// All attributes get pushed into a list and their indices saved in the
// module def
- pushOpCode(OpCode::EMPTY_LIST);
- pushOpCode(OpCode::MARK);
+ push<OpCode>(OpCode::EMPTY_LIST);
+ push<OpCode>(OpCode::MARK);
}
void Pickler::finish() {
- pushOpCode(OpCode::APPENDS);
- pushOpCode(OpCode::STOP);
+ push<OpCode>(OpCode::APPENDS);
+ push<OpCode>(OpCode::STOP);
}
void Pickler::addIValue(const IValue& ivalue) {
@@ -70,17 +70,12 @@
} else if (ivalue.isDouble()) {
pushDouble(ivalue);
} else if (ivalue.isInt()) {
- // TODO: use BININT1/BININT2/LONG if possible/necessary
- AT_ASSERT(
- ivalue.toInt() <= std::numeric_limits<int32_t>::max() &&
- ivalue.toInt() >= std::numeric_limits<int32_t>::min());
- pushOpCode(OpCode::BININT);
- pushInt32(ivalue.toInt());
+ pushInt(ivalue);
} else if (ivalue.isBool()) {
if (ivalue.toBool()) {
- pushOpCode(OpCode::NEWTRUE);
+ push<OpCode>(OpCode::NEWTRUE);
} else {
- pushOpCode(OpCode::NEWFALSE);
+ push<OpCode>(OpCode::NEWFALSE);
}
} else if (ivalue.isString()) {
pushMemoizedString(ivalue);
@@ -89,7 +84,7 @@
} else if (ivalue.isGenericDict()) {
pushDict(ivalue);
} else if (ivalue.isNone()) {
- pushOpCode(OpCode::NONE);
+ push<OpCode>(OpCode::NONE);
} else if (ivalue.isIntList()) {
pushIntList(ivalue);
} else {
@@ -113,22 +108,41 @@
return nullptr;
}
+void Pickler::pushInt(const IValue& ivalue) {
+ auto n = ivalue.toInt();
+ if (n >= std::numeric_limits<int8_t>::min() &&
+ n <= std::numeric_limits<int8_t>::max()) {
+ push<OpCode>(OpCode::BININT1);
+ push<int8_t>(n);
+ } else if (
+ n >= std::numeric_limits<int32_t>::min() &&
+ n <= std::numeric_limits<int32_t>::max()) {
+ push<OpCode>(OpCode::BININT);
+ push<int32_t>(n);
+ } else {
+ // Push 8 byte integer
+ push<OpCode>(OpCode::LONG1);
+ push<uint8_t>(8);
+ push<int64_t>(n);
+ }
+}
+
void Pickler::pushBinGet(uint32_t memo_id) {
if (memo_id <= std::numeric_limits<uint8_t>::max()) {
- pushOpCode(OpCode::BINGET);
- pushUint8(memo_id);
+ push<OpCode>(OpCode::BINGET);
+ push<uint8_t>(memo_id);
} else {
// Memoized too many items, issue a LONG_BINGET instead
- pushOpCode(OpCode::LONG_BINGET);
- pushUint32(memo_id);
+ push<OpCode>(OpCode::LONG_BINGET);
+ push<uint32_t>(memo_id);
}
}
void Pickler::pushMemoizedString(const IValue& ivalue) {
const auto& string = ivalue.toStringRef();
- pushOpCode(OpCode::BINUNICODE);
- pushUint32(string.size());
+ push<OpCode>(OpCode::BINUNICODE);
+ push<uint32_t>(string.size());
pushString(string);
pushMemoization(ivalue);
}
@@ -142,7 +156,7 @@
// Write it to the tensor table
auto memo_entry = memo_.find(&name);
if (memo_entry == memo_.end()) {
- pushOpCode(OpCode::GLOBAL);
+ push<OpCode>(OpCode::GLOBAL);
// Module name + "\n"
pushString(getModuleName());
// Class name + "\n"
@@ -152,8 +166,8 @@
pushBinGet(memo_entry->second);
}
- pushOpCode(OpCode::EMPTY_TUPLE);
- pushOpCode(OpCode::NEWOBJ);
+ push<OpCode>(OpCode::EMPTY_TUPLE);
+ push<OpCode>(OpCode::NEWOBJ);
}
void Pickler::pushTensor(const IValue& ivalue) {
@@ -161,25 +175,25 @@
tensor_table_->push_back(ivalue.toTensor());
auto tensor_id = tensor_table_->size() - 1;
- pushOpCode(OpCode::BININT);
- pushUint32(tensor_id);
+ push<OpCode>(OpCode::BININT);
+ push<uint32_t>(tensor_id);
- pushOpCode(OpCode::BUILD);
+ push<OpCode>(OpCode::BUILD);
}
void Pickler::pushIntList(const IValue& ivalue) {
pushClass(PicklerClass::INTLIST);
- pushOpCode(OpCode::EMPTY_LIST);
+ push<OpCode>(OpCode::EMPTY_LIST);
pushMemoization(ivalue);
- pushOpCode(OpCode::MARK);
+ push<OpCode>(OpCode::MARK);
for (const auto& item : ivalue.toIntListRef()) {
addIValue(item);
}
- pushOpCode(OpCode::APPENDS);
- pushOpCode(OpCode::BUILD);
+ push<OpCode>(OpCode::APPENDS);
+ push<OpCode>(OpCode::BUILD);
}
void Pickler::pushDouble(const IValue& ivalue) {
@@ -187,9 +201,9 @@
AT_ASSERT(sizeof(double) == 8);
char* bytes = reinterpret_cast<char*>(&value);
- pushOpCode(OpCode::BINFLOAT);
+ push<OpCode>(OpCode::BINFLOAT);
for (size_t i = 0; i < 8; ++i) {
- pushUint8(bytes[8 - i - 1]);
+ push<uint8_t>(bytes[8 - i - 1]);
}
}
@@ -213,10 +227,10 @@
void Pickler::pushDict(const IValue& ivalue) {
auto dict = ivalue.toGenericDictRef();
- pushOpCode(OpCode::EMPTY_DICT);
+ push<OpCode>(OpCode::EMPTY_DICT);
pushMemoization(ivalue);
- pushOpCode(OpCode::MARK);
+ push<OpCode>(OpCode::MARK);
// Sort the dict for deterministic keys
std::vector<std::pair<IValue, IValue>> dict_items(dict.begin(), dict.end());
@@ -227,18 +241,18 @@
addIValue(pair.second);
}
- pushOpCode(OpCode::SETITEMS);
+ push<OpCode>(OpCode::SETITEMS);
}
void Pickler::pushMemoization(const void* item) {
AT_ASSERT(item != nullptr);
if (memo_id <= std::numeric_limits<uint8_t>::max()) {
- pushOpCode(OpCode::BINPUT);
- pushUint8(memo_id);
+ push<OpCode>(OpCode::BINPUT);
+ push<uint8_t>(memo_id);
} else {
// Memoized too many items, issue a LONG_BINPUT instead
- pushOpCode(OpCode::LONG_BINPUT);
- pushUint32(memo_id);
+ push<OpCode>(OpCode::LONG_BINPUT);
+ push<uint32_t>(memo_id);
}
memo_[item] = memo_id;
AT_ASSERT(memo_id <= std::numeric_limits<uint32_t>::max());
@@ -251,51 +265,31 @@
void Pickler::pushList(const IValue& ivalue) {
auto list = ivalue.toGenericListRef();
- pushOpCode(OpCode::EMPTY_LIST);
+ push<OpCode>(OpCode::EMPTY_LIST);
pushMemoization(ivalue);
- pushOpCode(OpCode::MARK);
+ push<OpCode>(OpCode::MARK);
for (const auto& item : list) {
addIValue(item);
}
- pushOpCode(OpCode::APPENDS);
+ push<OpCode>(OpCode::APPENDS);
}
void Pickler::pushTuple(const IValue& ivalue) {
// TODO: Small tuple unrolling (e.g. TUPLE3)
- pushOpCode(OpCode::MARK);
+ push<OpCode>(OpCode::MARK);
auto tuple = ivalue.toTuple()->elements();
for (const auto& item : tuple) {
addIValue(item);
}
- pushOpCode(OpCode::TUPLE);
+ push<OpCode>(OpCode::TUPLE);
pushMemoization(ivalue);
}
-void Pickler::pushUint8(uint8_t value) {
- const char* begin = reinterpret_cast<const char*>(&value);
- stack_.insert(stack_.end(), begin, begin + sizeof(uint8_t));
-}
-
-void Pickler::pushOpCode(OpCode value) {
- const char* begin = reinterpret_cast<const char*>(&value);
- stack_.insert(stack_.end(), begin, begin + sizeof(OpCode));
-}
-
-void Pickler::pushUint32(uint32_t value) {
- const char* begin = reinterpret_cast<const char*>(&value);
- stack_.insert(stack_.end(), begin, begin + sizeof(uint32_t));
-}
-
-void Pickler::pushInt32(int32_t value) {
- const char* begin = reinterpret_cast<const char*>(&value);
- stack_.insert(stack_.end(), begin, begin + sizeof(int32_t));
-}
-
std::vector<IValue> Unpickler::parse_ivalue_list() {
run();
AT_ASSERT(stack_.size() == 1);
@@ -367,10 +361,20 @@
// Mark location of the container ivalue in the stack
marks_.push_back(stack_.size());
} break;
+ case OpCode::BININT1: {
+ int8_t value = read<int8_t>();
+ stack_.emplace_back(int64_t(value));
+ } break;
case OpCode::BININT: {
int32_t value = read<int32_t>();
stack_.emplace_back(int64_t(value));
} break;
+ case OpCode::LONG1: {
+ // Only read LONG1s with 8 as the length
+ uint8_t length = read<uint8_t>();
+ AT_ASSERT(length == 8);
+ stack_.emplace_back(int64_t(read<int64_t>()));
+ } break;
case OpCode::BINUNICODE: {
uint32_t length = read<uint32_t>();
const char* characters = reinterpret_cast<const char*>(bytes_);
diff --git a/torch/csrc/jit/pickler.h b/torch/csrc/jit/pickler.h
index 2439aa6..ab35a70 100644
--- a/torch/csrc/jit/pickler.h
+++ b/torch/csrc/jit/pickler.h
@@ -112,12 +112,18 @@
void pushTuple(const IValue& ivalue);
void pushDict(const IValue& ivalue);
void pushClass(PicklerClass cls);
+ void pushInt(const IValue& ivalue);
const void* getPointer(const IValue& ivalue);
- void pushUint8(uint8_t value);
- void pushOpCode(OpCode value);
- void pushUint32(uint32_t value);
- void pushInt32(int32_t value);
+ // These convert values to bytes and add them to the stack (NB: since T is to
+ // the left of a '::', its type cannot be deduced by the compiler so one must
+ // explicitly instantiate the template, i.e. push<int>(int) works, push(int)
+ // does not)
+ template<typename T>
+ void push(typename std::common_type<T>::type value) {
+ const char* begin = reinterpret_cast<const char*>(&value);
+ stack_.insert(stack_.end(), begin, begin + sizeof(T));
+ }
// Stack of opcodes/data
std::vector<char> stack_;
diff --git a/torch/csrc/jit/script/init.cpp b/torch/csrc/jit/script/init.cpp
index 8175f35..de678e6 100644
--- a/torch/csrc/jit/script/init.cpp
+++ b/torch/csrc/jit/script/init.cpp
@@ -761,6 +761,7 @@
.def("_set_parameter", &Module::set_parameter)
.def("_get_parameter", &Module::get_parameter)
.def("_get_buffer", &Module::get_buffer)
+ .def("_get_attribute", &Module::get_attribute)
.def("_get_module", &Module::get_module)
.def(
"_get_modules",
@@ -801,6 +802,11 @@
return result;
})
.def(
+ "_has_attribute",
+ [](Module& self, const std::string& name) -> bool {
+ return self.find_attribute(name);
+ })
+ .def(
"_has_parameter",
[](Module& self, const std::string& name) -> bool {
return self.find_parameter(name);
diff --git a/torch/csrc/jit/script/module.h b/torch/csrc/jit/script/module.h
index 6d4d746..d4417a7 100644
--- a/torch/csrc/jit/script/module.h
+++ b/torch/csrc/jit/script/module.h
@@ -480,6 +480,9 @@
autograd::Variable get_buffer(const std::string& name) const {
return autograd::as_variable_ref(attributes.find(name)->slot()->toTensor());
}
+ IValue get_attribute(const std::string& name) const {
+ return *attributes.find(name)->slot();
+ }
// each module owns its method. The reference returned here
// is guarenteed to stay valid until this module has been destroyed
diff --git a/torch/jit/__init__.py b/torch/jit/__init__.py
index ec895cf..2671432 100644
--- a/torch/jit/__init__.py
+++ b/torch/jit/__init__.py
@@ -10,6 +10,7 @@
from torch._six import raise_from, with_metaclass, get_function_from_type, \
string_classes
from torch._jit_internal import ignore
+from torch.jit._pickle import Unpickler
from ..nn.modules.utils import _single, _pair, _triple, _quadruple, \
_list_with_default
import torch.testing
@@ -1131,6 +1132,8 @@
return self._get_method(attr)
if attr == 'graph' and self._has_method('forward'):
return self.__getattr__('forward').graph
+ if self._has_attribute(attr):
+ return self._get_attribute(attr)
return Module.__getattr__(self, attr)
def __setattr__(self, attr, value):
diff --git a/torch/jit/_pickle.py b/torch/jit/_pickle.py
new file mode 100644
index 0000000..24e7bf3
--- /dev/null
+++ b/torch/jit/_pickle.py
@@ -0,0 +1,26 @@
+import torch
+import functools
+import pickle
+
+
+class TensorID(object):
+ def __setstate__(self, id):
+ self.id = id
+
+
+class IntList(object):
+ def __setstate__(self, data):
+ self.data = data
+
+
+class Unpickler(pickle.Unpickler):
+ def find_class(self, module, name):
+ if not module == '__main__':
+ return None
+
+ if name == 'TensorID':
+ return TensorID
+ elif name == 'IntList':
+ return IntList
+ elif name == 'LiteralTensor':
+ return LiteralTensor