[fx] python_code(verbose=True): show size/strides for all tensors (#132192)

python_code(verbose=True) (or print_readable()) generates a string with the code representing the fx graph, with extra annotations indicating the size or stride of the tensor. Currently, it'll only shows sizes/strides for FakeTensors provided in metadata. For subclass tensors like NestedTensor, the outer class (provided in the node metadata) will be a non-FakeTensor and the inner tensors will be fake. This PR expands the conditional to show sizes/strides for all tensors, not just FakeTensors.

Testing: I ran this test script (below), ran it with `TORCH_LOGS=+dynamo` and found in the logs the graph shown below - we see that the input nested tensor has sizes and strides associated with it. Also, I stacked a diff on top of this one that forces the readable graph to be generated whenever PT2 is in use in tests, which should hopefully find any issues; https://github.com/pytorch/pytorch/pull/132195 shows no significant failures except for preexisting failures.

test script:
```python
import torch

def fn(x):
    return x.cos()

nt = torch.nested.nested_tensor_from_jagged(
    torch.randn(10, 10),
    torch.tensor([0, 1, 3, 6, 10]),
)

torch.compile(fn)(nt)
```

logs excerpt:
```
[0/0] [__graph_code] TRACED GRAPH
[0/0] [__graph_code]  ===== __compiled_fn_1 =====
[0/0] [__graph_code]  /data/users/dberard/pytorch/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.M

[0/0] [__graph_code]     def forward(self, L_x_: "f32[4, zf1, 10][10*zf1, 10, 1]cpu", zf1: "Sym(zf1)"):
[0/0] [__graph_code]         l_x_ = L_x_
[0/0] [__graph_code]
[0/0] [__graph_code]          # File: /data/users/dberard/scripts/nt_print_graph.py:4 in fn, code: return x.c

[0/0] [__graph_code]         cos: "f32[4, zf1, 10][10*zf1, 10, 1]cpu" = l_x_.cos();  l_x_ = None
[0/0] [__graph_code]         return (cos,)
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/132192
Approved by: https://github.com/Chillee
diff --git a/test/dynamo/test_higher_order_ops.py b/test/dynamo/test_higher_order_ops.py
index 1b911bb..20b31f0 100644
--- a/test/dynamo/test_higher_order_ops.py
+++ b/test/dynamo/test_higher_order_ops.py
@@ -2837,7 +2837,7 @@
 
         _vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(12, 'error');  _vmap_increment_nesting = None
 
-        child_1 = torch._C._functorch._add_batch_dim(child, 0, 1);  child = None
+        child_1: "f32[4, 3]" = torch._C._functorch._add_batch_dim(child, 0, 1);  child = None
 
         _jvp_treespec_compare = torch._functorch.eager_transforms._jvp_treespec_compare((l_x_,), (child_1,));  _jvp_treespec_compare = None
 
@@ -2847,68 +2847,68 @@
 
         _maybe_load_decompositions = torch.autograd.forward_ad._maybe_load_decompositions();  _maybe_load_decompositions = None
 
-        child_2 = torch._make_dual(l_x_, child_1, level = 0);  child_1 = None
+        child_2: "f32[4, 3]" = torch._make_dual(l_x_, child_1, level = 0);  child_1 = None
 
-        _wrap_for_grad = torch._C._functorch._wrap_for_grad(l_x_, 2);  l_x_ = _wrap_for_grad = None
+        _wrap_for_grad: "f32[4, 3]" = torch._C._functorch._wrap_for_grad(l_x_, 2);  l_x_ = _wrap_for_grad = None
 
         _saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func.{grad, vjp, jacrev, hessian} don't yet support saved tensor hooks. Please open an issue with your use case.");  _saved_tensors_hooks_disable = None
         _grad_increment_nesting = torch._C._functorch._grad_increment_nesting();  _grad_increment_nesting = None
 
-        diff_primals = torch._C._functorch._wrap_for_grad(child_2, 3);  child_2 = None
+        diff_primals: "f32[4, 3]" = torch._C._functorch._wrap_for_grad(child_2, 3);  child_2 = None
 
         set_inplace_requires_grad_allowed = torch._C._functorch.set_inplace_requires_grad_allowed(True);  set_inplace_requires_grad_allowed = None
 
-        _set_tensor_requires_grad = torch._functorch.eager_transforms._set_tensor_requires_grad(diff_primals);  _set_tensor_requires_grad = None
+        _set_tensor_requires_grad: "f32[4, 3]" = torch._functorch.eager_transforms._set_tensor_requires_grad(diff_primals);  _set_tensor_requires_grad = None
 
         set_inplace_requires_grad_allowed_1 = torch._C._functorch.set_inplace_requires_grad_allowed(False);  set_inplace_requires_grad_allowed_1 = None
 
-        o = torch.sin(diff_primals)
+        o: "f32[4, 3]" = torch.sin(diff_primals)
 
-        results = torch._C._functorch._unwrap_for_grad(o, 3)
+        results: "f32[4, 3]" = torch._C._functorch._unwrap_for_grad(o, 3)
 
         _grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting();  _grad_decrement_nesting = None
         _saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable();  _saved_tensors_hooks_enable = None
 
-        tensor_1 = torch.tensor((12,))
-        cumsum_1 = tensor_1.cumsum(dim = 0);  tensor_1 = None
-        getitem_1 = cumsum_1[slice(None, -1, None)];  cumsum_1 = None
-        neg_1 = getitem_1.neg();  getitem_1 = None
+        tensor_1: "i64[1]" = torch.tensor((12,))
+        cumsum_1: "i64[1]" = tensor_1.cumsum(dim = 0);  tensor_1 = None
+        getitem_1: "i64[0]" = cumsum_1[slice(None, -1, None)];  cumsum_1 = None
+        neg_1: "i64[0]" = getitem_1.neg();  getitem_1 = None
         unbind_1 = neg_1.unbind();  neg_1 = unbind_1 = None
 
-        chunk_1 = results.new_zeros(12, 12);  results = None
+        chunk_1: "f32[12, 12]" = results.new_zeros(12, 12);  results = None
 
-        diagonal_1 = chunk_1.diagonal(0)
-        fill__1 = diagonal_1.fill_(1);  diagonal_1 = fill__1 = None
+        diagonal_1: "f32[12]" = chunk_1.diagonal(0)
+        fill__1: "f32[12]" = diagonal_1.fill_(1);  diagonal_1 = fill__1 = None
 
-        basis = chunk_1.view(12, 4, 3);  chunk_1 = None
+        basis: "f32[12, 4, 3]" = chunk_1.view(12, 4, 3);  chunk_1 = None
 
         lazy_load_decompositions_1 = torch._functorch.vmap.lazy_load_decompositions();  lazy_load_decompositions_1 = None
 
         _vmap_increment_nesting_1 = torch._C._functorch._vmap_increment_nesting(12, 'error');  _vmap_increment_nesting_1 = None
 
-        _add_batch_dim_1 = torch._C._functorch._add_batch_dim(basis, 0, 3);  basis = None
+        _add_batch_dim_1: "f32[4, 3]" = torch._C._functorch._add_batch_dim(basis, 0, 3);  basis = None
 
         _vjp_treespec_compare = torch._functorch.eager_transforms._vjp_treespec_compare(o, _add_batch_dim_1);  _vjp_treespec_compare = None
 
         _autograd_grad = torch._functorch.eager_transforms._autograd_grad([o], [diff_primals], [_add_batch_dim_1], retain_graph = True, create_graph = True);  o = diff_primals = _add_batch_dim_1 = None
-        batched_outputs = _autograd_grad[0];  _autograd_grad = None
+        batched_outputs: "f32[4, 3]" = _autograd_grad[0];  _autograd_grad = None
 
-        chunked_result = torch._C._functorch._remove_batch_dim(batched_outputs, 3, 12, 0);  batched_outputs = None
+        chunked_result: "f32[12, 4, 3]" = torch._C._functorch._remove_batch_dim(batched_outputs, 3, 12, 0);  batched_outputs = None
 
         _vmap_decrement_nesting = torch._C._functorch._vmap_decrement_nesting();  _vmap_decrement_nesting = None
 
         split = chunked_result.split((12,), dim = 0);  chunked_result = None
-        split_1 = split[0];  split = None
+        split_1: "f32[12, 4, 3]" = split[0];  split = None
 
-        output_input = split_1.view((4, 3, 4, 3));  split_1 = None
+        output_input: "f32[4, 3, 4, 3]" = split_1.view((4, 3, 4, 3));  split_1 = None
 
         _unpack_dual = torch._unpack_dual(output_input, level = 0);  output_input = None
-        primal = _unpack_dual[0]
-        dual = _unpack_dual[1];  _unpack_dual = None
+        primal: "f32[4, 3, 4, 3]" = _unpack_dual[0]
+        dual: "f32[4, 3, 4, 3]" = _unpack_dual[1];  _unpack_dual = None
 
         primals_out_unflatten: "f32[4, 3, 4, 3]" = torch._C._functorch._unwrap_for_grad(primal, 2);  primal = primals_out_unflatten = None
 
-        tangents_out_unflatten = torch._C._functorch._unwrap_for_grad(dual, 2);  dual = None
+        tangents_out_unflatten: "f32[4, 3, 4, 3]" = torch._C._functorch._unwrap_for_grad(dual, 2);  dual = None
 
         _exit_dual_level = torch._C._exit_dual_level(0);  _exit_dual_level = None
         _set_fwd_grad_enabled_1 = torch._C._set_fwd_grad_enabled(True);  _set_fwd_grad_enabled_1 = None
@@ -2969,7 +2969,7 @@
 
         _vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(12, 'error');  _vmap_increment_nesting = None
 
-        child_1 = torch._C._functorch._add_batch_dim(child, 0, 1);  child = None
+        child_1: "f32[3, 4]" = torch._C._functorch._add_batch_dim(child, 0, 1);  child = None
 
         _jvp_treespec_compare = torch._functorch.eager_transforms._jvp_treespec_compare((l_y_,), (child_1,));  _jvp_treespec_compare = None
 
@@ -2979,67 +2979,67 @@
 
         _maybe_load_decompositions = torch.autograd.forward_ad._maybe_load_decompositions();  _maybe_load_decompositions = None
 
-        child_3 = torch._make_dual(l_y_, child_1, level = 0);  child_1 = None
+        child_3: "f32[3, 4]" = torch._make_dual(l_y_, child_1, level = 0);  child_1 = None
 
-        child_2 = torch._C._functorch._wrap_for_grad(l_x_, 2);  l_x_ = None
-        _wrap_for_grad_1 = torch._C._functorch._wrap_for_grad(l_y_, 2);  l_y_ = _wrap_for_grad_1 = None
+        child_2: "f32[4, 3]" = torch._C._functorch._wrap_for_grad(l_x_, 2);  l_x_ = None
+        _wrap_for_grad_1: "f32[3, 4]" = torch._C._functorch._wrap_for_grad(l_y_, 2);  l_y_ = _wrap_for_grad_1 = None
 
         _saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func.{grad, vjp, jacrev, hessian} don't yet support saved tensor hooks. Please open an issue with your use case.");  _saved_tensors_hooks_disable = None
         _grad_increment_nesting = torch._C._functorch._grad_increment_nesting();  _grad_increment_nesting = None
 
-        _wrap_for_grad_2 = torch._C._functorch._wrap_for_grad(child_2, 3);  child_2 = None
-        child_4 = torch._C._functorch._wrap_for_grad(child_3, 3);  child_3 = None
+        _wrap_for_grad_2: "f32[4, 3]" = torch._C._functorch._wrap_for_grad(child_2, 3);  child_2 = None
+        child_4: "f32[3, 4]" = torch._C._functorch._wrap_for_grad(child_3, 3);  child_3 = None
 
         set_inplace_requires_grad_allowed = torch._C._functorch.set_inplace_requires_grad_allowed(True);  set_inplace_requires_grad_allowed = None
 
-        _set_tensor_requires_grad = torch._functorch.eager_transforms._set_tensor_requires_grad(child_4);  _set_tensor_requires_grad = None
+        _set_tensor_requires_grad: "f32[3, 4]" = torch._functorch.eager_transforms._set_tensor_requires_grad(child_4);  _set_tensor_requires_grad = None
 
         set_inplace_requires_grad_allowed_1 = torch._C._functorch.set_inplace_requires_grad_allowed(False);  set_inplace_requires_grad_allowed_1 = None
 
-        o = _wrap_for_grad_2.sin();  _wrap_for_grad_2 = None
+        o: "f32[4, 3]" = _wrap_for_grad_2.sin();  _wrap_for_grad_2 = None
 
-        results = torch._C._functorch._unwrap_for_grad(o, 3)
+        results: "f32[4, 3]" = torch._C._functorch._unwrap_for_grad(o, 3)
 
         _grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting();  _grad_decrement_nesting = None
         _saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable();  _saved_tensors_hooks_enable = None
 
-        tensor_1 = torch.tensor((12,))
-        cumsum_1 = tensor_1.cumsum(dim = 0);  tensor_1 = None
-        getitem_1 = cumsum_1[slice(None, -1, None)];  cumsum_1 = None
-        neg_1 = getitem_1.neg();  getitem_1 = None
+        tensor_1: "i64[1]" = torch.tensor((12,))
+        cumsum_1: "i64[1]" = tensor_1.cumsum(dim = 0);  tensor_1 = None
+        getitem_1: "i64[0]" = cumsum_1[slice(None, -1, None)];  cumsum_1 = None
+        neg_1: "i64[0]" = getitem_1.neg();  getitem_1 = None
         unbind_1 = neg_1.unbind();  neg_1 = unbind_1 = None
 
-        chunk_1 = results.new_zeros(12, 12);  results = None
+        chunk_1: "f32[12, 12]" = results.new_zeros(12, 12);  results = None
 
-        diagonal_1 = chunk_1.diagonal(0)
-        fill__1 = diagonal_1.fill_(1);  diagonal_1 = fill__1 = None
+        diagonal_1: "f32[12]" = chunk_1.diagonal(0)
+        fill__1: "f32[12]" = diagonal_1.fill_(1);  diagonal_1 = fill__1 = None
 
-        basis = chunk_1.view(12, 4, 3);  chunk_1 = None
+        basis: "f32[12, 4, 3]" = chunk_1.view(12, 4, 3);  chunk_1 = None
 
         lazy_load_decompositions_1 = torch._functorch.vmap.lazy_load_decompositions();  lazy_load_decompositions_1 = None
 
         _vmap_increment_nesting_1 = torch._C._functorch._vmap_increment_nesting(12, 'error');  _vmap_increment_nesting_1 = None
 
-        _add_batch_dim_1 = torch._C._functorch._add_batch_dim(basis, 0, 3);  basis = None
+        _add_batch_dim_1: "f32[4, 3]" = torch._C._functorch._add_batch_dim(basis, 0, 3);  basis = None
 
         _vjp_treespec_compare = torch._functorch.eager_transforms._vjp_treespec_compare(o, _add_batch_dim_1);  _vjp_treespec_compare = None
 
         _autograd_grad = torch._functorch.eager_transforms._autograd_grad([o], [child_4], [_add_batch_dim_1], retain_graph = True, create_graph = True);  o = child_4 = _add_batch_dim_1 = None
-        child_5 = _autograd_grad[0];  _autograd_grad = None
+        child_5: "f32[3, 4]" = _autograd_grad[0];  _autograd_grad = None
 
-        child_6 = torch._C._functorch._remove_batch_dim(child_5, 3, 12, 0);  child_5 = None
+        child_6: "f32[12, 3, 4]" = torch._C._functorch._remove_batch_dim(child_5, 3, 12, 0);  child_5 = None
 
         _vmap_decrement_nesting = torch._C._functorch._vmap_decrement_nesting();  _vmap_decrement_nesting = None
 
         split = child_6.split((12,), dim = 0);  child_6 = None
-        split_1 = split[0];  split = None
+        split_1: "f32[12, 3, 4]" = split[0];  split = None
 
-        child_7 = split_1.view((4, 3, 3, 4));  split_1 = None
+        child_7: "f32[4, 3, 3, 4]" = split_1.view((4, 3, 3, 4));  split_1 = None
 
         _unpack_dual = torch._unpack_dual(child_7, level = 0);  child_7 = None
-        primal = _unpack_dual[0];  _unpack_dual = None
+        primal: "f32[4, 3, 3, 4]" = _unpack_dual[0];  _unpack_dual = None
 
-        tangent = torch.zeros_like(primal)
+        tangent: "f32[4, 3, 3, 4]" = torch.zeros_like(primal)
 
         child_8: "f32[4, 3, 3, 4]" = torch._C._functorch._unwrap_for_grad(primal, 2);  primal = child_8 = None
 
@@ -3114,15 +3114,15 @@
         _saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func.{grad, vjp, jacrev, hessian} don't yet support saved tensor hooks. Please open an issue with your use case.");  _saved_tensors_hooks_disable = None
         _grad_increment_nesting = torch._C._functorch._grad_increment_nesting();  _grad_increment_nesting = None
 
-        diff_primals = torch._C._functorch._wrap_for_grad(l_x_, 1);  l_x_ = None
+        diff_primals: "f32[4, 3]" = torch._C._functorch._wrap_for_grad(l_x_, 1);  l_x_ = None
 
         set_inplace_requires_grad_allowed = torch._C._functorch.set_inplace_requires_grad_allowed(True);  set_inplace_requires_grad_allowed = None
 
-        _set_tensor_requires_grad = torch._functorch.eager_transforms._set_tensor_requires_grad(diff_primals);  _set_tensor_requires_grad = None
+        _set_tensor_requires_grad: "f32[4, 3]" = torch._functorch.eager_transforms._set_tensor_requires_grad(diff_primals);  _set_tensor_requires_grad = None
 
         set_inplace_requires_grad_allowed_1 = torch._C._functorch.set_inplace_requires_grad_allowed(False);  set_inplace_requires_grad_allowed_1 = None
 
-        o = torch.sin(diff_primals)
+        o: "f32[4, 3]" = torch.sin(diff_primals)
 
         results: "f32[4, 3]" = torch._C._functorch._unwrap_for_grad(o, 1)
 
@@ -3146,12 +3146,12 @@
 
         _vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(12, 'error');  _vmap_increment_nesting = None
 
-        _add_batch_dim = torch._C._functorch._add_batch_dim(basis, 0, 1);  basis = None
+        _add_batch_dim: "f32[4, 3]" = torch._C._functorch._add_batch_dim(basis, 0, 1);  basis = None
 
         _vjp_treespec_compare = torch._functorch.eager_transforms._vjp_treespec_compare(o, _add_batch_dim);  _vjp_treespec_compare = None
 
         _autograd_grad = torch._functorch.eager_transforms._autograd_grad([o], [diff_primals], [_add_batch_dim], retain_graph = True, create_graph = True);  o = diff_primals = _add_batch_dim = None
-        batched_outputs = _autograd_grad[0];  _autograd_grad = None
+        batched_outputs: "f32[4, 3]" = _autograd_grad[0];  _autograd_grad = None
 
         chunked_result: "f32[12, 4, 3]" = torch._C._functorch._remove_batch_dim(batched_outputs, 1, 12, 0);  batched_outputs = None
 
@@ -3193,16 +3193,16 @@
         _saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func.{grad, vjp, jacrev, hessian} don't yet support saved tensor hooks. Please open an issue with your use case.");  _saved_tensors_hooks_disable = None
         _grad_increment_nesting = torch._C._functorch._grad_increment_nesting();  _grad_increment_nesting = None
 
-        _wrap_for_grad = torch._C._functorch._wrap_for_grad(l_x_, 1);  l_x_ = _wrap_for_grad = None
-        diff_primals = torch._C._functorch._wrap_for_grad(l_y_, 1);  l_y_ = None
+        _wrap_for_grad: "f32[4, 3]" = torch._C._functorch._wrap_for_grad(l_x_, 1);  l_x_ = _wrap_for_grad = None
+        diff_primals: "f32[3, 4]" = torch._C._functorch._wrap_for_grad(l_y_, 1);  l_y_ = None
 
         set_inplace_requires_grad_allowed = torch._C._functorch.set_inplace_requires_grad_allowed(True);  set_inplace_requires_grad_allowed = None
 
-        _set_tensor_requires_grad = torch._functorch.eager_transforms._set_tensor_requires_grad(diff_primals);  _set_tensor_requires_grad = None
+        _set_tensor_requires_grad: "f32[3, 4]" = torch._functorch.eager_transforms._set_tensor_requires_grad(diff_primals);  _set_tensor_requires_grad = None
 
         set_inplace_requires_grad_allowed_1 = torch._C._functorch.set_inplace_requires_grad_allowed(False);  set_inplace_requires_grad_allowed_1 = None
 
-        o = diff_primals.sin()
+        o: "f32[3, 4]" = diff_primals.sin()
 
         results: "f32[3, 4]" = torch._C._functorch._unwrap_for_grad(o, 1)
 
@@ -3226,12 +3226,12 @@
 
         _vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(12, 'error');  _vmap_increment_nesting = None
 
-        _add_batch_dim = torch._C._functorch._add_batch_dim(basis, 0, 1);  basis = None
+        _add_batch_dim: "f32[3, 4]" = torch._C._functorch._add_batch_dim(basis, 0, 1);  basis = None
 
         _vjp_treespec_compare = torch._functorch.eager_transforms._vjp_treespec_compare(o, _add_batch_dim);  _vjp_treespec_compare = None
 
         _autograd_grad = torch._functorch.eager_transforms._autograd_grad([o], [diff_primals], [_add_batch_dim], retain_graph = True, create_graph = True);  o = diff_primals = _add_batch_dim = None
-        batched_outputs = _autograd_grad[0];  _autograd_grad = None
+        batched_outputs: "f32[3, 4]" = _autograd_grad[0];  _autograd_grad = None
 
         chunked_result: "f32[12, 3, 4]" = torch._C._functorch._remove_batch_dim(batched_outputs, 1, 12, 0);  batched_outputs = None
 
@@ -3273,16 +3273,16 @@
         _saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func.{grad, vjp, jacrev, hessian} don't yet support saved tensor hooks. Please open an issue with your use case.");  _saved_tensors_hooks_disable = None
         _grad_increment_nesting = torch._C._functorch._grad_increment_nesting();  _grad_increment_nesting = None
 
-        aux = torch._C._functorch._wrap_for_grad(l_x_, 1);  l_x_ = None
-        diff_primals = torch._C._functorch._wrap_for_grad(l_y_, 1);  l_y_ = None
+        aux: "f32[4, 3]" = torch._C._functorch._wrap_for_grad(l_x_, 1);  l_x_ = None
+        diff_primals: "f32[3, 4]" = torch._C._functorch._wrap_for_grad(l_y_, 1);  l_y_ = None
 
         set_inplace_requires_grad_allowed = torch._C._functorch.set_inplace_requires_grad_allowed(True);  set_inplace_requires_grad_allowed = None
 
-        _set_tensor_requires_grad = torch._functorch.eager_transforms._set_tensor_requires_grad(diff_primals);  _set_tensor_requires_grad = None
+        _set_tensor_requires_grad: "f32[3, 4]" = torch._functorch.eager_transforms._set_tensor_requires_grad(diff_primals);  _set_tensor_requires_grad = None
 
         set_inplace_requires_grad_allowed_1 = torch._C._functorch.set_inplace_requires_grad_allowed(False);  set_inplace_requires_grad_allowed_1 = None
 
-        o = diff_primals.sin()
+        o: "f32[3, 4]" = diff_primals.sin()
 
         aux_1: "f32[4, 3]" = torch._C._functorch._unwrap_for_grad(aux, 1);  aux = None
 
@@ -3308,12 +3308,12 @@
 
         _vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(12, 'error');  _vmap_increment_nesting = None
 
-        _add_batch_dim = torch._C._functorch._add_batch_dim(basis, 0, 1);  basis = None
+        _add_batch_dim: "f32[3, 4]" = torch._C._functorch._add_batch_dim(basis, 0, 1);  basis = None
 
         _vjp_treespec_compare = torch._functorch.eager_transforms._vjp_treespec_compare(o, _add_batch_dim);  _vjp_treespec_compare = None
 
         _autograd_grad = torch._functorch.eager_transforms._autograd_grad([o], [diff_primals], [_add_batch_dim], retain_graph = True, create_graph = True);  o = diff_primals = _add_batch_dim = None
-        batched_outputs = _autograd_grad[0];  _autograd_grad = None
+        batched_outputs: "f32[3, 4]" = _autograd_grad[0];  _autograd_grad = None
 
         chunked_result: "f32[12, 3, 4]" = torch._C._functorch._remove_batch_dim(batched_outputs, 1, 12, 0);  batched_outputs = None
 
@@ -3382,16 +3382,16 @@
         _saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func.{grad, vjp, jacrev, hessian} don't yet support saved tensor hooks. Please open an issue with your use case.");  _saved_tensors_hooks_disable = None
         _grad_increment_nesting = torch._C._functorch._grad_increment_nesting();  _grad_increment_nesting = None
 
-        child = torch._C._functorch._wrap_for_grad(l_x_, 1);  l_x_ = None
+        child: "f32[5]" = torch._C._functorch._wrap_for_grad(l_x_, 1);  l_x_ = None
 
         set_inplace_requires_grad_allowed = torch._C._functorch.set_inplace_requires_grad_allowed(True);  set_inplace_requires_grad_allowed = None
 
-        child_1 = torch._functorch.eager_transforms._set_tensor_requires_grad(child);  child_1 = None
+        child_1: "f32[5]" = torch._functorch.eager_transforms._set_tensor_requires_grad(child);  child_1 = None
 
         set_inplace_requires_grad_allowed_1 = torch._C._functorch.set_inplace_requires_grad_allowed(False);  set_inplace_requires_grad_allowed_1 = None
 
-        sin = child.sin();  child = None
-        o = sin.sum();  sin = None
+        sin: "f32[5]" = child.sin();  child = None
+        o: "f32[]" = sin.sum();  sin = None
 
         results: "f32[]" = torch._C._functorch._unwrap_for_grad(o, 1);  o = None
 
@@ -3430,16 +3430,16 @@
         _saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func.{grad, vjp, jacrev, hessian} don't yet support saved tensor hooks. Please open an issue with your use case.");  _saved_tensors_hooks_disable = None
         _grad_increment_nesting = torch._C._functorch._grad_increment_nesting();  _grad_increment_nesting = None
 
-        child = torch._C._functorch._wrap_for_grad(l_x_, 1);  l_x_ = None
+        child: "f32[5]" = torch._C._functorch._wrap_for_grad(l_x_, 1);  l_x_ = None
 
         set_inplace_requires_grad_allowed = torch._C._functorch.set_inplace_requires_grad_allowed(True);  set_inplace_requires_grad_allowed = None
 
-        child_3 = torch._functorch.eager_transforms._set_tensor_requires_grad(child)
+        child_3: "f32[5]" = torch._functorch.eager_transforms._set_tensor_requires_grad(child)
 
         set_inplace_requires_grad_allowed_1 = torch._C._functorch.set_inplace_requires_grad_allowed(False);  set_inplace_requires_grad_allowed_1 = None
 
-        child_1 = child.sin()
-        child_2 = child.cos();  child = None
+        child_1: "f32[5]" = child.sin()
+        child_2: "f32[5]" = child.cos();  child = None
 
         _unwrap_for_grad: "f32[5]" = torch._C._functorch._unwrap_for_grad(child_1, 1)
         _unwrap_for_grad_1: "f32[5]" = torch._C._functorch._unwrap_for_grad(child_2, 1)
@@ -3484,16 +3484,16 @@
         _saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func.{grad, vjp, jacrev, hessian} don't yet support saved tensor hooks. Please open an issue with your use case.");  _saved_tensors_hooks_disable = None
         _grad_increment_nesting = torch._C._functorch._grad_increment_nesting();  _grad_increment_nesting = None
 
-        child = torch._C._functorch._wrap_for_grad(l_x_, 1);  l_x_ = None
+        child: "f32[5]" = torch._C._functorch._wrap_for_grad(l_x_, 1);  l_x_ = None
 
         set_inplace_requires_grad_allowed = torch._C._functorch.set_inplace_requires_grad_allowed(True);  set_inplace_requires_grad_allowed = None
 
-        child_3 = torch._functorch.eager_transforms._set_tensor_requires_grad(child)
+        child_3: "f32[5]" = torch._functorch.eager_transforms._set_tensor_requires_grad(child)
 
         set_inplace_requires_grad_allowed_1 = torch._C._functorch.set_inplace_requires_grad_allowed(False);  set_inplace_requires_grad_allowed_1 = None
 
-        child_1 = child.sin()
-        child_2 = child.cos();  child = None
+        child_1: "f32[5]" = child.sin()
+        child_2: "f32[5]" = child.cos();  child = None
 
         _unwrap_for_grad: "f32[5]" = torch._C._functorch._unwrap_for_grad(child_1, 1)
         _unwrap_for_grad_1: "f32[5]" = torch._C._functorch._unwrap_for_grad(child_2, 1)
@@ -3540,16 +3540,16 @@
         _saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func.{grad, vjp, jacrev, hessian} don't yet support saved tensor hooks. Please open an issue with your use case.");  _saved_tensors_hooks_disable = None
         _grad_increment_nesting = torch._C._functorch._grad_increment_nesting();  _grad_increment_nesting = None
 
-        child = torch._C._functorch._wrap_for_grad(l_x_, 1);  l_x_ = None
+        child: "f32[5]" = torch._C._functorch._wrap_for_grad(l_x_, 1);  l_x_ = None
 
         set_inplace_requires_grad_allowed = torch._C._functorch.set_inplace_requires_grad_allowed(True);  set_inplace_requires_grad_allowed = None
 
-        child_1 = torch._functorch.eager_transforms._set_tensor_requires_grad(child);  child_1 = None
+        child_1: "f32[5]" = torch._functorch.eager_transforms._set_tensor_requires_grad(child);  child_1 = None
 
         set_inplace_requires_grad_allowed_1 = torch._C._functorch.set_inplace_requires_grad_allowed(False);  set_inplace_requires_grad_allowed_1 = None
 
-        sin = child.sin()
-        o = sin.sum();  sin = None
+        sin: "f32[5]" = child.sin()
+        o: "f32[]" = sin.sum();  sin = None
 
         aux: "f32[5]" = torch._C._functorch._unwrap_for_grad(child, 1);  child = aux = None
 
@@ -3781,19 +3781,19 @@
         _saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func.{grad, vjp, jacrev, hessian} don't yet support saved tensor hooks. Please open an issue with your use case.");  _saved_tensors_hooks_disable = None
         _grad_increment_nesting = torch._C._functorch._grad_increment_nesting();  _grad_increment_nesting = None
 
-        diff_args = torch._C._functorch._wrap_for_grad(l_x_, 1);  l_x_ = None
+        diff_args: "f32[3, 3, 3]" = torch._C._functorch._wrap_for_grad(l_x_, 1);  l_x_ = None
 
         set_inplace_requires_grad_allowed = torch._C._functorch.set_inplace_requires_grad_allowed(True);  set_inplace_requires_grad_allowed = None
 
-        _set_tensor_requires_grad = torch._functorch.eager_transforms._set_tensor_requires_grad(diff_args);  _set_tensor_requires_grad = None
+        _set_tensor_requires_grad: "f32[3, 3, 3]" = torch._functorch.eager_transforms._set_tensor_requires_grad(diff_args);  _set_tensor_requires_grad = None
 
         set_inplace_requires_grad_allowed_1 = torch._C._functorch.set_inplace_requires_grad_allowed(False);  set_inplace_requires_grad_allowed_1 = None
 
-        sin = diff_args.sin()
-        output = sin.sum();  sin = None
+        sin: "f32[3, 3, 3]" = diff_args.sin()
+        output: "f32[]" = sin.sum();  sin = None
 
         _autograd_grad = torch._functorch.eager_transforms._autograd_grad((output,), [diff_args], create_graph = True);  diff_args = None
-        grad_input = _autograd_grad[0];  _autograd_grad = None
+        grad_input: "f32[3, 3, 3]" = _autograd_grad[0];  _autograd_grad = None
 
         grad_input_1: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(grad_input, 1);  grad_input = None
 
@@ -3848,20 +3848,20 @@
         _saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func.{grad, vjp, jacrev, hessian} don't yet support saved tensor hooks. Please open an issue with your use case.");  _saved_tensors_hooks_disable = None
         _grad_increment_nesting = torch._C._functorch._grad_increment_nesting();  _grad_increment_nesting = None
 
-        diff_args = torch._C._functorch._wrap_for_grad(l_x_, 1);  l_x_ = None
+        diff_args: "f32[3, 3, 3]" = torch._C._functorch._wrap_for_grad(l_x_, 1);  l_x_ = None
 
         set_inplace_requires_grad_allowed = torch._C._functorch.set_inplace_requires_grad_allowed(True);  set_inplace_requires_grad_allowed = None
 
-        _set_tensor_requires_grad = torch._functorch.eager_transforms._set_tensor_requires_grad(diff_args);  _set_tensor_requires_grad = None
+        _set_tensor_requires_grad: "f32[3, 3, 3]" = torch._functorch.eager_transforms._set_tensor_requires_grad(diff_args);  _set_tensor_requires_grad = None
 
         set_inplace_requires_grad_allowed_1 = torch._C._functorch.set_inplace_requires_grad_allowed(False);  set_inplace_requires_grad_allowed_1 = None
 
-        sin = diff_args.sin()
-        add = sin + 3;  sin = None
-        output = add.sum();  add = None
+        sin: "f32[3, 3, 3]" = diff_args.sin()
+        add: "f32[3, 3, 3]" = sin + 3;  sin = None
+        output: "f32[]" = add.sum();  add = None
 
         _autograd_grad = torch._functorch.eager_transforms._autograd_grad((output,), [diff_args], create_graph = True);  diff_args = None
-        grad_input = _autograd_grad[0];  _autograd_grad = None
+        grad_input: "f32[3, 3, 3]" = _autograd_grad[0];  _autograd_grad = None
 
         grad_input_1: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(grad_input, 1);  grad_input = None
 
@@ -3905,20 +3905,20 @@
         _saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func.{grad, vjp, jacrev, hessian} don't yet support saved tensor hooks. Please open an issue with your use case.");  _saved_tensors_hooks_disable = None
         _grad_increment_nesting = torch._C._functorch._grad_increment_nesting();  _grad_increment_nesting = None
 
-        diff_args = torch._C._functorch._wrap_for_grad(l_x_, 1);  l_x_ = None
+        diff_args: "f32[3, 3, 3]" = torch._C._functorch._wrap_for_grad(l_x_, 1);  l_x_ = None
 
         set_inplace_requires_grad_allowed = torch._C._functorch.set_inplace_requires_grad_allowed(True);  set_inplace_requires_grad_allowed = None
 
-        _set_tensor_requires_grad = torch._functorch.eager_transforms._set_tensor_requires_grad(diff_args);  _set_tensor_requires_grad = None
+        _set_tensor_requires_grad: "f32[3, 3, 3]" = torch._functorch.eager_transforms._set_tensor_requires_grad(diff_args);  _set_tensor_requires_grad = None
 
         set_inplace_requires_grad_allowed_1 = torch._C._functorch.set_inplace_requires_grad_allowed(False);  set_inplace_requires_grad_allowed_1 = None
 
-        sin = diff_args.sin()
-        add = sin + y;  sin = None
-        output = add.sum();  add = None
+        sin: "f32[3, 3, 3]" = diff_args.sin()
+        add: "f32[3, 3, 3]" = sin + y;  sin = None
+        output: "f32[]" = add.sum();  add = None
 
         _autograd_grad = torch._functorch.eager_transforms._autograd_grad((output,), [diff_args], create_graph = True);  diff_args = None
-        grad_input = _autograd_grad[0];  _autograd_grad = None
+        grad_input: "f32[3, 3, 3]" = _autograd_grad[0];  _autograd_grad = None
 
         grad_input_1: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(grad_input, 1);  grad_input = None
 
@@ -3962,20 +3962,20 @@
         _saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func.{grad, vjp, jacrev, hessian} don't yet support saved tensor hooks. Please open an issue with your use case.");  _saved_tensors_hooks_disable = None
         _grad_increment_nesting = torch._C._functorch._grad_increment_nesting();  _grad_increment_nesting = None
 
-        diff_args = torch._C._functorch._wrap_for_grad(l_x_, 1);  l_x_ = None
+        diff_args: "f32[3, 3, 3]" = torch._C._functorch._wrap_for_grad(l_x_, 1);  l_x_ = None
 
         set_inplace_requires_grad_allowed = torch._C._functorch.set_inplace_requires_grad_allowed(True);  set_inplace_requires_grad_allowed = None
 
-        _set_tensor_requires_grad = torch._functorch.eager_transforms._set_tensor_requires_grad(diff_args);  _set_tensor_requires_grad = None
+        _set_tensor_requires_grad: "f32[3, 3, 3]" = torch._functorch.eager_transforms._set_tensor_requires_grad(diff_args);  _set_tensor_requires_grad = None
 
         set_inplace_requires_grad_allowed_1 = torch._C._functorch.set_inplace_requires_grad_allowed(False);  set_inplace_requires_grad_allowed_1 = None
 
-        sin = diff_args.sin()
-        add = sin + 3.14;  sin = None
-        output = add.sum();  add = None
+        sin: "f32[3, 3, 3]" = diff_args.sin()
+        add: "f32[3, 3, 3]" = sin + 3.14;  sin = None
+        output: "f32[]" = add.sum();  add = None
 
         _autograd_grad = torch._functorch.eager_transforms._autograd_grad((output,), [diff_args], create_graph = True);  diff_args = None
-        grad_input = _autograd_grad[0];  _autograd_grad = None
+        grad_input: "f32[3, 3, 3]" = _autograd_grad[0];  _autograd_grad = None
 
         grad_input_1: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(grad_input, 1);  grad_input = None
 
@@ -4016,21 +4016,21 @@
         _saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func.{grad, vjp, jacrev, hessian} don't yet support saved tensor hooks. Please open an issue with your use case.");  _saved_tensors_hooks_disable = None
         _grad_increment_nesting = torch._C._functorch._grad_increment_nesting();  _grad_increment_nesting = None
 
-        diff_args = torch._C._functorch._wrap_for_grad(l_x_, 1);  l_x_ = None
+        diff_args: "f32[3, 3, 3]" = torch._C._functorch._wrap_for_grad(l_x_, 1);  l_x_ = None
 
         set_inplace_requires_grad_allowed = torch._C._functorch.set_inplace_requires_grad_allowed(True);  set_inplace_requires_grad_allowed = None
 
-        _set_tensor_requires_grad = torch._functorch.eager_transforms._set_tensor_requires_grad(diff_args);  _set_tensor_requires_grad = None
+        _set_tensor_requires_grad: "f32[3, 3, 3]" = torch._functorch.eager_transforms._set_tensor_requires_grad(diff_args);  _set_tensor_requires_grad = None
 
         set_inplace_requires_grad_allowed_1 = torch._C._functorch.set_inplace_requires_grad_allowed(False);  set_inplace_requires_grad_allowed_1 = None
 
-        sin = diff_args.sin()
-        add = sin + 3.14;  sin = None
-        output = add.sum();  add = None
-        aux = diff_args.cos()
+        sin: "f32[3, 3, 3]" = diff_args.sin()
+        add: "f32[3, 3, 3]" = sin + 3.14;  sin = None
+        output: "f32[]" = add.sum();  add = None
+        aux: "f32[3, 3, 3]" = diff_args.cos()
 
         _autograd_grad = torch._functorch.eager_transforms._autograd_grad((output,), [diff_args], create_graph = True);  diff_args = None
-        grad_input = _autograd_grad[0];  _autograd_grad = None
+        grad_input: "f32[3, 3, 3]" = _autograd_grad[0];  _autograd_grad = None
 
         grad_input_1: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(grad_input, 1);  grad_input = None
 
@@ -4073,22 +4073,22 @@
         _saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func.{grad, vjp, jacrev, hessian} don't yet support saved tensor hooks. Please open an issue with your use case.");  _saved_tensors_hooks_disable = None
         _grad_increment_nesting = torch._C._functorch._grad_increment_nesting();  _grad_increment_nesting = None
 
-        diff_args = torch._C._functorch._wrap_for_grad(l_x_, 1);  l_x_ = None
-        _wrap_for_grad_1 = torch._C._functorch._wrap_for_grad(l_y_, 1);  l_y_ = None
+        diff_args: "f32[3, 3, 3]" = torch._C._functorch._wrap_for_grad(l_x_, 1);  l_x_ = None
+        _wrap_for_grad_1: "f32[3, 3, 3]" = torch._C._functorch._wrap_for_grad(l_y_, 1);  l_y_ = None
 
         set_inplace_requires_grad_allowed = torch._C._functorch.set_inplace_requires_grad_allowed(True);  set_inplace_requires_grad_allowed = None
 
-        _set_tensor_requires_grad = torch._functorch.eager_transforms._set_tensor_requires_grad(diff_args);  _set_tensor_requires_grad = None
+        _set_tensor_requires_grad: "f32[3, 3, 3]" = torch._functorch.eager_transforms._set_tensor_requires_grad(diff_args);  _set_tensor_requires_grad = None
 
         set_inplace_requires_grad_allowed_1 = torch._C._functorch.set_inplace_requires_grad_allowed(False);  set_inplace_requires_grad_allowed_1 = None
 
-        sin = diff_args.sin()
-        add = sin + _wrap_for_grad_1;  sin = _wrap_for_grad_1 = None
-        output = add.sum();  add = None
-        aux = diff_args.cos()
+        sin: "f32[3, 3, 3]" = diff_args.sin()
+        add: "f32[3, 3, 3]" = sin + _wrap_for_grad_1;  sin = _wrap_for_grad_1 = None
+        output: "f32[]" = add.sum();  add = None
+        aux: "f32[3, 3, 3]" = diff_args.cos()
 
         _autograd_grad = torch._functorch.eager_transforms._autograd_grad((output,), [diff_args], create_graph = True);  diff_args = None
-        grad_input = _autograd_grad[0];  _autograd_grad = None
+        grad_input: "f32[3, 3, 3]" = _autograd_grad[0];  _autograd_grad = None
 
         grad_input_1: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(grad_input, 1);  grad_input = None
 
@@ -4142,28 +4142,28 @@
         _saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func.{grad, vjp, jacrev, hessian} don't yet support saved tensor hooks. Please open an issue with your use case.");  _saved_tensors_hooks_disable = None
         _grad_increment_nesting = torch._C._functorch._grad_increment_nesting();  _grad_increment_nesting = None
 
-        child = torch._C._functorch._wrap_for_grad(l_x_, 1);  l_x_ = None
-        child_1 = torch._C._functorch._wrap_for_grad(l_y_, 1);  l_y_ = None
+        child: "f32[3, 3, 3]" = torch._C._functorch._wrap_for_grad(l_x_, 1);  l_x_ = None
+        child_1: "f32[3, 3, 3]" = torch._C._functorch._wrap_for_grad(l_y_, 1);  l_y_ = None
 
         set_inplace_requires_grad_allowed = torch._C._functorch.set_inplace_requires_grad_allowed(True);  set_inplace_requires_grad_allowed = None
 
-        _set_tensor_requires_grad = torch._functorch.eager_transforms._set_tensor_requires_grad(child);  _set_tensor_requires_grad = None
+        _set_tensor_requires_grad: "f32[3, 3, 3]" = torch._functorch.eager_transforms._set_tensor_requires_grad(child);  _set_tensor_requires_grad = None
 
         set_inplace_requires_grad_allowed_1 = torch._C._functorch.set_inplace_requires_grad_allowed(False);  set_inplace_requires_grad_allowed_1 = None
         set_inplace_requires_grad_allowed_2 = torch._C._functorch.set_inplace_requires_grad_allowed(True);  set_inplace_requires_grad_allowed_2 = None
 
-        _set_tensor_requires_grad_1 = torch._functorch.eager_transforms._set_tensor_requires_grad(child_1);  _set_tensor_requires_grad_1 = None
+        _set_tensor_requires_grad_1: "f32[3, 3, 3]" = torch._functorch.eager_transforms._set_tensor_requires_grad(child_1);  _set_tensor_requires_grad_1 = None
 
         set_inplace_requires_grad_allowed_3 = torch._C._functorch.set_inplace_requires_grad_allowed(False);  set_inplace_requires_grad_allowed_3 = None
 
-        sin = child.sin()
-        add = sin + child_1;  sin = None
-        output = add.sum();  add = None
-        aux = child.cos()
+        sin: "f32[3, 3, 3]" = child.sin()
+        add: "f32[3, 3, 3]" = sin + child_1;  sin = None
+        output: "f32[]" = add.sum();  add = None
+        aux: "f32[3, 3, 3]" = child.cos()
 
         _autograd_grad = torch._functorch.eager_transforms._autograd_grad((output,), [child, child_1], create_graph = True);  child = child_1 = None
-        child_2 = _autograd_grad[0]
-        child_3 = _autograd_grad[1];  _autograd_grad = None
+        child_2: "f32[3, 3, 3]" = _autograd_grad[0]
+        child_3: "f32[3, 3, 3]" = _autograd_grad[1];  _autograd_grad = None
 
         _unwrap_for_grad: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(child_2, 1);  child_2 = None
         _unwrap_for_grad_1: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(child_3, 1);  child_3 = None
@@ -4188,28 +4188,28 @@
         _saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func.{grad, vjp, jacrev, hessian} don't yet support saved tensor hooks. Please open an issue with your use case.");  _saved_tensors_hooks_disable = None
         _grad_increment_nesting = torch._C._functorch._grad_increment_nesting();  _grad_increment_nesting = None
 
-        child = torch._C._functorch._wrap_for_grad(l_x_, 1);  l_x_ = None
-        child_1 = torch._C._functorch._wrap_for_grad(l_y_, 1);  l_y_ = None
+        child: "f32[3, 3, 3]" = torch._C._functorch._wrap_for_grad(l_x_, 1);  l_x_ = None
+        child_1: "f32[3, 3, 3]" = torch._C._functorch._wrap_for_grad(l_y_, 1);  l_y_ = None
 
         set_inplace_requires_grad_allowed = torch._C._functorch.set_inplace_requires_grad_allowed(True);  set_inplace_requires_grad_allowed = None
 
-        _set_tensor_requires_grad = torch._functorch.eager_transforms._set_tensor_requires_grad(child);  _set_tensor_requires_grad = None
+        _set_tensor_requires_grad: "f32[3, 3, 3]" = torch._functorch.eager_transforms._set_tensor_requires_grad(child);  _set_tensor_requires_grad = None
 
         set_inplace_requires_grad_allowed_1 = torch._C._functorch.set_inplace_requires_grad_allowed(False);  set_inplace_requires_grad_allowed_1 = None
         set_inplace_requires_grad_allowed_2 = torch._C._functorch.set_inplace_requires_grad_allowed(True);  set_inplace_requires_grad_allowed_2 = None
 
-        _set_tensor_requires_grad_1 = torch._functorch.eager_transforms._set_tensor_requires_grad(child_1);  _set_tensor_requires_grad_1 = None
+        _set_tensor_requires_grad_1: "f32[3, 3, 3]" = torch._functorch.eager_transforms._set_tensor_requires_grad(child_1);  _set_tensor_requires_grad_1 = None
 
         set_inplace_requires_grad_allowed_3 = torch._C._functorch.set_inplace_requires_grad_allowed(False);  set_inplace_requires_grad_allowed_3 = None
 
-        sin = child.sin()
-        add = sin + child_1;  sin = None
-        output = add.sum();  add = None
-        aux = child.cos()
+        sin: "f32[3, 3, 3]" = child.sin()
+        add: "f32[3, 3, 3]" = sin + child_1;  sin = None
+        output: "f32[]" = add.sum();  add = None
+        aux: "f32[3, 3, 3]" = child.cos()
 
         _autograd_grad = torch._functorch.eager_transforms._autograd_grad((output,), [child, child_1], create_graph = True);  child = child_1 = None
-        child_2 = _autograd_grad[0]
-        child_3 = _autograd_grad[1];  _autograd_grad = None
+        child_2: "f32[3, 3, 3]" = _autograd_grad[0]
+        child_3: "f32[3, 3, 3]" = _autograd_grad[1];  _autograd_grad = None
 
         _unwrap_for_grad: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(child_2, 1);  child_2 = None
         _unwrap_for_grad_1: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(child_3, 1);  child_3 = None
@@ -4250,39 +4250,39 @@
         _saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func.{grad, vjp, jacrev, hessian} don't yet support saved tensor hooks. Please open an issue with your use case.");  _saved_tensors_hooks_disable = None
         _grad_increment_nesting = torch._C._functorch._grad_increment_nesting();  _grad_increment_nesting = None
 
-        diff_args = torch._C._functorch._wrap_for_grad(l_x_, 1);  l_x_ = None
+        diff_args: "f32[]" = torch._C._functorch._wrap_for_grad(l_x_, 1);  l_x_ = None
 
         set_inplace_requires_grad_allowed = torch._C._functorch.set_inplace_requires_grad_allowed(True);  set_inplace_requires_grad_allowed = None
 
-        _set_tensor_requires_grad = torch._functorch.eager_transforms._set_tensor_requires_grad(diff_args);  _set_tensor_requires_grad = None
+        _set_tensor_requires_grad: "f32[]" = torch._functorch.eager_transforms._set_tensor_requires_grad(diff_args);  _set_tensor_requires_grad = None
 
         set_inplace_requires_grad_allowed_1 = torch._C._functorch.set_inplace_requires_grad_allowed(False);  set_inplace_requires_grad_allowed_1 = None
         _saved_tensors_hooks_disable_1 = torch._C._autograd._saved_tensors_hooks_disable("torch.func.{grad, vjp, jacrev, hessian} don't yet support saved tensor hooks. Please open an issue with your use case.");  _saved_tensors_hooks_disable_1 = None
         _grad_increment_nesting_1 = torch._C._functorch._grad_increment_nesting();  _grad_increment_nesting_1 = None
 
-        diff_args_1 = torch._C._functorch._wrap_for_grad(diff_args, 2)
+        diff_args_1: "f32[]" = torch._C._functorch._wrap_for_grad(diff_args, 2)
 
         set_inplace_requires_grad_allowed_2 = torch._C._functorch.set_inplace_requires_grad_allowed(True);  set_inplace_requires_grad_allowed_2 = None
 
-        _set_tensor_requires_grad_1 = torch._functorch.eager_transforms._set_tensor_requires_grad(diff_args_1);  _set_tensor_requires_grad_1 = None
+        _set_tensor_requires_grad_1: "f32[]" = torch._functorch.eager_transforms._set_tensor_requires_grad(diff_args_1);  _set_tensor_requires_grad_1 = None
 
         set_inplace_requires_grad_allowed_3 = torch._C._functorch.set_inplace_requires_grad_allowed(False);  set_inplace_requires_grad_allowed_3 = None
 
-        sin = diff_args_1.sin()
-        output = sin.sum();  sin = None
+        sin: "f32[]" = diff_args_1.sin()
+        output: "f32[]" = sin.sum();  sin = None
 
         _autograd_grad = torch._functorch.eager_transforms._autograd_grad((output,), [diff_args_1], create_graph = True);  diff_args_1 = None
-        grad_input = _autograd_grad[0];  _autograd_grad = None
+        grad_input: "f32[]" = _autograd_grad[0];  _autograd_grad = None
 
-        grad_input_1 = torch._C._functorch._unwrap_for_grad(grad_input, 2);  grad_input = None
+        grad_input_1: "f32[]" = torch._C._functorch._unwrap_for_grad(grad_input, 2);  grad_input = None
 
-        output_1 = torch._C._functorch._unwrap_for_grad(output, 2);  output = output_1 = None
+        output_1: "f32[]" = torch._C._functorch._unwrap_for_grad(output, 2);  output = output_1 = None
 
         _grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting();  _grad_decrement_nesting = None
         _saved_tensors_hooks_disable_2 = torch._C._autograd._saved_tensors_hooks_disable("torch.func.{grad, vjp, jacrev, hessian} don't yet support saved tensor hooks. Please open an issue with your use case.");  _saved_tensors_hooks_disable_2 = None
 
         _autograd_grad_1 = torch._functorch.eager_transforms._autograd_grad((grad_input_1,), [diff_args], create_graph = True);  diff_args = None
-        grad_input_2 = _autograd_grad_1[0];  _autograd_grad_1 = None
+        grad_input_2: "f32[]" = _autograd_grad_1[0];  _autograd_grad_1 = None
 
         grad_input_3: "f32[]" = torch._C._functorch._unwrap_for_grad(grad_input_2, 1);  grad_input_2 = None
 
@@ -4375,20 +4375,20 @@
         _saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func.{grad, vjp, jacrev, hessian} don't yet support saved tensor hooks. Please open an issue with your use case.");  _saved_tensors_hooks_disable = None
         _grad_increment_nesting = torch._C._functorch._grad_increment_nesting();  _grad_increment_nesting = None
 
-        diff_args = torch._C._functorch._wrap_for_grad(l_x_, 1);  l_x_ = None
+        diff_args: "f32[3, 3, 3]" = torch._C._functorch._wrap_for_grad(l_x_, 1);  l_x_ = None
 
         set_inplace_requires_grad_allowed = torch._C._functorch.set_inplace_requires_grad_allowed(True);  set_inplace_requires_grad_allowed = None
 
-        _set_tensor_requires_grad = torch._functorch.eager_transforms._set_tensor_requires_grad(diff_args);  _set_tensor_requires_grad = None
+        _set_tensor_requires_grad: "f32[3, 3, 3]" = torch._functorch.eager_transforms._set_tensor_requires_grad(diff_args);  _set_tensor_requires_grad = None
 
         set_inplace_requires_grad_allowed_1 = torch._C._functorch.set_inplace_requires_grad_allowed(False);  set_inplace_requires_grad_allowed_1 = None
 
-        sin = diff_args.sin()
-        sum_1 = sin.sum();  sin = None
-        output = sum_1 + 3.0;  sum_1 = None
+        sin: "f32[3, 3, 3]" = diff_args.sin()
+        sum_1: "f32[]" = sin.sum();  sin = None
+        output: "f32[]" = sum_1 + 3.0;  sum_1 = None
 
         _autograd_grad = torch._functorch.eager_transforms._autograd_grad((output,), [diff_args], create_graph = True);  diff_args = None
-        grad_input = _autograd_grad[0];  _autograd_grad = None
+        grad_input: "f32[3, 3, 3]" = _autograd_grad[0];  _autograd_grad = None
 
         grad_input_1: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(grad_input, 1);  grad_input = None
 
@@ -4478,7 +4478,7 @@
 
         _vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(12, 'error');  _vmap_increment_nesting = None
 
-        child_1 = torch._C._functorch._add_batch_dim(child, 0, 1);  child = None
+        child_1: "f32[4, 3]" = torch._C._functorch._add_batch_dim(child, 0, 1);  child = None
 
         _jvp_treespec_compare = torch._functorch.eager_transforms._jvp_treespec_compare((l_x_,), (child_1,));  _jvp_treespec_compare = None
 
@@ -4488,19 +4488,19 @@
 
         _maybe_load_decompositions = torch.autograd.forward_ad._maybe_load_decompositions();  _maybe_load_decompositions = None
 
-        _make_dual = torch._make_dual(l_x_, child_1, level = 0);  child_1 = None
+        _make_dual: "f32[4, 3]" = torch._make_dual(l_x_, child_1, level = 0);  child_1 = None
 
-        _wrap_for_grad = torch._C._functorch._wrap_for_grad(l_x_, 2);  l_x_ = _wrap_for_grad = None
+        _wrap_for_grad: "f32[4, 3]" = torch._C._functorch._wrap_for_grad(l_x_, 2);  l_x_ = _wrap_for_grad = None
 
-        result_duals = torch.sin(_make_dual);  _make_dual = None
+        result_duals: "f32[4, 3]" = torch.sin(_make_dual);  _make_dual = None
 
         _unpack_dual = torch._unpack_dual(result_duals, level = 0);  result_duals = None
-        primal = _unpack_dual[0]
-        dual = _unpack_dual[1];  _unpack_dual = None
+        primal: "f32[4, 3]" = _unpack_dual[0]
+        dual: "f32[4, 3]" = _unpack_dual[1];  _unpack_dual = None
 
         primals_out_unflatten: "f32[4, 3]" = torch._C._functorch._unwrap_for_grad(primal, 2);  primal = primals_out_unflatten = None
 
-        tangents_out_unflatten = torch._C._functorch._unwrap_for_grad(dual, 2);  dual = None
+        tangents_out_unflatten: "f32[4, 3]" = torch._C._functorch._unwrap_for_grad(dual, 2);  dual = None
 
         _exit_dual_level = torch._C._exit_dual_level(0);  _exit_dual_level = None
         _set_fwd_grad_enabled_1 = torch._C._set_fwd_grad_enabled(True);  _set_fwd_grad_enabled_1 = None
@@ -4561,7 +4561,7 @@
 
         _vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(12, 'error');  _vmap_increment_nesting = None
 
-        child_1 = torch._C._functorch._add_batch_dim(child, 0, 1);  child = None
+        child_1: "f32[3, 4]" = torch._C._functorch._add_batch_dim(child, 0, 1);  child = None
 
         _jvp_treespec_compare = torch._functorch.eager_transforms._jvp_treespec_compare((l_y_,), (child_1,));  _jvp_treespec_compare = None
 
@@ -4571,20 +4571,20 @@
 
         _maybe_load_decompositions = torch.autograd.forward_ad._maybe_load_decompositions();  _maybe_load_decompositions = None
 
-        _make_dual = torch._make_dual(l_y_, child_1, level = 0);  child_1 = None
+        _make_dual: "f32[3, 4]" = torch._make_dual(l_y_, child_1, level = 0);  child_1 = None
 
-        _wrap_for_grad = torch._C._functorch._wrap_for_grad(l_x_, 2);  l_x_ = _wrap_for_grad = None
-        _wrap_for_grad_1 = torch._C._functorch._wrap_for_grad(l_y_, 2);  l_y_ = _wrap_for_grad_1 = None
+        _wrap_for_grad: "f32[4, 3]" = torch._C._functorch._wrap_for_grad(l_x_, 2);  l_x_ = _wrap_for_grad = None
+        _wrap_for_grad_1: "f32[3, 4]" = torch._C._functorch._wrap_for_grad(l_y_, 2);  l_y_ = _wrap_for_grad_1 = None
 
-        result_duals = _make_dual.sin();  _make_dual = None
+        result_duals: "f32[3, 4]" = _make_dual.sin();  _make_dual = None
 
         _unpack_dual = torch._unpack_dual(result_duals, level = 0);  result_duals = None
-        primal = _unpack_dual[0]
-        dual = _unpack_dual[1];  _unpack_dual = None
+        primal: "f32[3, 4]" = _unpack_dual[0]
+        dual: "f32[3, 4]" = _unpack_dual[1];  _unpack_dual = None
 
         primals_out_unflatten: "f32[3, 4]" = torch._C._functorch._unwrap_for_grad(primal, 2);  primal = primals_out_unflatten = None
 
-        tangents_out_unflatten = torch._C._functorch._unwrap_for_grad(dual, 2);  dual = None
+        tangents_out_unflatten: "f32[3, 4]" = torch._C._functorch._unwrap_for_grad(dual, 2);  dual = None
 
         _exit_dual_level = torch._C._exit_dual_level(0);  _exit_dual_level = None
         _set_fwd_grad_enabled_1 = torch._C._set_fwd_grad_enabled(True);  _set_fwd_grad_enabled_1 = None
@@ -4645,7 +4645,7 @@
 
         _vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(12, 'error');  _vmap_increment_nesting = None
 
-        child_1 = torch._C._functorch._add_batch_dim(child, 0, 1);  child = None
+        child_1: "f32[3, 4]" = torch._C._functorch._add_batch_dim(child, 0, 1);  child = None
 
         _jvp_treespec_compare = torch._functorch.eager_transforms._jvp_treespec_compare((l_y_,), (child_1,));  _jvp_treespec_compare = None
 
@@ -4655,22 +4655,22 @@
 
         _maybe_load_decompositions = torch.autograd.forward_ad._maybe_load_decompositions();  _maybe_load_decompositions = None
 
-        _make_dual = torch._make_dual(l_y_, child_1, level = 0);  child_1 = None
+        _make_dual: "f32[3, 4]" = torch._make_dual(l_y_, child_1, level = 0);  child_1 = None
 
-        aux = torch._C._functorch._wrap_for_grad(l_x_, 2);  l_x_ = None
-        _wrap_for_grad_1 = torch._C._functorch._wrap_for_grad(l_y_, 2);  l_y_ = _wrap_for_grad_1 = None
+        aux: "f32[4, 3]" = torch._C._functorch._wrap_for_grad(l_x_, 2);  l_x_ = None
+        _wrap_for_grad_1: "f32[3, 4]" = torch._C._functorch._wrap_for_grad(l_y_, 2);  l_y_ = _wrap_for_grad_1 = None
 
-        result_duals = _make_dual.sin();  _make_dual = None
+        result_duals: "f32[3, 4]" = _make_dual.sin();  _make_dual = None
 
         aux_1: "f32[4, 3]" = torch._C._functorch._unwrap_for_grad(aux, 2);  aux = None
 
         _unpack_dual = torch._unpack_dual(result_duals, level = 0);  result_duals = None
-        primal = _unpack_dual[0]
-        dual = _unpack_dual[1];  _unpack_dual = None
+        primal: "f32[3, 4]" = _unpack_dual[0]
+        dual: "f32[3, 4]" = _unpack_dual[1];  _unpack_dual = None
 
         primals_out_unflatten: "f32[3, 4]" = torch._C._functorch._unwrap_for_grad(primal, 2);  primal = primals_out_unflatten = None
 
-        tangents_out_unflatten = torch._C._functorch._unwrap_for_grad(dual, 2);  dual = None
+        tangents_out_unflatten: "f32[3, 4]" = torch._C._functorch._unwrap_for_grad(dual, 2);  dual = None
 
         _exit_dual_level = torch._C._exit_dual_level(0);  _exit_dual_level = None
         _set_fwd_grad_enabled_1 = torch._C._set_fwd_grad_enabled(True);  _set_fwd_grad_enabled_1 = None
@@ -4734,7 +4734,7 @@
 
         _vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(12, 'same');  _vmap_increment_nesting = None
 
-        child_1 = torch._C._functorch._add_batch_dim(child, 0, 1);  child = None
+        child_1: "f32[4, 3]" = torch._C._functorch._add_batch_dim(child, 0, 1);  child = None
 
         _jvp_treespec_compare = torch._functorch.eager_transforms._jvp_treespec_compare((l_x_,), (child_1,));  _jvp_treespec_compare = None
 
@@ -4744,27 +4744,27 @@
 
         _maybe_load_decompositions = torch.autograd.forward_ad._maybe_load_decompositions();  _maybe_load_decompositions = None
 
-        child_3 = torch._make_dual(l_x_, child_1, level = 0);  child_1 = None
+        child_3: "f32[4, 3]" = torch._make_dual(l_x_, child_1, level = 0);  child_1 = None
 
-        _wrap_for_grad = torch._C._functorch._wrap_for_grad(l_x_, 2);  l_x_ = _wrap_for_grad = None
-        _wrap_for_grad_1 = torch._C._functorch._wrap_for_grad(l_y_, 2);  l_y_ = None
+        _wrap_for_grad: "f32[4, 3]" = torch._C._functorch._wrap_for_grad(l_x_, 2);  l_x_ = _wrap_for_grad = None
+        _wrap_for_grad_1: "f32[3, 4]" = torch._C._functorch._wrap_for_grad(l_y_, 2);  l_y_ = None
 
-        child_2 = _wrap_for_grad_1.sin();  _wrap_for_grad_1 = None
+        child_2: "f32[3, 4]" = _wrap_for_grad_1.sin();  _wrap_for_grad_1 = None
 
         _unpack_dual = torch._unpack_dual(child_2, level = 0);  child_2 = None
-        primal = _unpack_dual[0];  _unpack_dual = None
+        primal: "f32[3, 4]" = _unpack_dual[0];  _unpack_dual = None
 
-        tangent = torch.zeros_like(primal)
+        tangent: "f32[3, 4]" = torch.zeros_like(primal)
 
         _unpack_dual_1 = torch._unpack_dual(child_3, level = 0);  child_3 = None
-        primal_1 = _unpack_dual_1[0]
-        dual = _unpack_dual_1[1];  _unpack_dual_1 = None
+        primal_1: "f32[4, 3]" = _unpack_dual_1[0]
+        dual: "f32[4, 3]" = _unpack_dual_1[1];  _unpack_dual_1 = None
 
         child_4: "f32[3, 4]" = torch._C._functorch._unwrap_for_grad(primal, 2);  primal = child_4 = None
         child_5: "f32[4, 3]" = torch._C._functorch._unwrap_for_grad(primal_1, 2);  primal_1 = child_5 = None
 
         child_6: "f32[3, 4]" = torch._C._functorch._unwrap_for_grad(tangent, 2);  tangent = None
-        child_7 = torch._C._functorch._unwrap_for_grad(dual, 2);  dual = None
+        child_7: "f32[4, 3]" = torch._C._functorch._unwrap_for_grad(dual, 2);  dual = None
 
         _exit_dual_level = torch._C._exit_dual_level(0);  _exit_dual_level = None
         _set_fwd_grad_enabled_1 = torch._C._set_fwd_grad_enabled(True);  _set_fwd_grad_enabled_1 = None
@@ -4850,14 +4850,14 @@
 
         _maybe_load_decompositions = torch.autograd.forward_ad._maybe_load_decompositions();  _maybe_load_decompositions = None
 
-        _make_dual = torch._make_dual(l_x_, l_v_, level = 0);  l_x_ = l_v_ = None
+        _make_dual: "f32[3, 3]" = torch._make_dual(l_x_, l_v_, level = 0);  l_x_ = l_v_ = None
 
-        sin = _make_dual.sin();  _make_dual = None
-        result_duals = sin.sum();  sin = None
+        sin: "f32[3, 3]" = _make_dual.sin();  _make_dual = None
+        result_duals: "f32[]" = sin.sum();  sin = None
 
         _unpack_dual = torch._unpack_dual(result_duals, level = 0);  result_duals = None
-        primal = _unpack_dual[0]
-        dual = _unpack_dual[1];  _unpack_dual = None
+        primal: "f32[]" = _unpack_dual[0]
+        dual: "f32[]" = _unpack_dual[1];  _unpack_dual = None
 
         primals_out_unflatten: "f32[]" = torch._C._functorch._unwrap_for_grad(primal, 1);  primal = None
 
@@ -4904,16 +4904,16 @@
 
         _maybe_load_decompositions = torch.autograd.forward_ad._maybe_load_decompositions();  _maybe_load_decompositions = None
 
-        aux = torch._make_dual(l_x_, l_v_, level = 0);  l_x_ = l_v_ = None
+        aux: "f32[3, 3]" = torch._make_dual(l_x_, l_v_, level = 0);  l_x_ = l_v_ = None
 
-        sin = aux.sin()
-        result_duals = sin.sum();  sin = None
+        sin: "f32[3, 3]" = aux.sin()
+        result_duals: "f32[]" = sin.sum();  sin = None
 
         aux_1: "f32[3, 3]" = torch._C._functorch._unwrap_for_grad(aux, 1);  aux = None
 
         _unpack_dual = torch._unpack_dual(result_duals, level = 0);  result_duals = None
-        primal = _unpack_dual[0]
-        dual = _unpack_dual[1];  _unpack_dual = None
+        primal: "f32[]" = _unpack_dual[0]
+        dual: "f32[]" = _unpack_dual[1];  _unpack_dual = None
 
         primals_out_unflatten: "f32[]" = torch._C._functorch._unwrap_for_grad(primal, 1);  primal = None
 
@@ -4962,22 +4962,22 @@
 
         _maybe_load_decompositions = torch.autograd.forward_ad._maybe_load_decompositions();  _maybe_load_decompositions = None
 
-        aux = torch._make_dual(l_x_, l_v_, level = 0);  l_x_ = None
+        aux: "f32[3, 3]" = torch._make_dual(l_x_, l_v_, level = 0);  l_x_ = None
 
         _maybe_load_decompositions_1 = torch.autograd.forward_ad._maybe_load_decompositions();  _maybe_load_decompositions_1 = None
 
-        _make_dual_1 = torch._make_dual(l_y_, l_v_, level = 0);  l_y_ = l_v_ = None
+        _make_dual_1: "f32[3, 3]" = torch._make_dual(l_y_, l_v_, level = 0);  l_y_ = l_v_ = None
 
-        sin = aux.sin()
-        sum_1 = sin.sum();  sin = None
-        cos = _make_dual_1.cos();  _make_dual_1 = None
-        result_duals = sum_1 + cos;  sum_1 = cos = None
+        sin: "f32[3, 3]" = aux.sin()
+        sum_1: "f32[]" = sin.sum();  sin = None
+        cos: "f32[3, 3]" = _make_dual_1.cos();  _make_dual_1 = None
+        result_duals: "f32[3, 3]" = sum_1 + cos;  sum_1 = cos = None
 
         aux_1: "f32[3, 3]" = torch._C._functorch._unwrap_for_grad(aux, 1);  aux = None
 
         _unpack_dual = torch._unpack_dual(result_duals, level = 0);  result_duals = None
-        primal = _unpack_dual[0]
-        dual = _unpack_dual[1];  _unpack_dual = None
+        primal: "f32[3, 3]" = _unpack_dual[0]
+        dual: "f32[3, 3]" = _unpack_dual[1];  _unpack_dual = None
 
         primals_out_unflatten: "f32[3, 3]" = torch._C._functorch._unwrap_for_grad(primal, 1);  primal = None
 
@@ -5027,14 +5027,14 @@
 
         _maybe_load_decompositions = torch.autograd.forward_ad._maybe_load_decompositions();  _maybe_load_decompositions = None
 
-        _make_dual = torch._make_dual(l_x_, l_v_, level = 0);  l_x_ = l_v_ = None
+        _make_dual: "f32[3, 3]" = torch._make_dual(l_x_, l_v_, level = 0);  l_x_ = l_v_ = None
 
-        sin = _make_dual.sin();  _make_dual = None
-        result_duals = sin.sum();  sin = None
+        sin: "f32[3, 3]" = _make_dual.sin();  _make_dual = None
+        result_duals: "f32[]" = sin.sum();  sin = None
 
         _unpack_dual = torch._unpack_dual(result_duals, level = 0);  result_duals = None
-        primal = _unpack_dual[0]
-        dual = _unpack_dual[1];  _unpack_dual = None
+        primal: "f32[]" = _unpack_dual[0]
+        dual: "f32[]" = _unpack_dual[1];  _unpack_dual = None
 
         primals_out_unflatten: "f32[]" = torch._C._functorch._unwrap_for_grad(primal, 1);  primal = None
 
@@ -5098,14 +5098,14 @@
 
         _maybe_load_decompositions = torch.autograd.forward_ad._maybe_load_decompositions();  _maybe_load_decompositions = None
 
-        _make_dual = torch._make_dual(l_x_, l_v_, level = 0);  l_x_ = l_v_ = None
+        _make_dual: "f32[3, 3]" = torch._make_dual(l_x_, l_v_, level = 0);  l_x_ = l_v_ = None
 
-        sin = _make_dual.sin();  _make_dual = None
-        result_duals = sin.sum();  sin = None
+        sin: "f32[3, 3]" = _make_dual.sin();  _make_dual = None
+        result_duals: "f32[]" = sin.sum();  sin = None
 
         _unpack_dual = torch._unpack_dual(result_duals, level = 0);  result_duals = None
-        primal = _unpack_dual[0]
-        dual = _unpack_dual[1];  _unpack_dual = None
+        primal: "f32[]" = _unpack_dual[0]
+        dual: "f32[]" = _unpack_dual[1];  _unpack_dual = None
 
         primals_out_unflatten: "f32[]" = torch._C._functorch._unwrap_for_grad(primal, 1);  primal = None
 
@@ -5171,7 +5171,7 @@
 
         _maybe_load_decompositions = torch.autograd.forward_ad._maybe_load_decompositions();  _maybe_load_decompositions = None
 
-        child = torch._make_dual(l_x_, l_x_, level = 0);  l_x_ = None
+        child: "f32[3, 3, 3]" = torch._make_dual(l_x_, l_x_, level = 0);  l_x_ = None
 
         _jvp_treespec_compare_1 = torch._functorch.eager_transforms._jvp_treespec_compare((child,), (child,));  _jvp_treespec_compare_1 = None
 
@@ -5180,27 +5180,27 @@
 
         _maybe_load_decompositions_1 = torch.autograd.forward_ad._maybe_load_decompositions();  _maybe_load_decompositions_1 = None
 
-        _make_dual_1 = torch._make_dual(child, child, level = 0);  child = None
+        _make_dual_1: "f32[3, 3, 3]" = torch._make_dual(child, child, level = 0);  child = None
 
-        result_duals = torch.sin(_make_dual_1);  _make_dual_1 = None
+        result_duals: "f32[3, 3, 3]" = torch.sin(_make_dual_1);  _make_dual_1 = None
 
         _unpack_dual = torch._unpack_dual(result_duals, level = 0);  result_duals = None
-        primal = _unpack_dual[0]
-        dual = _unpack_dual[1];  _unpack_dual = None
+        primal: "f32[3, 3, 3]" = _unpack_dual[0]
+        dual: "f32[3, 3, 3]" = _unpack_dual[1];  _unpack_dual = None
 
-        primals_out_unflatten = torch._C._functorch._unwrap_for_grad(primal, 2);  primal = None
+        primals_out_unflatten: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(primal, 2);  primal = None
 
-        tangents_out_unflatten = torch._C._functorch._unwrap_for_grad(dual, 2);  dual = None
+        tangents_out_unflatten: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(dual, 2);  dual = None
 
         _set_fwd_grad_enabled_2 = torch._C._set_fwd_grad_enabled(True);  _set_fwd_grad_enabled_2 = None
         _jvp_decrement_nesting = torch._C._functorch._jvp_decrement_nesting();  _jvp_decrement_nesting = None
 
         _unpack_dual_1 = torch._unpack_dual(primals_out_unflatten, level = 0);  primals_out_unflatten = None
-        primal_1 = _unpack_dual_1[0]
-        dual_1 = _unpack_dual_1[1];  _unpack_dual_1 = None
+        primal_1: "f32[3, 3, 3]" = _unpack_dual_1[0]
+        dual_1: "f32[3, 3, 3]" = _unpack_dual_1[1];  _unpack_dual_1 = None
         _unpack_dual_2 = torch._unpack_dual(tangents_out_unflatten, level = 0);  tangents_out_unflatten = None
-        primal_2 = _unpack_dual_2[0]
-        dual_2 = _unpack_dual_2[1];  _unpack_dual_2 = None
+        primal_2: "f32[3, 3, 3]" = _unpack_dual_2[0]
+        dual_2: "f32[3, 3, 3]" = _unpack_dual_2[1];  _unpack_dual_2 = None
 
         _unwrap_for_grad_2: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(primal_1, 1);  primal_1 = None
         _unwrap_for_grad_3: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(primal_2, 1);  primal_2 = None
@@ -5516,11 +5516,11 @@
 
         _vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(3, 'error');  _vmap_increment_nesting = None
 
-        _add_batch_dim = torch._C._functorch._add_batch_dim(l_x_, 0, 1);  l_x_ = None
+        _add_batch_dim: "f32[3, 3]" = torch._C._functorch._add_batch_dim(l_x_, 0, 1);  l_x_ = None
 
-        sum_1 = _add_batch_dim.sum(0)
-        sum_2 = _add_batch_dim.sum(1);  _add_batch_dim = None
-        batched_outputs = sum_1 + sum_2;  sum_1 = sum_2 = None
+        sum_1: "f32[3]" = _add_batch_dim.sum(0)
+        sum_2: "f32[3]" = _add_batch_dim.sum(1);  _add_batch_dim = None
+        batched_outputs: "f32[3]" = sum_1 + sum_2;  sum_1 = sum_2 = None
 
         _remove_batch_dim: "f32[3, 3]" = torch._C._functorch._remove_batch_dim(batched_outputs, 1, 3, 0);  batched_outputs = None
 
@@ -5554,12 +5554,12 @@
 
         _vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(3, 'error');  _vmap_increment_nesting = None
 
-        _add_batch_dim = torch._C._functorch._add_batch_dim(l_x_, 0, 1);  l_x_ = None
+        _add_batch_dim: "f32[3, 3]" = torch._C._functorch._add_batch_dim(l_x_, 0, 1);  l_x_ = None
 
-        sum_1 = _add_batch_dim.sum(0)
-        sum_2 = _add_batch_dim.sum(1);  _add_batch_dim = None
-        add = sum_1 + sum_2;  sum_1 = sum_2 = None
-        batched_outputs = add + 3;  add = None
+        sum_1: "f32[3]" = _add_batch_dim.sum(0)
+        sum_2: "f32[3]" = _add_batch_dim.sum(1);  _add_batch_dim = None
+        add: "f32[3]" = sum_1 + sum_2;  sum_1 = sum_2 = None
+        batched_outputs: "f32[3]" = add + 3;  add = None
 
         _remove_batch_dim: "f32[3, 3]" = torch._C._functorch._remove_batch_dim(batched_outputs, 1, 3, 0);  batched_outputs = None
 
@@ -5594,12 +5594,12 @@
 
         _vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(3, 'error');  _vmap_increment_nesting = None
 
-        _add_batch_dim = torch._C._functorch._add_batch_dim(l_x_, 0, 1);  l_x_ = None
+        _add_batch_dim: "f32[3, 3]" = torch._C._functorch._add_batch_dim(l_x_, 0, 1);  l_x_ = None
 
-        sum_1 = _add_batch_dim.sum(0)
-        sum_2 = _add_batch_dim.sum(1);  _add_batch_dim = None
-        add = sum_1 + sum_2;  sum_1 = sum_2 = None
-        batched_outputs = add + l_y_;  add = l_y_ = None
+        sum_1: "f32[3]" = _add_batch_dim.sum(0)
+        sum_2: "f32[3]" = _add_batch_dim.sum(1);  _add_batch_dim = None
+        add: "f32[3]" = sum_1 + sum_2;  sum_1 = sum_2 = None
+        batched_outputs: "f32[3, 3]" = add + l_y_;  add = l_y_ = None
 
         _remove_batch_dim: "f32[3, 3, 3]" = torch._C._functorch._remove_batch_dim(batched_outputs, 1, 3, 0);  batched_outputs = None
 
@@ -5635,13 +5635,13 @@
 
         _vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(3, 'error');  _vmap_increment_nesting = None
 
-        _add_batch_dim = torch._C._functorch._add_batch_dim(l_x_, 0, 1);  l_x_ = None
-        _add_batch_dim_1 = torch._C._functorch._add_batch_dim(l_y_, 1, 1);  l_y_ = None
+        _add_batch_dim: "f32[3, 3]" = torch._C._functorch._add_batch_dim(l_x_, 0, 1);  l_x_ = None
+        _add_batch_dim_1: "f32[3]" = torch._C._functorch._add_batch_dim(l_y_, 1, 1);  l_y_ = None
 
-        sum_1 = _add_batch_dim.sum(0)
-        sum_2 = _add_batch_dim.sum(1);  _add_batch_dim = None
-        add = sum_1 + sum_2;  sum_1 = sum_2 = None
-        batched_outputs = add + _add_batch_dim_1;  add = _add_batch_dim_1 = None
+        sum_1: "f32[3]" = _add_batch_dim.sum(0)
+        sum_2: "f32[3]" = _add_batch_dim.sum(1);  _add_batch_dim = None
+        add: "f32[3]" = sum_1 + sum_2;  sum_1 = sum_2 = None
+        batched_outputs: "f32[3]" = add + _add_batch_dim_1;  add = _add_batch_dim_1 = None
 
         _remove_batch_dim: "f32[3, 3]" = torch._C._functorch._remove_batch_dim(batched_outputs, 1, 3, 0);  batched_outputs = None
 
@@ -5679,13 +5679,13 @@
 
         _vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(3, 'error');  _vmap_increment_nesting = None
 
-        _add_batch_dim = torch._C._functorch._add_batch_dim(l_x_, 0, 1);  l_x_ = None
-        _add_batch_dim_1 = torch._C._functorch._add_batch_dim(l_y_, 1, 1);  l_y_ = None
+        _add_batch_dim: "f32[3, 3]" = torch._C._functorch._add_batch_dim(l_x_, 0, 1);  l_x_ = None
+        _add_batch_dim_1: "f32[3]" = torch._C._functorch._add_batch_dim(l_y_, 1, 1);  l_y_ = None
 
-        sum_1 = _add_batch_dim.sum(0)
-        sum_2 = _add_batch_dim.sum(1);  _add_batch_dim = None
-        add = sum_1 + sum_2;  sum_1 = sum_2 = None
-        batched_outputs = add + _add_batch_dim_1;  add = _add_batch_dim_1 = None
+        sum_1: "f32[3]" = _add_batch_dim.sum(0)
+        sum_2: "f32[3]" = _add_batch_dim.sum(1);  _add_batch_dim = None
+        add: "f32[3]" = sum_1 + sum_2;  sum_1 = sum_2 = None
+        batched_outputs: "f32[3]" = add + _add_batch_dim_1;  add = _add_batch_dim_1 = None
 
         _remove_batch_dim: "f32[3, 3]" = torch._C._functorch._remove_batch_dim(batched_outputs, 1, 3, 0);  batched_outputs = None
 
@@ -5719,19 +5719,19 @@
 
         _vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(3, 'error');  _vmap_increment_nesting = None
 
-        child = torch._C._functorch._add_batch_dim(l_x_, 0, 1);  l_x_ = None
-        child_1 = torch._C._functorch._add_batch_dim(l_y_, 0, 1);  l_y_ = None
+        child: "f32[3, 3]" = torch._C._functorch._add_batch_dim(l_x_, 0, 1);  l_x_ = None
+        child_1: "f32[3, 3]" = torch._C._functorch._add_batch_dim(l_y_, 0, 1);  l_y_ = None
 
         lazy_load_decompositions_1 = torch._functorch.vmap.lazy_load_decompositions();  lazy_load_decompositions_1 = None
 
         _vmap_increment_nesting_1 = torch._C._functorch._vmap_increment_nesting(3, 'error');  _vmap_increment_nesting_1 = None
 
-        _add_batch_dim_2 = torch._C._functorch._add_batch_dim(child, 1, 2);  child = None
-        _add_batch_dim_3 = torch._C._functorch._add_batch_dim(child_1, 1, 2);  child_1 = None
+        _add_batch_dim_2: "f32[3]" = torch._C._functorch._add_batch_dim(child, 1, 2);  child = None
+        _add_batch_dim_3: "f32[3]" = torch._C._functorch._add_batch_dim(child_1, 1, 2);  child_1 = None
 
-        batched_outputs = _add_batch_dim_2 + _add_batch_dim_3;  _add_batch_dim_2 = _add_batch_dim_3 = None
+        batched_outputs: "f32[3]" = _add_batch_dim_2 + _add_batch_dim_3;  _add_batch_dim_2 = _add_batch_dim_3 = None
 
-        batched_outputs_1 = torch._C._functorch._remove_batch_dim(batched_outputs, 2, 3, 0);  batched_outputs = None
+        batched_outputs_1: "f32[3, 3]" = torch._C._functorch._remove_batch_dim(batched_outputs, 2, 3, 0);  batched_outputs = None
 
         _vmap_decrement_nesting = torch._C._functorch._vmap_decrement_nesting();  _vmap_decrement_nesting = None
 
@@ -5768,17 +5768,17 @@
 
         _vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(5, 'error');  _vmap_increment_nesting = None
 
-        child = torch._C._functorch._add_batch_dim(l_y_, 0, 1);  l_y_ = None
+        child: "f32[3]" = torch._C._functorch._add_batch_dim(l_y_, 0, 1);  l_y_ = None
 
         lazy_load_decompositions_1 = torch._functorch.vmap.lazy_load_decompositions();  lazy_load_decompositions_1 = None
 
         _vmap_increment_nesting_1 = torch._C._functorch._vmap_increment_nesting(3, 'error');  _vmap_increment_nesting_1 = None
 
-        _add_batch_dim_1 = torch._C._functorch._add_batch_dim(child, 0, 2);  child = None
+        _add_batch_dim_1: "f32[]" = torch._C._functorch._add_batch_dim(child, 0, 2);  child = None
 
-        batched_outputs = l_x_ * _add_batch_dim_1;  l_x_ = _add_batch_dim_1 = None
+        batched_outputs: "f32[2, 3]" = l_x_ * _add_batch_dim_1;  l_x_ = _add_batch_dim_1 = None
 
-        batched_outputs_1 = torch._C._functorch._remove_batch_dim(batched_outputs, 2, 3, 0);  batched_outputs = None
+        batched_outputs_1: "f32[3, 2, 3]" = torch._C._functorch._remove_batch_dim(batched_outputs, 2, 3, 0);  batched_outputs = None
 
         _vmap_decrement_nesting = torch._C._functorch._vmap_decrement_nesting();  _vmap_decrement_nesting = None
 
@@ -5813,10 +5813,10 @@
 
         _vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(2, 'error');  _vmap_increment_nesting = None
 
-        _add_batch_dim = torch._C._functorch._add_batch_dim(l_x_, 0, 1);  l_x_ = None
+        _add_batch_dim: "f32[4, 3]" = torch._C._functorch._add_batch_dim(l_x_, 0, 1);  l_x_ = None
 
-        child = _add_batch_dim.sum(0)
-        child_1 = _add_batch_dim.sum(1);  _add_batch_dim = None
+        child: "f32[3]" = _add_batch_dim.sum(0)
+        child_1: "f32[4]" = _add_batch_dim.sum(1);  _add_batch_dim = None
 
         _remove_batch_dim: "f32[2, 3]" = torch._C._functorch._remove_batch_dim(child, 1, 2, 0);  child = None
         _remove_batch_dim_1: "f32[2, 4]" = torch._C._functorch._remove_batch_dim(child_1, 1, 2, 0);  child_1 = None
@@ -5850,10 +5850,10 @@
 
         _vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(2, 'error');  _vmap_increment_nesting = None
 
-        _add_batch_dim = torch._C._functorch._add_batch_dim(l_x_, 0, 1);  l_x_ = None
+        _add_batch_dim: "f32[4, 3]" = torch._C._functorch._add_batch_dim(l_x_, 0, 1);  l_x_ = None
 
-        child = _add_batch_dim.sum(0)
-        child_1 = _add_batch_dim.sum(1);  _add_batch_dim = None
+        child: "f32[3]" = _add_batch_dim.sum(0)
+        child_1: "f32[4]" = _add_batch_dim.sum(1);  _add_batch_dim = None
 
         _remove_batch_dim: "f32[3, 2]" = torch._C._functorch._remove_batch_dim(child, 1, 2, 1);  child = None
         _remove_batch_dim_1: "f32[2, 4]" = torch._C._functorch._remove_batch_dim(child_1, 1, 2, 0);  child_1 = None
@@ -5888,10 +5888,10 @@
 
         _vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(2, 'error');  _vmap_increment_nesting = None
 
-        _add_batch_dim = torch._C._functorch._add_batch_dim(l_x_, 0, 1);  l_x_ = None
+        _add_batch_dim: "f32[4, 3]" = torch._C._functorch._add_batch_dim(l_x_, 0, 1);  l_x_ = None
 
-        child = _add_batch_dim.sum(0)
-        child_1 = _add_batch_dim.sum(1);  _add_batch_dim = None
+        child: "f32[3]" = _add_batch_dim.sum(0)
+        child_1: "f32[4]" = _add_batch_dim.sum(1);  _add_batch_dim = None
 
         _remove_batch_dim: "f32[3, 2]" = torch._C._functorch._remove_batch_dim(child, 1, 2, 1);  child = None
         _remove_batch_dim_1: "f32[2, 4]" = torch._C._functorch._remove_batch_dim(child_1, 1, 2, 0);  child_1 = None
diff --git a/test/dynamo/test_subclasses.py b/test/dynamo/test_subclasses.py
index 22b3bbf..01cdb4e 100644
--- a/test/dynamo/test_subclasses.py
+++ b/test/dynamo/test_subclasses.py
@@ -987,12 +987,12 @@
             actual,
             """\
 class GraphModule(torch.nn.Module):
-    def forward(self, L_x_ : torch.Tensor):
+    def forward(self, L_x_: "f32[3, 4]"):
         l_x_ = L_x_
 
-        add_ = l_x_.add_(1.0)
-        relu_ = torch.relu_(l_x_);  l_x_ = None
-        add = add_ + relu_;  add_ = relu_ = None
+        add_: "f32[3, 4]" = l_x_.add_(1.0)
+        relu_: "f32[3, 4]" = torch.relu_(l_x_);  l_x_ = None
+        add: "f32[3, 4]" = add_ + relu_;  add_ = relu_ = None
         return (add,)
 """,
         )
@@ -1029,12 +1029,12 @@
             actual,
             """\
 class GraphModule(torch.nn.Module):
-    def forward(self, L_x_ : torch.Tensor):
+    def forward(self, L_x_: "f32[3, 4]"):
         l_x_ = L_x_
 
-        add_ = l_x_.add_(1.0)
-        relu_ = torch.relu_(l_x_);  l_x_ = None
-        add = add_ + relu_;  add_ = relu_ = None
+        add_: "f32[3, 4]" = l_x_.add_(1.0)
+        relu_: "f32[3, 4]" = torch.relu_(l_x_);  l_x_ = None
+        add: "f32[3, 4]" = add_ + relu_;  add_ = relu_ = None
         return (add,)
 """,
         )
@@ -1106,17 +1106,17 @@
             2,
             """\
 class GraphModule(torch.nn.Module):
-    def forward(self, L_x_ : torch.Tensor):
+    def forward(self, L_x_: "f32[3, 4]"):
         l_x_ = L_x_
 
         wrap_body_0 = self.wrap_body_0
         wrap = torch._higher_order_ops.wrap.wrap(wrap_body_0, l_x_);  wrap_body_0 = l_x_ = None
-        getitem = wrap[0];  wrap = None
+        getitem: "f32[3, 4]" = wrap[0];  wrap = None
         return (getitem,)
 
     class GraphModule(torch.nn.Module):
-        def forward(self, l_x_):
-            add_ = l_x_.add_(1.0);  l_x_ = None
+        def forward(self, l_x_: "f32[3, 4]"):
+            add_: "f32[3, 4]" = l_x_.add_(1.0);  l_x_ = None
             return (add_,)
 """,
         )
@@ -1136,17 +1136,17 @@
             3,
             """\
 class GraphModule(torch.nn.Module):
-    def forward(self, L_x_ : torch.Tensor):
+    def forward(self, L_x_: "f32[3, 4]"):
         l_x_ = L_x_
 
         wrap_body_0 = self.wrap_body_0
         wrap = torch._higher_order_ops.wrap.wrap(wrap_body_0, l_x_);  wrap_body_0 = l_x_ = None
-        getitem = wrap[0];  wrap = None
+        getitem: "f32[3, 4]" = wrap[0];  wrap = None
         return (getitem,)
 
     class GraphModule(torch.nn.Module):
-        def forward(self, l_x_):
-            add_ = l_x_.add_(1.0);  l_x_ = None
+        def forward(self, l_x_: "f32[3, 4]"):
+            add_: "f32[3, 4]" = l_x_.add_(1.0);  l_x_ = None
             return (add_,)
 """,
         )
diff --git a/torch/fx/graph.py b/torch/fx/graph.py
index 0c5687b..720b808 100644
--- a/torch/fx/graph.py
+++ b/torch/fx/graph.py
@@ -573,14 +573,13 @@
 
             if verbose:
                 # override annotation with more detailed information
-                from torch._subclasses.fake_tensor import FakeTensor
                 from torch.fx.experimental.proxy_tensor import py_sym_types
                 from torch.fx.passes.shape_prop import TensorMetadata
 
                 meta_val = node.meta.get('val', node.meta.get('tensor_meta', node.meta.get('example_value', None)))
                 # use string as annotation, to make it valid python code
 
-                if isinstance(meta_val, FakeTensor):
+                if isinstance(meta_val, torch.Tensor):
                     stride_annotation = f"{stringify_shape(meta_val.stride())}" if include_stride else ""
                     device_annotation = f"{meta_val.device}" if include_device else ""
                     maybe_type_annotation = \