commit | 5b9b816b1743918d2b75b406c2a875b0fc322b80 | [log] [tgz] |
---|---|---|
author | Edward Z. Yang <ezyang@meta.com> | Wed Aug 16 08:48:50 2023 -0700 |
committer | PyTorch MergeBot <pytorchmergebot@users.noreply.github.com> | Thu Aug 17 00:31:16 2023 +0000 |
tree | a721ce31f191e11656dd72e7ff971e8785139a55 | |
parent | b234b94760c4c95282a2cb4dbd1e8b27fe73c2ff [diff] |
WAR by avoid querying device before env mutation (#107301) We should probably fix https://github.com/pytorch/pytorch/issues/107300 properly but this works around the problem Signed-off-by: Edward Z. Yang <ezyang@meta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/107301 Approved by: https://github.com/bdhirsh, https://github.com/H-Huang, https://github.com/albanD
diff --git a/benchmarks/dynamo/common.py b/benchmarks/dynamo/common.py index 42fa02b..6858e06 100644 --- a/benchmarks/dynamo/common.py +++ b/benchmarks/dynamo/common.py
@@ -3002,9 +3002,12 @@ f"--diff-branch: current branch is same as {args.diff_branch} branch, what are you diffing?" ) - device_count = torch.cuda.device_count() args.use_distributed = (args.ddp or args.fsdp) and args.only if args.multiprocess: + # NB: Do NOT query device count before CUDA initialization; we're + # going to overwrite CUDA_VISIBLE_DEVICES and this will result in + # https://github.com/pytorch/pytorch/issues/107300 + device_count = torch.cuda.device_count() if device_count <= 1: log.warning( "The use multiprocess flag is set but there are <= 1 devices available."