Propagate map_location arg to torch.jit.load in torch.load (#78733)
Fixes #78331
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78733
Approved by: https://github.com/davidberard98
diff --git a/torch/serialization.py b/torch/serialization.py
index e778817..8262b96 100644
--- a/torch/serialization.py
+++ b/torch/serialization.py
@@ -718,7 +718,7 @@
" dispatching to 'torch.jit.load' (call 'torch.jit.load' directly to"
" silence this warning)", UserWarning)
opened_file.seek(orig_position)
- return torch.jit.load(opened_file)
+ return torch.jit.load(opened_file, map_location=map_location)
return _load(opened_zipfile, map_location, pickle_module, **pickle_load_args)
return _legacy_load(opened_file, map_location, pickle_module, **pickle_load_args)