[adam] Add not torch.jit.is_scripting() as a requirement for switching to fused (#92181)
A "fix" following https://github.com/pytorch/pytorch/pull/90865. Realized that fused is not compatible with torch.jit.is_scripting() when looking at a later line.
Took the opportunity to make the code cleaner/slightly more performant (with the extends) as well.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/92181
Approved by: https://github.com/albanD
diff --git a/torch/optim/adam.py b/torch/optim/adam.py
index c250f1f..f24a4b8 100644
--- a/torch/optim/adam.py
+++ b/torch/optim/adam.py
@@ -316,13 +316,16 @@
# and when differentiable=False.
# We still respect when the user inputs False for fused.
if fused is None:
- if not differentiable and all(
- p.is_cuda and torch.is_floating_point(p)
- for p in params + grads + exp_avgs + exp_avg_sqs + max_exp_avg_sqs + state_steps
- ):
- fused = True
- else:
- fused = False
+ all_tensors = []
+ all_tensors.extend(params)
+ all_tensors.extend(grads)
+ all_tensors.extend(exp_avgs)
+ all_tensors.extend(exp_avg_sqs)
+ all_tensors.extend(max_exp_avg_sqs)
+ all_tensors.extend(state_steps)
+ fused = not torch.jit.is_scripting() and not differentiable and all(
+ p.is_cuda and torch.is_floating_point(p) for p in all_tensors
+ )
if not all(isinstance(t, torch.Tensor) for t in state_steps):
raise RuntimeError("API has changed, `state_steps` argument must contain a list of singleton tensors")