assign var for "not populated" str (#108844)
minor cleanup of assigning a variable to the 'not populated' string value referenced in several places in `vmapify_autograd_function`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/108844
Approved by: https://github.com/zou3519
diff --git a/torch/_functorch/autograd_function.py b/torch/_functorch/autograd_function.py
index 92fd138..bb628ea 100644
--- a/torch/_functorch/autograd_function.py
+++ b/torch/_functorch/autograd_function.py
@@ -349,9 +349,10 @@
# vmap(vmap( but not completely sure if it is a problem. If we
# assigned those fields to the ctx object, the worry is that they
# get overwritten.
- out_dims = "not populated"
- input_shapes: Any = "not populated"
- saved_tensors_bdims: Any = "not populated"
+ init_val = "not populated"
+ out_dims = init_val
+ input_shapes: Any = init_val
+ saved_tensors_bdims: Any = init_val
def forward(*operands):
nonlocal out_dims
@@ -395,8 +396,8 @@
saved_tensors_bdims = saved_tensors_bdims_
def jvp(ctx, *tangents):
- assert out_dims != "not populated"
- assert saved_tensors_bdims != "not populated"
+ assert out_dims != init_val
+ assert saved_tensors_bdims != init_val
def jvp_no_context(saved_tensors, tangents):
wrapped_ctx = CtxWithSavedTensors(ctx, saved_tensors)
@@ -411,9 +412,9 @@
return result
def backward(ctx, *grad_outputs):
- assert out_dims != "not populated"
- assert input_shapes != "not populated"
- assert saved_tensors_bdims != "not populated"
+ assert out_dims != init_val
+ assert input_shapes != init_val
+ assert saved_tensors_bdims != init_val
def backward_no_context(inputs):
saved_tensors, grad_outputs = inputs
@@ -440,7 +441,7 @@
)
def get_out_dims():
- assert out_dims != "not populated"
+ assert out_dims != init_val
return out_dims
return Generated, get_out_dims