Propagate map_location arg to torch.jit.load in torch.load (#78733)

Fixes #78331

Pull Request resolved: https://github.com/pytorch/pytorch/pull/78733
Approved by: https://github.com/davidberard98
diff --git a/torch/serialization.py b/torch/serialization.py
index e778817..8262b96 100644
--- a/torch/serialization.py
+++ b/torch/serialization.py
@@ -718,7 +718,7 @@
                                   " dispatching to 'torch.jit.load' (call 'torch.jit.load' directly to"
                                   " silence this warning)", UserWarning)
                     opened_file.seek(orig_position)
-                    return torch.jit.load(opened_file)
+                    return torch.jit.load(opened_file, map_location=map_location)
                 return _load(opened_zipfile, map_location, pickle_module, **pickle_load_args)
         return _legacy_load(opened_file, map_location, pickle_module, **pickle_load_args)