Don't raise an error when retrieval of container's source code fails
diff --git a/torch/serialization.py b/torch/serialization.py
index a654a24..1704ee2 100644
--- a/torch/serialization.py
+++ b/torch/serialization.py
@@ -137,8 +137,14 @@
if obj in serialized_container_types:
return None
serialized_container_types[obj] = True
- source_file = inspect.getsourcefile(obj)
- source = inspect.getsource(obj)
+ source_file = source = None
+ try:
+ source_file = inspect.getsourcefile(obj)
+ source = inspect.getsource(obj)
+ except TypeError:
+ warnings.warn("Couldn't retrieve source code for container of "
+ "type " + obj.__name__ + ". It won't be checked "
+ "for correctness upon loading.")
return (obj, source_file, source)
if torch.is_tensor(obj):
serialized_tensors[obj._cdata] = obj
@@ -299,7 +305,9 @@
def persistent_load(saved_id):
if isinstance(saved_id, tuple):
- _check_container_source(*saved_id)
+ # Ignore containers that don't have any sources saved
+ if all(saved_id[1:]):
+ _check_container_source(*saved_id)
return saved_id[0]
return deserialized_objects[int(saved_id)]