commit | c47cf9bc7f9e02f649ab4ed53fe4d35732c92ab6 | [log] [tgz] |
---|---|---|
author | Kim,Won-Joong <wonjoong11@naver.com> | Tue Mar 21 17:46:23 2023 +0000 |
committer | PyTorch MergeBot <pytorchmergebot@users.noreply.github.com> | Tue Mar 21 17:46:23 2023 +0000 |
tree | 06e9396276db399e414bc4293bde02922582b3d2 | |
parent | a6bbeec2e1df840855dcbcff782ddee015a507bb [diff] |
Update parallel_apply.py for assertion error when len(modules) != len(inputs) (#94671) Print the result why it is wrong. Pull Request resolved: https://github.com/pytorch/pytorch/pull/94671 Approved by: https://github.com/ngimel, https://github.com/kit1980
diff --git a/torch/nn/parallel/parallel_apply.py b/torch/nn/parallel/parallel_apply.py index a114dfd..8f3d220 100644 --- a/torch/nn/parallel/parallel_apply.py +++ b/torch/nn/parallel/parallel_apply.py
@@ -35,7 +35,7 @@ element of :attr:`inputs` can either be a single object as the only argument to a module, or a collection of positional arguments. """ - assert len(modules) == len(inputs) + assert len(modules) == len(inputs), f'The number of modules {len(modules)} is not equal to the number of inputs {len(inputs)}' if kwargs_tup is not None: assert len(modules) == len(kwargs_tup) else: