Added validation of mode parameter in AveragedModel (#65921)
Summary:
Discussion: https://github.com/pytorch/pytorch/pull/65495#issuecomment-930460469
Pull Request resolved: https://github.com/pytorch/pytorch/pull/65921
Reviewed By: albanD
Differential Revision: D31310105
Pulled By: prabhat00155
fbshipit-source-id: 417691832a7c793744830c11e0ce53e3972d21a3
diff --git a/torch/optim/swa_utils.py b/torch/optim/swa_utils.py
index a186e43..e87f10e 100644
--- a/torch/optim/swa_utils.py
+++ b/torch/optim/swa_utils.py
@@ -26,8 +26,8 @@
:class:`AveragedModel` parameter, the current value of :attr:`model`
parameter and the number of models already averaged; if None,
equally weighted average is used (default: None)
- mode (str, optional): whether to use parameters or state_dict for update
- (default: parameters)
+ mode (str, optional): whether to use ``'parameters'`` or ``'state_dict'`` for update
+ (default: ``'parameters'``)
Example:
>>> loader, optimizer, model, loss_fn = ...
@@ -98,6 +98,9 @@
return averaged_model_parameter + \
(model_parameter - averaged_model_parameter) / (num_averaged + 1)
self.avg_fn = avg_fn
+ modes = ['parameters', 'state_dict']
+ if mode not in modes:
+ raise ValueError(f'Invalid mode passed, valid values are {", ".join(modes)}.')
self.use_state_dict = mode == 'state_dict'
def forward(self, *args, **kwargs):