| # mypy: allow-untyped-defs |
| import torch |
| from torch.fx.experimental.proxy_tensor import ProxyTorchDispatchMode |
| from torch.overrides import TorchFunctionMode |
| |
| |
| class AutogradStateOpsFailSafeguard(TorchFunctionMode): |
| """ |
| Detect grad state ops during exporting the graph and fail the process by |
| raising an error, to avoid unexpected behavior. Those grad mode ops could be: |
| `torch.no_grad` |
| `torch.enable_grad` |
| `torch.set_grad_enabled` |
| |
| Export with predispatch mode is exempted. |
| """ |
| |
| def __torch_function__(self, func, types, args=(), kwargs=None): |
| kwargs = kwargs or {} |
| unsupported_grad_mode_ops = [ |
| torch._C._set_grad_enabled, |
| ] |
| # It's only enabled while tracing, by confirming the torch dispatch mode is |
| # any active PROXY. This is to allow the autograd ops out of tracing. |
| current_state = torch._C.is_grad_enabled() |
| if func in unsupported_grad_mode_ops: |
| assert len(args) == 1 |
| changed_state = args[0] |
| mode = torch._C._get_dispatch_mode(torch._C._TorchDispatchModeKey.PROXY) |
| # Intend to check if it's not the pre_dispatch mode. It's allowed to use |
| # autograd ops in pre_dispatch mode, e.g. `torch.no_grad` |
| if ( |
| mode |
| and isinstance(mode, ProxyTorchDispatchMode) |
| and not mode.pre_dispatch |
| and changed_state != current_state |
| ): |
| raise RuntimeError( |
| f"Encountered autograd state manager op {func} trying to change global autograd state " |
| "while exporting. This is unsafe because we don't capture this op in torch.export " |
| "today, hence we can't reflect the user intention soundly. You can fix this by " |
| "adding a torch.no_grad() context around the export call." |
| ) |
| return func(*args, **kwargs) |