[Py3.11] Remove skip logic from vmap and forward_ad (#91825)
Depends on https://github.com/pytorch/pytorch/pull/91805
Fixes https://github.com/pytorch/pytorch/issues/85506
Pull Request resolved: https://github.com/pytorch/pytorch/pull/91825
Approved by: https://github.com/albanD
diff --git a/torch/_functorch/vmap.py b/torch/_functorch/vmap.py
index 0cae1b9..efb0f6e 100644
--- a/torch/_functorch/vmap.py
+++ b/torch/_functorch/vmap.py
@@ -12,7 +12,6 @@
from .pytree_hacks import tree_map_
from functools import partial
import os
-import sys
import itertools
from torch._C._functorch import (
@@ -226,8 +225,7 @@
return
DECOMPOSITIONS_LOADED = True
- if not (os.environ.get("PYTORCH_JIT", "1" if sys.version_info < (3, 11) else "0") == "1" and
- __debug__):
+ if not (os.environ.get("PYTORCH_JIT", "1") == "1" and __debug__):
return
# use an alternate way to register an operator into the decomposition table
# _register_jit_decomposition doesn't work for some operators, e.g. addr,
diff --git a/torch/autograd/forward_ad.py b/torch/autograd/forward_ad.py
index 5db1041..d702845 100644
--- a/torch/autograd/forward_ad.py
+++ b/torch/autograd/forward_ad.py
@@ -1,6 +1,5 @@
import torch
import os
-import sys
from .grad_mode import _DecoratorContextManager
from collections import namedtuple
@@ -87,9 +86,7 @@
# buffer = z
# return min - torch.log1p(z), buffer
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
- # Currently broken for 3.11, see https://github.com/pytorch/pytorch/issues/85506
- if (os.environ.get("PYTORCH_JIT", "1" if sys.version_info < (3, 11) else "0") == "1" and
- __debug__):
+ if os.environ.get("PYTORCH_JIT", "1") == "1" and __debug__:
from torch._decomp import decompositions_for_jvp # noqa: F401
if level is None: