Will Feng | 6b97279 | 2019-06-19 10:20:44 -0700 | [diff] [blame] | 1 | """ |
| 2 | This global flag controls whether to assign new tensors to the parameters |
| 3 | instead of changing the existing parameters in-place when converting an `nn.Module` |
| 4 | using the following methods: |
| 5 | 1. `module.cuda()` / `.cpu()` (for moving `module` between devices) |
| 6 | 2. `module.float()` / `.double()` / `.half()` (for converting `module` to a different dtype) |
| 7 | 3. `module.to()` / `.type()` (for changing `module`'s device or dtype) |
| 8 | 4. `module._apply(fn)` (for generic functions applied to `module`) |
| 9 | |
| 10 | Default: False |
| 11 | """ |
| 12 | _overwrite_module_params_on_conversion = False |
| 13 | |
Huy Do | 12cb265 | 2022-07-22 02:19:50 +0000 | [diff] [blame] | 14 | |
Will Feng | 6b97279 | 2019-06-19 10:20:44 -0700 | [diff] [blame] | 15 | def set_overwrite_module_params_on_conversion(value): |
| 16 | global _overwrite_module_params_on_conversion |
| 17 | _overwrite_module_params_on_conversion = value |
| 18 | |
Huy Do | 12cb265 | 2022-07-22 02:19:50 +0000 | [diff] [blame] | 19 | |
Will Feng | 6b97279 | 2019-06-19 10:20:44 -0700 | [diff] [blame] | 20 | def get_overwrite_module_params_on_conversion(): |
| 21 | return _overwrite_module_params_on_conversion |