commit | 7caeac17183a9aee0ccce4a3470925c6fe7e5007 | [log] [tgz] |
---|---|---|
author | Soumith Chintala <soumith@gmail.com> | Fri Oct 21 06:36:13 2022 +0000 |
committer | PyTorch MergeBot <pytorchmergebot@users.noreply.github.com> | Fri Oct 21 06:36:16 2022 +0000 |
tree | 626f8c2c1060fde984428762c110445824fbc9d6 | |
parent | 6b59d9b566001cd7036ac06497372eae6238cdd4 [diff] |
[inductor] Fix channels_last conv2d propagation when CuDNN is not found (#87266) Fixes https://github.com/pytorch/torchdynamo/issues/1701 cc @jansel @lezcano @fdrocha @mlazos @voznesenskym @yanboliang Pull Request resolved: https://github.com/pytorch/pytorch/pull/87266 Approved by: https://github.com/anijain2305, https://github.com/jansel, https://github.com/voznesenskym
diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index a9b6979..13cf5d7 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py
@@ -3121,13 +3121,16 @@ # CUDA channels_last path depend on cudnn version, see # https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/ConvUtils.h. - valid_device = True + valid_cudnn = False if ( - x.get_device() == "cuda" - and torch.backends.cudnn.is_available() - and torch.backends.cudnn.version() < 8302 + torch.backends.cudnn.is_available() + and torch.backends.cudnn.version() >= 7603 ): - valid_device = False + valid_cudnn = True + + valid_device = x.get_device().type == "cpu" or ( + x.get_device().type == "cuda" and valid_cudnn + ) if ( valid_device and len(x.get_size()) == 4