blob: 9ac8406e8f8ea3150eed5fb08843e2c72305c950 [file] [log] [blame]
Will Feng6b972792019-06-19 10:20:44 -07001"""
2This global flag controls whether to assign new tensors to the parameters
3instead of changing the existing parameters in-place when converting an `nn.Module`
4using the following methods:
51. `module.cuda()` / `.cpu()` (for moving `module` between devices)
62. `module.float()` / `.double()` / `.half()` (for converting `module` to a different dtype)
73. `module.to()` / `.type()` (for changing `module`'s device or dtype)
84. `module._apply(fn)` (for generic functions applied to `module`)
9
10Default: False
11"""
12_overwrite_module_params_on_conversion = False
13
Huy Do12cb2652022-07-22 02:19:50 +000014
Will Feng6b972792019-06-19 10:20:44 -070015def set_overwrite_module_params_on_conversion(value):
16 global _overwrite_module_params_on_conversion
17 _overwrite_module_params_on_conversion = value
18
Huy Do12cb2652022-07-22 02:19:50 +000019
Will Feng6b972792019-06-19 10:20:44 -070020def get_overwrite_module_params_on_conversion():
21 return _overwrite_module_params_on_conversion