[fx] python_code(verbose=True): show size/strides for all tensors (#132192)
python_code(verbose=True) (or print_readable()) generates a string with the code representing the fx graph, with extra annotations indicating the size or stride of the tensor. Currently, it'll only shows sizes/strides for FakeTensors provided in metadata. For subclass tensors like NestedTensor, the outer class (provided in the node metadata) will be a non-FakeTensor and the inner tensors will be fake. This PR expands the conditional to show sizes/strides for all tensors, not just FakeTensors.
Testing: I ran this test script (below), ran it with `TORCH_LOGS=+dynamo` and found in the logs the graph shown below - we see that the input nested tensor has sizes and strides associated with it. Also, I stacked a diff on top of this one that forces the readable graph to be generated whenever PT2 is in use in tests, which should hopefully find any issues; https://github.com/pytorch/pytorch/pull/132195 shows no significant failures except for preexisting failures.
test script:
```python
import torch
def fn(x):
return x.cos()
nt = torch.nested.nested_tensor_from_jagged(
torch.randn(10, 10),
torch.tensor([0, 1, 3, 6, 10]),
)
torch.compile(fn)(nt)
```
logs excerpt:
```
[0/0] [__graph_code] TRACED GRAPH
[0/0] [__graph_code] ===== __compiled_fn_1 =====
[0/0] [__graph_code] /data/users/dberard/pytorch/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.M
[0/0] [__graph_code] def forward(self, L_x_: "f32[4, zf1, 10][10*zf1, 10, 1]cpu", zf1: "Sym(zf1)"):
[0/0] [__graph_code] l_x_ = L_x_
[0/0] [__graph_code]
[0/0] [__graph_code] # File: /data/users/dberard/scripts/nt_print_graph.py:4 in fn, code: return x.c
[0/0] [__graph_code] cos: "f32[4, zf1, 10][10*zf1, 10, 1]cpu" = l_x_.cos(); l_x_ = None
[0/0] [__graph_code] return (cos,)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/132192
Approved by: https://github.com/Chillee
diff --git a/test/dynamo/test_higher_order_ops.py b/test/dynamo/test_higher_order_ops.py
index 1b911bb..20b31f0 100644
--- a/test/dynamo/test_higher_order_ops.py
+++ b/test/dynamo/test_higher_order_ops.py
@@ -2837,7 +2837,7 @@
_vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(12, 'error'); _vmap_increment_nesting = None
- child_1 = torch._C._functorch._add_batch_dim(child, 0, 1); child = None
+ child_1: "f32[4, 3]" = torch._C._functorch._add_batch_dim(child, 0, 1); child = None
_jvp_treespec_compare = torch._functorch.eager_transforms._jvp_treespec_compare((l_x_,), (child_1,)); _jvp_treespec_compare = None
@@ -2847,68 +2847,68 @@
_maybe_load_decompositions = torch.autograd.forward_ad._maybe_load_decompositions(); _maybe_load_decompositions = None
- child_2 = torch._make_dual(l_x_, child_1, level = 0); child_1 = None
+ child_2: "f32[4, 3]" = torch._make_dual(l_x_, child_1, level = 0); child_1 = None
- _wrap_for_grad = torch._C._functorch._wrap_for_grad(l_x_, 2); l_x_ = _wrap_for_grad = None
+ _wrap_for_grad: "f32[4, 3]" = torch._C._functorch._wrap_for_grad(l_x_, 2); l_x_ = _wrap_for_grad = None
_saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func.{grad, vjp, jacrev, hessian} don't yet support saved tensor hooks. Please open an issue with your use case."); _saved_tensors_hooks_disable = None
_grad_increment_nesting = torch._C._functorch._grad_increment_nesting(); _grad_increment_nesting = None
- diff_primals = torch._C._functorch._wrap_for_grad(child_2, 3); child_2 = None
+ diff_primals: "f32[4, 3]" = torch._C._functorch._wrap_for_grad(child_2, 3); child_2 = None
set_inplace_requires_grad_allowed = torch._C._functorch.set_inplace_requires_grad_allowed(True); set_inplace_requires_grad_allowed = None
- _set_tensor_requires_grad = torch._functorch.eager_transforms._set_tensor_requires_grad(diff_primals); _set_tensor_requires_grad = None
+ _set_tensor_requires_grad: "f32[4, 3]" = torch._functorch.eager_transforms._set_tensor_requires_grad(diff_primals); _set_tensor_requires_grad = None
set_inplace_requires_grad_allowed_1 = torch._C._functorch.set_inplace_requires_grad_allowed(False); set_inplace_requires_grad_allowed_1 = None
- o = torch.sin(diff_primals)
+ o: "f32[4, 3]" = torch.sin(diff_primals)
- results = torch._C._functorch._unwrap_for_grad(o, 3)
+ results: "f32[4, 3]" = torch._C._functorch._unwrap_for_grad(o, 3)
_grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting(); _grad_decrement_nesting = None
_saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable(); _saved_tensors_hooks_enable = None
- tensor_1 = torch.tensor((12,))
- cumsum_1 = tensor_1.cumsum(dim = 0); tensor_1 = None
- getitem_1 = cumsum_1[slice(None, -1, None)]; cumsum_1 = None
- neg_1 = getitem_1.neg(); getitem_1 = None
+ tensor_1: "i64[1]" = torch.tensor((12,))
+ cumsum_1: "i64[1]" = tensor_1.cumsum(dim = 0); tensor_1 = None
+ getitem_1: "i64[0]" = cumsum_1[slice(None, -1, None)]; cumsum_1 = None
+ neg_1: "i64[0]" = getitem_1.neg(); getitem_1 = None
unbind_1 = neg_1.unbind(); neg_1 = unbind_1 = None
- chunk_1 = results.new_zeros(12, 12); results = None
+ chunk_1: "f32[12, 12]" = results.new_zeros(12, 12); results = None
- diagonal_1 = chunk_1.diagonal(0)
- fill__1 = diagonal_1.fill_(1); diagonal_1 = fill__1 = None
+ diagonal_1: "f32[12]" = chunk_1.diagonal(0)
+ fill__1: "f32[12]" = diagonal_1.fill_(1); diagonal_1 = fill__1 = None
- basis = chunk_1.view(12, 4, 3); chunk_1 = None
+ basis: "f32[12, 4, 3]" = chunk_1.view(12, 4, 3); chunk_1 = None
lazy_load_decompositions_1 = torch._functorch.vmap.lazy_load_decompositions(); lazy_load_decompositions_1 = None
_vmap_increment_nesting_1 = torch._C._functorch._vmap_increment_nesting(12, 'error'); _vmap_increment_nesting_1 = None
- _add_batch_dim_1 = torch._C._functorch._add_batch_dim(basis, 0, 3); basis = None
+ _add_batch_dim_1: "f32[4, 3]" = torch._C._functorch._add_batch_dim(basis, 0, 3); basis = None
_vjp_treespec_compare = torch._functorch.eager_transforms._vjp_treespec_compare(o, _add_batch_dim_1); _vjp_treespec_compare = None
_autograd_grad = torch._functorch.eager_transforms._autograd_grad([o], [diff_primals], [_add_batch_dim_1], retain_graph = True, create_graph = True); o = diff_primals = _add_batch_dim_1 = None
- batched_outputs = _autograd_grad[0]; _autograd_grad = None
+ batched_outputs: "f32[4, 3]" = _autograd_grad[0]; _autograd_grad = None
- chunked_result = torch._C._functorch._remove_batch_dim(batched_outputs, 3, 12, 0); batched_outputs = None
+ chunked_result: "f32[12, 4, 3]" = torch._C._functorch._remove_batch_dim(batched_outputs, 3, 12, 0); batched_outputs = None
_vmap_decrement_nesting = torch._C._functorch._vmap_decrement_nesting(); _vmap_decrement_nesting = None
split = chunked_result.split((12,), dim = 0); chunked_result = None
- split_1 = split[0]; split = None
+ split_1: "f32[12, 4, 3]" = split[0]; split = None
- output_input = split_1.view((4, 3, 4, 3)); split_1 = None
+ output_input: "f32[4, 3, 4, 3]" = split_1.view((4, 3, 4, 3)); split_1 = None
_unpack_dual = torch._unpack_dual(output_input, level = 0); output_input = None
- primal = _unpack_dual[0]
- dual = _unpack_dual[1]; _unpack_dual = None
+ primal: "f32[4, 3, 4, 3]" = _unpack_dual[0]
+ dual: "f32[4, 3, 4, 3]" = _unpack_dual[1]; _unpack_dual = None
primals_out_unflatten: "f32[4, 3, 4, 3]" = torch._C._functorch._unwrap_for_grad(primal, 2); primal = primals_out_unflatten = None
- tangents_out_unflatten = torch._C._functorch._unwrap_for_grad(dual, 2); dual = None
+ tangents_out_unflatten: "f32[4, 3, 4, 3]" = torch._C._functorch._unwrap_for_grad(dual, 2); dual = None
_exit_dual_level = torch._C._exit_dual_level(0); _exit_dual_level = None
_set_fwd_grad_enabled_1 = torch._C._set_fwd_grad_enabled(True); _set_fwd_grad_enabled_1 = None
@@ -2969,7 +2969,7 @@
_vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(12, 'error'); _vmap_increment_nesting = None
- child_1 = torch._C._functorch._add_batch_dim(child, 0, 1); child = None
+ child_1: "f32[3, 4]" = torch._C._functorch._add_batch_dim(child, 0, 1); child = None
_jvp_treespec_compare = torch._functorch.eager_transforms._jvp_treespec_compare((l_y_,), (child_1,)); _jvp_treespec_compare = None
@@ -2979,67 +2979,67 @@
_maybe_load_decompositions = torch.autograd.forward_ad._maybe_load_decompositions(); _maybe_load_decompositions = None
- child_3 = torch._make_dual(l_y_, child_1, level = 0); child_1 = None
+ child_3: "f32[3, 4]" = torch._make_dual(l_y_, child_1, level = 0); child_1 = None
- child_2 = torch._C._functorch._wrap_for_grad(l_x_, 2); l_x_ = None
- _wrap_for_grad_1 = torch._C._functorch._wrap_for_grad(l_y_, 2); l_y_ = _wrap_for_grad_1 = None
+ child_2: "f32[4, 3]" = torch._C._functorch._wrap_for_grad(l_x_, 2); l_x_ = None
+ _wrap_for_grad_1: "f32[3, 4]" = torch._C._functorch._wrap_for_grad(l_y_, 2); l_y_ = _wrap_for_grad_1 = None
_saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func.{grad, vjp, jacrev, hessian} don't yet support saved tensor hooks. Please open an issue with your use case."); _saved_tensors_hooks_disable = None
_grad_increment_nesting = torch._C._functorch._grad_increment_nesting(); _grad_increment_nesting = None
- _wrap_for_grad_2 = torch._C._functorch._wrap_for_grad(child_2, 3); child_2 = None
- child_4 = torch._C._functorch._wrap_for_grad(child_3, 3); child_3 = None
+ _wrap_for_grad_2: "f32[4, 3]" = torch._C._functorch._wrap_for_grad(child_2, 3); child_2 = None
+ child_4: "f32[3, 4]" = torch._C._functorch._wrap_for_grad(child_3, 3); child_3 = None
set_inplace_requires_grad_allowed = torch._C._functorch.set_inplace_requires_grad_allowed(True); set_inplace_requires_grad_allowed = None
- _set_tensor_requires_grad = torch._functorch.eager_transforms._set_tensor_requires_grad(child_4); _set_tensor_requires_grad = None
+ _set_tensor_requires_grad: "f32[3, 4]" = torch._functorch.eager_transforms._set_tensor_requires_grad(child_4); _set_tensor_requires_grad = None
set_inplace_requires_grad_allowed_1 = torch._C._functorch.set_inplace_requires_grad_allowed(False); set_inplace_requires_grad_allowed_1 = None
- o = _wrap_for_grad_2.sin(); _wrap_for_grad_2 = None
+ o: "f32[4, 3]" = _wrap_for_grad_2.sin(); _wrap_for_grad_2 = None
- results = torch._C._functorch._unwrap_for_grad(o, 3)
+ results: "f32[4, 3]" = torch._C._functorch._unwrap_for_grad(o, 3)
_grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting(); _grad_decrement_nesting = None
_saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable(); _saved_tensors_hooks_enable = None
- tensor_1 = torch.tensor((12,))
- cumsum_1 = tensor_1.cumsum(dim = 0); tensor_1 = None
- getitem_1 = cumsum_1[slice(None, -1, None)]; cumsum_1 = None
- neg_1 = getitem_1.neg(); getitem_1 = None
+ tensor_1: "i64[1]" = torch.tensor((12,))
+ cumsum_1: "i64[1]" = tensor_1.cumsum(dim = 0); tensor_1 = None
+ getitem_1: "i64[0]" = cumsum_1[slice(None, -1, None)]; cumsum_1 = None
+ neg_1: "i64[0]" = getitem_1.neg(); getitem_1 = None
unbind_1 = neg_1.unbind(); neg_1 = unbind_1 = None
- chunk_1 = results.new_zeros(12, 12); results = None
+ chunk_1: "f32[12, 12]" = results.new_zeros(12, 12); results = None
- diagonal_1 = chunk_1.diagonal(0)
- fill__1 = diagonal_1.fill_(1); diagonal_1 = fill__1 = None
+ diagonal_1: "f32[12]" = chunk_1.diagonal(0)
+ fill__1: "f32[12]" = diagonal_1.fill_(1); diagonal_1 = fill__1 = None
- basis = chunk_1.view(12, 4, 3); chunk_1 = None
+ basis: "f32[12, 4, 3]" = chunk_1.view(12, 4, 3); chunk_1 = None
lazy_load_decompositions_1 = torch._functorch.vmap.lazy_load_decompositions(); lazy_load_decompositions_1 = None
_vmap_increment_nesting_1 = torch._C._functorch._vmap_increment_nesting(12, 'error'); _vmap_increment_nesting_1 = None
- _add_batch_dim_1 = torch._C._functorch._add_batch_dim(basis, 0, 3); basis = None
+ _add_batch_dim_1: "f32[4, 3]" = torch._C._functorch._add_batch_dim(basis, 0, 3); basis = None
_vjp_treespec_compare = torch._functorch.eager_transforms._vjp_treespec_compare(o, _add_batch_dim_1); _vjp_treespec_compare = None
_autograd_grad = torch._functorch.eager_transforms._autograd_grad([o], [child_4], [_add_batch_dim_1], retain_graph = True, create_graph = True); o = child_4 = _add_batch_dim_1 = None
- child_5 = _autograd_grad[0]; _autograd_grad = None
+ child_5: "f32[3, 4]" = _autograd_grad[0]; _autograd_grad = None
- child_6 = torch._C._functorch._remove_batch_dim(child_5, 3, 12, 0); child_5 = None
+ child_6: "f32[12, 3, 4]" = torch._C._functorch._remove_batch_dim(child_5, 3, 12, 0); child_5 = None
_vmap_decrement_nesting = torch._C._functorch._vmap_decrement_nesting(); _vmap_decrement_nesting = None
split = child_6.split((12,), dim = 0); child_6 = None
- split_1 = split[0]; split = None
+ split_1: "f32[12, 3, 4]" = split[0]; split = None
- child_7 = split_1.view((4, 3, 3, 4)); split_1 = None
+ child_7: "f32[4, 3, 3, 4]" = split_1.view((4, 3, 3, 4)); split_1 = None
_unpack_dual = torch._unpack_dual(child_7, level = 0); child_7 = None
- primal = _unpack_dual[0]; _unpack_dual = None
+ primal: "f32[4, 3, 3, 4]" = _unpack_dual[0]; _unpack_dual = None
- tangent = torch.zeros_like(primal)
+ tangent: "f32[4, 3, 3, 4]" = torch.zeros_like(primal)
child_8: "f32[4, 3, 3, 4]" = torch._C._functorch._unwrap_for_grad(primal, 2); primal = child_8 = None
@@ -3114,15 +3114,15 @@
_saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func.{grad, vjp, jacrev, hessian} don't yet support saved tensor hooks. Please open an issue with your use case."); _saved_tensors_hooks_disable = None
_grad_increment_nesting = torch._C._functorch._grad_increment_nesting(); _grad_increment_nesting = None
- diff_primals = torch._C._functorch._wrap_for_grad(l_x_, 1); l_x_ = None
+ diff_primals: "f32[4, 3]" = torch._C._functorch._wrap_for_grad(l_x_, 1); l_x_ = None
set_inplace_requires_grad_allowed = torch._C._functorch.set_inplace_requires_grad_allowed(True); set_inplace_requires_grad_allowed = None
- _set_tensor_requires_grad = torch._functorch.eager_transforms._set_tensor_requires_grad(diff_primals); _set_tensor_requires_grad = None
+ _set_tensor_requires_grad: "f32[4, 3]" = torch._functorch.eager_transforms._set_tensor_requires_grad(diff_primals); _set_tensor_requires_grad = None
set_inplace_requires_grad_allowed_1 = torch._C._functorch.set_inplace_requires_grad_allowed(False); set_inplace_requires_grad_allowed_1 = None
- o = torch.sin(diff_primals)
+ o: "f32[4, 3]" = torch.sin(diff_primals)
results: "f32[4, 3]" = torch._C._functorch._unwrap_for_grad(o, 1)
@@ -3146,12 +3146,12 @@
_vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(12, 'error'); _vmap_increment_nesting = None
- _add_batch_dim = torch._C._functorch._add_batch_dim(basis, 0, 1); basis = None
+ _add_batch_dim: "f32[4, 3]" = torch._C._functorch._add_batch_dim(basis, 0, 1); basis = None
_vjp_treespec_compare = torch._functorch.eager_transforms._vjp_treespec_compare(o, _add_batch_dim); _vjp_treespec_compare = None
_autograd_grad = torch._functorch.eager_transforms._autograd_grad([o], [diff_primals], [_add_batch_dim], retain_graph = True, create_graph = True); o = diff_primals = _add_batch_dim = None
- batched_outputs = _autograd_grad[0]; _autograd_grad = None
+ batched_outputs: "f32[4, 3]" = _autograd_grad[0]; _autograd_grad = None
chunked_result: "f32[12, 4, 3]" = torch._C._functorch._remove_batch_dim(batched_outputs, 1, 12, 0); batched_outputs = None
@@ -3193,16 +3193,16 @@
_saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func.{grad, vjp, jacrev, hessian} don't yet support saved tensor hooks. Please open an issue with your use case."); _saved_tensors_hooks_disable = None
_grad_increment_nesting = torch._C._functorch._grad_increment_nesting(); _grad_increment_nesting = None
- _wrap_for_grad = torch._C._functorch._wrap_for_grad(l_x_, 1); l_x_ = _wrap_for_grad = None
- diff_primals = torch._C._functorch._wrap_for_grad(l_y_, 1); l_y_ = None
+ _wrap_for_grad: "f32[4, 3]" = torch._C._functorch._wrap_for_grad(l_x_, 1); l_x_ = _wrap_for_grad = None
+ diff_primals: "f32[3, 4]" = torch._C._functorch._wrap_for_grad(l_y_, 1); l_y_ = None
set_inplace_requires_grad_allowed = torch._C._functorch.set_inplace_requires_grad_allowed(True); set_inplace_requires_grad_allowed = None
- _set_tensor_requires_grad = torch._functorch.eager_transforms._set_tensor_requires_grad(diff_primals); _set_tensor_requires_grad = None
+ _set_tensor_requires_grad: "f32[3, 4]" = torch._functorch.eager_transforms._set_tensor_requires_grad(diff_primals); _set_tensor_requires_grad = None
set_inplace_requires_grad_allowed_1 = torch._C._functorch.set_inplace_requires_grad_allowed(False); set_inplace_requires_grad_allowed_1 = None
- o = diff_primals.sin()
+ o: "f32[3, 4]" = diff_primals.sin()
results: "f32[3, 4]" = torch._C._functorch._unwrap_for_grad(o, 1)
@@ -3226,12 +3226,12 @@
_vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(12, 'error'); _vmap_increment_nesting = None
- _add_batch_dim = torch._C._functorch._add_batch_dim(basis, 0, 1); basis = None
+ _add_batch_dim: "f32[3, 4]" = torch._C._functorch._add_batch_dim(basis, 0, 1); basis = None
_vjp_treespec_compare = torch._functorch.eager_transforms._vjp_treespec_compare(o, _add_batch_dim); _vjp_treespec_compare = None
_autograd_grad = torch._functorch.eager_transforms._autograd_grad([o], [diff_primals], [_add_batch_dim], retain_graph = True, create_graph = True); o = diff_primals = _add_batch_dim = None
- batched_outputs = _autograd_grad[0]; _autograd_grad = None
+ batched_outputs: "f32[3, 4]" = _autograd_grad[0]; _autograd_grad = None
chunked_result: "f32[12, 3, 4]" = torch._C._functorch._remove_batch_dim(batched_outputs, 1, 12, 0); batched_outputs = None
@@ -3273,16 +3273,16 @@
_saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func.{grad, vjp, jacrev, hessian} don't yet support saved tensor hooks. Please open an issue with your use case."); _saved_tensors_hooks_disable = None
_grad_increment_nesting = torch._C._functorch._grad_increment_nesting(); _grad_increment_nesting = None
- aux = torch._C._functorch._wrap_for_grad(l_x_, 1); l_x_ = None
- diff_primals = torch._C._functorch._wrap_for_grad(l_y_, 1); l_y_ = None
+ aux: "f32[4, 3]" = torch._C._functorch._wrap_for_grad(l_x_, 1); l_x_ = None
+ diff_primals: "f32[3, 4]" = torch._C._functorch._wrap_for_grad(l_y_, 1); l_y_ = None
set_inplace_requires_grad_allowed = torch._C._functorch.set_inplace_requires_grad_allowed(True); set_inplace_requires_grad_allowed = None
- _set_tensor_requires_grad = torch._functorch.eager_transforms._set_tensor_requires_grad(diff_primals); _set_tensor_requires_grad = None
+ _set_tensor_requires_grad: "f32[3, 4]" = torch._functorch.eager_transforms._set_tensor_requires_grad(diff_primals); _set_tensor_requires_grad = None
set_inplace_requires_grad_allowed_1 = torch._C._functorch.set_inplace_requires_grad_allowed(False); set_inplace_requires_grad_allowed_1 = None
- o = diff_primals.sin()
+ o: "f32[3, 4]" = diff_primals.sin()
aux_1: "f32[4, 3]" = torch._C._functorch._unwrap_for_grad(aux, 1); aux = None
@@ -3308,12 +3308,12 @@
_vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(12, 'error'); _vmap_increment_nesting = None
- _add_batch_dim = torch._C._functorch._add_batch_dim(basis, 0, 1); basis = None
+ _add_batch_dim: "f32[3, 4]" = torch._C._functorch._add_batch_dim(basis, 0, 1); basis = None
_vjp_treespec_compare = torch._functorch.eager_transforms._vjp_treespec_compare(o, _add_batch_dim); _vjp_treespec_compare = None
_autograd_grad = torch._functorch.eager_transforms._autograd_grad([o], [diff_primals], [_add_batch_dim], retain_graph = True, create_graph = True); o = diff_primals = _add_batch_dim = None
- batched_outputs = _autograd_grad[0]; _autograd_grad = None
+ batched_outputs: "f32[3, 4]" = _autograd_grad[0]; _autograd_grad = None
chunked_result: "f32[12, 3, 4]" = torch._C._functorch._remove_batch_dim(batched_outputs, 1, 12, 0); batched_outputs = None
@@ -3382,16 +3382,16 @@
_saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func.{grad, vjp, jacrev, hessian} don't yet support saved tensor hooks. Please open an issue with your use case."); _saved_tensors_hooks_disable = None
_grad_increment_nesting = torch._C._functorch._grad_increment_nesting(); _grad_increment_nesting = None
- child = torch._C._functorch._wrap_for_grad(l_x_, 1); l_x_ = None
+ child: "f32[5]" = torch._C._functorch._wrap_for_grad(l_x_, 1); l_x_ = None
set_inplace_requires_grad_allowed = torch._C._functorch.set_inplace_requires_grad_allowed(True); set_inplace_requires_grad_allowed = None
- child_1 = torch._functorch.eager_transforms._set_tensor_requires_grad(child); child_1 = None
+ child_1: "f32[5]" = torch._functorch.eager_transforms._set_tensor_requires_grad(child); child_1 = None
set_inplace_requires_grad_allowed_1 = torch._C._functorch.set_inplace_requires_grad_allowed(False); set_inplace_requires_grad_allowed_1 = None
- sin = child.sin(); child = None
- o = sin.sum(); sin = None
+ sin: "f32[5]" = child.sin(); child = None
+ o: "f32[]" = sin.sum(); sin = None
results: "f32[]" = torch._C._functorch._unwrap_for_grad(o, 1); o = None
@@ -3430,16 +3430,16 @@
_saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func.{grad, vjp, jacrev, hessian} don't yet support saved tensor hooks. Please open an issue with your use case."); _saved_tensors_hooks_disable = None
_grad_increment_nesting = torch._C._functorch._grad_increment_nesting(); _grad_increment_nesting = None
- child = torch._C._functorch._wrap_for_grad(l_x_, 1); l_x_ = None
+ child: "f32[5]" = torch._C._functorch._wrap_for_grad(l_x_, 1); l_x_ = None
set_inplace_requires_grad_allowed = torch._C._functorch.set_inplace_requires_grad_allowed(True); set_inplace_requires_grad_allowed = None
- child_3 = torch._functorch.eager_transforms._set_tensor_requires_grad(child)
+ child_3: "f32[5]" = torch._functorch.eager_transforms._set_tensor_requires_grad(child)
set_inplace_requires_grad_allowed_1 = torch._C._functorch.set_inplace_requires_grad_allowed(False); set_inplace_requires_grad_allowed_1 = None
- child_1 = child.sin()
- child_2 = child.cos(); child = None
+ child_1: "f32[5]" = child.sin()
+ child_2: "f32[5]" = child.cos(); child = None
_unwrap_for_grad: "f32[5]" = torch._C._functorch._unwrap_for_grad(child_1, 1)
_unwrap_for_grad_1: "f32[5]" = torch._C._functorch._unwrap_for_grad(child_2, 1)
@@ -3484,16 +3484,16 @@
_saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func.{grad, vjp, jacrev, hessian} don't yet support saved tensor hooks. Please open an issue with your use case."); _saved_tensors_hooks_disable = None
_grad_increment_nesting = torch._C._functorch._grad_increment_nesting(); _grad_increment_nesting = None
- child = torch._C._functorch._wrap_for_grad(l_x_, 1); l_x_ = None
+ child: "f32[5]" = torch._C._functorch._wrap_for_grad(l_x_, 1); l_x_ = None
set_inplace_requires_grad_allowed = torch._C._functorch.set_inplace_requires_grad_allowed(True); set_inplace_requires_grad_allowed = None
- child_3 = torch._functorch.eager_transforms._set_tensor_requires_grad(child)
+ child_3: "f32[5]" = torch._functorch.eager_transforms._set_tensor_requires_grad(child)
set_inplace_requires_grad_allowed_1 = torch._C._functorch.set_inplace_requires_grad_allowed(False); set_inplace_requires_grad_allowed_1 = None
- child_1 = child.sin()
- child_2 = child.cos(); child = None
+ child_1: "f32[5]" = child.sin()
+ child_2: "f32[5]" = child.cos(); child = None
_unwrap_for_grad: "f32[5]" = torch._C._functorch._unwrap_for_grad(child_1, 1)
_unwrap_for_grad_1: "f32[5]" = torch._C._functorch._unwrap_for_grad(child_2, 1)
@@ -3540,16 +3540,16 @@
_saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func.{grad, vjp, jacrev, hessian} don't yet support saved tensor hooks. Please open an issue with your use case."); _saved_tensors_hooks_disable = None
_grad_increment_nesting = torch._C._functorch._grad_increment_nesting(); _grad_increment_nesting = None
- child = torch._C._functorch._wrap_for_grad(l_x_, 1); l_x_ = None
+ child: "f32[5]" = torch._C._functorch._wrap_for_grad(l_x_, 1); l_x_ = None
set_inplace_requires_grad_allowed = torch._C._functorch.set_inplace_requires_grad_allowed(True); set_inplace_requires_grad_allowed = None
- child_1 = torch._functorch.eager_transforms._set_tensor_requires_grad(child); child_1 = None
+ child_1: "f32[5]" = torch._functorch.eager_transforms._set_tensor_requires_grad(child); child_1 = None
set_inplace_requires_grad_allowed_1 = torch._C._functorch.set_inplace_requires_grad_allowed(False); set_inplace_requires_grad_allowed_1 = None
- sin = child.sin()
- o = sin.sum(); sin = None
+ sin: "f32[5]" = child.sin()
+ o: "f32[]" = sin.sum(); sin = None
aux: "f32[5]" = torch._C._functorch._unwrap_for_grad(child, 1); child = aux = None
@@ -3781,19 +3781,19 @@
_saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func.{grad, vjp, jacrev, hessian} don't yet support saved tensor hooks. Please open an issue with your use case."); _saved_tensors_hooks_disable = None
_grad_increment_nesting = torch._C._functorch._grad_increment_nesting(); _grad_increment_nesting = None
- diff_args = torch._C._functorch._wrap_for_grad(l_x_, 1); l_x_ = None
+ diff_args: "f32[3, 3, 3]" = torch._C._functorch._wrap_for_grad(l_x_, 1); l_x_ = None
set_inplace_requires_grad_allowed = torch._C._functorch.set_inplace_requires_grad_allowed(True); set_inplace_requires_grad_allowed = None
- _set_tensor_requires_grad = torch._functorch.eager_transforms._set_tensor_requires_grad(diff_args); _set_tensor_requires_grad = None
+ _set_tensor_requires_grad: "f32[3, 3, 3]" = torch._functorch.eager_transforms._set_tensor_requires_grad(diff_args); _set_tensor_requires_grad = None
set_inplace_requires_grad_allowed_1 = torch._C._functorch.set_inplace_requires_grad_allowed(False); set_inplace_requires_grad_allowed_1 = None
- sin = diff_args.sin()
- output = sin.sum(); sin = None
+ sin: "f32[3, 3, 3]" = diff_args.sin()
+ output: "f32[]" = sin.sum(); sin = None
_autograd_grad = torch._functorch.eager_transforms._autograd_grad((output,), [diff_args], create_graph = True); diff_args = None
- grad_input = _autograd_grad[0]; _autograd_grad = None
+ grad_input: "f32[3, 3, 3]" = _autograd_grad[0]; _autograd_grad = None
grad_input_1: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(grad_input, 1); grad_input = None
@@ -3848,20 +3848,20 @@
_saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func.{grad, vjp, jacrev, hessian} don't yet support saved tensor hooks. Please open an issue with your use case."); _saved_tensors_hooks_disable = None
_grad_increment_nesting = torch._C._functorch._grad_increment_nesting(); _grad_increment_nesting = None
- diff_args = torch._C._functorch._wrap_for_grad(l_x_, 1); l_x_ = None
+ diff_args: "f32[3, 3, 3]" = torch._C._functorch._wrap_for_grad(l_x_, 1); l_x_ = None
set_inplace_requires_grad_allowed = torch._C._functorch.set_inplace_requires_grad_allowed(True); set_inplace_requires_grad_allowed = None
- _set_tensor_requires_grad = torch._functorch.eager_transforms._set_tensor_requires_grad(diff_args); _set_tensor_requires_grad = None
+ _set_tensor_requires_grad: "f32[3, 3, 3]" = torch._functorch.eager_transforms._set_tensor_requires_grad(diff_args); _set_tensor_requires_grad = None
set_inplace_requires_grad_allowed_1 = torch._C._functorch.set_inplace_requires_grad_allowed(False); set_inplace_requires_grad_allowed_1 = None
- sin = diff_args.sin()
- add = sin + 3; sin = None
- output = add.sum(); add = None
+ sin: "f32[3, 3, 3]" = diff_args.sin()
+ add: "f32[3, 3, 3]" = sin + 3; sin = None
+ output: "f32[]" = add.sum(); add = None
_autograd_grad = torch._functorch.eager_transforms._autograd_grad((output,), [diff_args], create_graph = True); diff_args = None
- grad_input = _autograd_grad[0]; _autograd_grad = None
+ grad_input: "f32[3, 3, 3]" = _autograd_grad[0]; _autograd_grad = None
grad_input_1: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(grad_input, 1); grad_input = None
@@ -3905,20 +3905,20 @@
_saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func.{grad, vjp, jacrev, hessian} don't yet support saved tensor hooks. Please open an issue with your use case."); _saved_tensors_hooks_disable = None
_grad_increment_nesting = torch._C._functorch._grad_increment_nesting(); _grad_increment_nesting = None
- diff_args = torch._C._functorch._wrap_for_grad(l_x_, 1); l_x_ = None
+ diff_args: "f32[3, 3, 3]" = torch._C._functorch._wrap_for_grad(l_x_, 1); l_x_ = None
set_inplace_requires_grad_allowed = torch._C._functorch.set_inplace_requires_grad_allowed(True); set_inplace_requires_grad_allowed = None
- _set_tensor_requires_grad = torch._functorch.eager_transforms._set_tensor_requires_grad(diff_args); _set_tensor_requires_grad = None
+ _set_tensor_requires_grad: "f32[3, 3, 3]" = torch._functorch.eager_transforms._set_tensor_requires_grad(diff_args); _set_tensor_requires_grad = None
set_inplace_requires_grad_allowed_1 = torch._C._functorch.set_inplace_requires_grad_allowed(False); set_inplace_requires_grad_allowed_1 = None
- sin = diff_args.sin()
- add = sin + y; sin = None
- output = add.sum(); add = None
+ sin: "f32[3, 3, 3]" = diff_args.sin()
+ add: "f32[3, 3, 3]" = sin + y; sin = None
+ output: "f32[]" = add.sum(); add = None
_autograd_grad = torch._functorch.eager_transforms._autograd_grad((output,), [diff_args], create_graph = True); diff_args = None
- grad_input = _autograd_grad[0]; _autograd_grad = None
+ grad_input: "f32[3, 3, 3]" = _autograd_grad[0]; _autograd_grad = None
grad_input_1: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(grad_input, 1); grad_input = None
@@ -3962,20 +3962,20 @@
_saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func.{grad, vjp, jacrev, hessian} don't yet support saved tensor hooks. Please open an issue with your use case."); _saved_tensors_hooks_disable = None
_grad_increment_nesting = torch._C._functorch._grad_increment_nesting(); _grad_increment_nesting = None
- diff_args = torch._C._functorch._wrap_for_grad(l_x_, 1); l_x_ = None
+ diff_args: "f32[3, 3, 3]" = torch._C._functorch._wrap_for_grad(l_x_, 1); l_x_ = None
set_inplace_requires_grad_allowed = torch._C._functorch.set_inplace_requires_grad_allowed(True); set_inplace_requires_grad_allowed = None
- _set_tensor_requires_grad = torch._functorch.eager_transforms._set_tensor_requires_grad(diff_args); _set_tensor_requires_grad = None
+ _set_tensor_requires_grad: "f32[3, 3, 3]" = torch._functorch.eager_transforms._set_tensor_requires_grad(diff_args); _set_tensor_requires_grad = None
set_inplace_requires_grad_allowed_1 = torch._C._functorch.set_inplace_requires_grad_allowed(False); set_inplace_requires_grad_allowed_1 = None
- sin = diff_args.sin()
- add = sin + 3.14; sin = None
- output = add.sum(); add = None
+ sin: "f32[3, 3, 3]" = diff_args.sin()
+ add: "f32[3, 3, 3]" = sin + 3.14; sin = None
+ output: "f32[]" = add.sum(); add = None
_autograd_grad = torch._functorch.eager_transforms._autograd_grad((output,), [diff_args], create_graph = True); diff_args = None
- grad_input = _autograd_grad[0]; _autograd_grad = None
+ grad_input: "f32[3, 3, 3]" = _autograd_grad[0]; _autograd_grad = None
grad_input_1: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(grad_input, 1); grad_input = None
@@ -4016,21 +4016,21 @@
_saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func.{grad, vjp, jacrev, hessian} don't yet support saved tensor hooks. Please open an issue with your use case."); _saved_tensors_hooks_disable = None
_grad_increment_nesting = torch._C._functorch._grad_increment_nesting(); _grad_increment_nesting = None
- diff_args = torch._C._functorch._wrap_for_grad(l_x_, 1); l_x_ = None
+ diff_args: "f32[3, 3, 3]" = torch._C._functorch._wrap_for_grad(l_x_, 1); l_x_ = None
set_inplace_requires_grad_allowed = torch._C._functorch.set_inplace_requires_grad_allowed(True); set_inplace_requires_grad_allowed = None
- _set_tensor_requires_grad = torch._functorch.eager_transforms._set_tensor_requires_grad(diff_args); _set_tensor_requires_grad = None
+ _set_tensor_requires_grad: "f32[3, 3, 3]" = torch._functorch.eager_transforms._set_tensor_requires_grad(diff_args); _set_tensor_requires_grad = None
set_inplace_requires_grad_allowed_1 = torch._C._functorch.set_inplace_requires_grad_allowed(False); set_inplace_requires_grad_allowed_1 = None
- sin = diff_args.sin()
- add = sin + 3.14; sin = None
- output = add.sum(); add = None
- aux = diff_args.cos()
+ sin: "f32[3, 3, 3]" = diff_args.sin()
+ add: "f32[3, 3, 3]" = sin + 3.14; sin = None
+ output: "f32[]" = add.sum(); add = None
+ aux: "f32[3, 3, 3]" = diff_args.cos()
_autograd_grad = torch._functorch.eager_transforms._autograd_grad((output,), [diff_args], create_graph = True); diff_args = None
- grad_input = _autograd_grad[0]; _autograd_grad = None
+ grad_input: "f32[3, 3, 3]" = _autograd_grad[0]; _autograd_grad = None
grad_input_1: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(grad_input, 1); grad_input = None
@@ -4073,22 +4073,22 @@
_saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func.{grad, vjp, jacrev, hessian} don't yet support saved tensor hooks. Please open an issue with your use case."); _saved_tensors_hooks_disable = None
_grad_increment_nesting = torch._C._functorch._grad_increment_nesting(); _grad_increment_nesting = None
- diff_args = torch._C._functorch._wrap_for_grad(l_x_, 1); l_x_ = None
- _wrap_for_grad_1 = torch._C._functorch._wrap_for_grad(l_y_, 1); l_y_ = None
+ diff_args: "f32[3, 3, 3]" = torch._C._functorch._wrap_for_grad(l_x_, 1); l_x_ = None
+ _wrap_for_grad_1: "f32[3, 3, 3]" = torch._C._functorch._wrap_for_grad(l_y_, 1); l_y_ = None
set_inplace_requires_grad_allowed = torch._C._functorch.set_inplace_requires_grad_allowed(True); set_inplace_requires_grad_allowed = None
- _set_tensor_requires_grad = torch._functorch.eager_transforms._set_tensor_requires_grad(diff_args); _set_tensor_requires_grad = None
+ _set_tensor_requires_grad: "f32[3, 3, 3]" = torch._functorch.eager_transforms._set_tensor_requires_grad(diff_args); _set_tensor_requires_grad = None
set_inplace_requires_grad_allowed_1 = torch._C._functorch.set_inplace_requires_grad_allowed(False); set_inplace_requires_grad_allowed_1 = None
- sin = diff_args.sin()
- add = sin + _wrap_for_grad_1; sin = _wrap_for_grad_1 = None
- output = add.sum(); add = None
- aux = diff_args.cos()
+ sin: "f32[3, 3, 3]" = diff_args.sin()
+ add: "f32[3, 3, 3]" = sin + _wrap_for_grad_1; sin = _wrap_for_grad_1 = None
+ output: "f32[]" = add.sum(); add = None
+ aux: "f32[3, 3, 3]" = diff_args.cos()
_autograd_grad = torch._functorch.eager_transforms._autograd_grad((output,), [diff_args], create_graph = True); diff_args = None
- grad_input = _autograd_grad[0]; _autograd_grad = None
+ grad_input: "f32[3, 3, 3]" = _autograd_grad[0]; _autograd_grad = None
grad_input_1: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(grad_input, 1); grad_input = None
@@ -4142,28 +4142,28 @@
_saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func.{grad, vjp, jacrev, hessian} don't yet support saved tensor hooks. Please open an issue with your use case."); _saved_tensors_hooks_disable = None
_grad_increment_nesting = torch._C._functorch._grad_increment_nesting(); _grad_increment_nesting = None
- child = torch._C._functorch._wrap_for_grad(l_x_, 1); l_x_ = None
- child_1 = torch._C._functorch._wrap_for_grad(l_y_, 1); l_y_ = None
+ child: "f32[3, 3, 3]" = torch._C._functorch._wrap_for_grad(l_x_, 1); l_x_ = None
+ child_1: "f32[3, 3, 3]" = torch._C._functorch._wrap_for_grad(l_y_, 1); l_y_ = None
set_inplace_requires_grad_allowed = torch._C._functorch.set_inplace_requires_grad_allowed(True); set_inplace_requires_grad_allowed = None
- _set_tensor_requires_grad = torch._functorch.eager_transforms._set_tensor_requires_grad(child); _set_tensor_requires_grad = None
+ _set_tensor_requires_grad: "f32[3, 3, 3]" = torch._functorch.eager_transforms._set_tensor_requires_grad(child); _set_tensor_requires_grad = None
set_inplace_requires_grad_allowed_1 = torch._C._functorch.set_inplace_requires_grad_allowed(False); set_inplace_requires_grad_allowed_1 = None
set_inplace_requires_grad_allowed_2 = torch._C._functorch.set_inplace_requires_grad_allowed(True); set_inplace_requires_grad_allowed_2 = None
- _set_tensor_requires_grad_1 = torch._functorch.eager_transforms._set_tensor_requires_grad(child_1); _set_tensor_requires_grad_1 = None
+ _set_tensor_requires_grad_1: "f32[3, 3, 3]" = torch._functorch.eager_transforms._set_tensor_requires_grad(child_1); _set_tensor_requires_grad_1 = None
set_inplace_requires_grad_allowed_3 = torch._C._functorch.set_inplace_requires_grad_allowed(False); set_inplace_requires_grad_allowed_3 = None
- sin = child.sin()
- add = sin + child_1; sin = None
- output = add.sum(); add = None
- aux = child.cos()
+ sin: "f32[3, 3, 3]" = child.sin()
+ add: "f32[3, 3, 3]" = sin + child_1; sin = None
+ output: "f32[]" = add.sum(); add = None
+ aux: "f32[3, 3, 3]" = child.cos()
_autograd_grad = torch._functorch.eager_transforms._autograd_grad((output,), [child, child_1], create_graph = True); child = child_1 = None
- child_2 = _autograd_grad[0]
- child_3 = _autograd_grad[1]; _autograd_grad = None
+ child_2: "f32[3, 3, 3]" = _autograd_grad[0]
+ child_3: "f32[3, 3, 3]" = _autograd_grad[1]; _autograd_grad = None
_unwrap_for_grad: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(child_2, 1); child_2 = None
_unwrap_for_grad_1: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(child_3, 1); child_3 = None
@@ -4188,28 +4188,28 @@
_saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func.{grad, vjp, jacrev, hessian} don't yet support saved tensor hooks. Please open an issue with your use case."); _saved_tensors_hooks_disable = None
_grad_increment_nesting = torch._C._functorch._grad_increment_nesting(); _grad_increment_nesting = None
- child = torch._C._functorch._wrap_for_grad(l_x_, 1); l_x_ = None
- child_1 = torch._C._functorch._wrap_for_grad(l_y_, 1); l_y_ = None
+ child: "f32[3, 3, 3]" = torch._C._functorch._wrap_for_grad(l_x_, 1); l_x_ = None
+ child_1: "f32[3, 3, 3]" = torch._C._functorch._wrap_for_grad(l_y_, 1); l_y_ = None
set_inplace_requires_grad_allowed = torch._C._functorch.set_inplace_requires_grad_allowed(True); set_inplace_requires_grad_allowed = None
- _set_tensor_requires_grad = torch._functorch.eager_transforms._set_tensor_requires_grad(child); _set_tensor_requires_grad = None
+ _set_tensor_requires_grad: "f32[3, 3, 3]" = torch._functorch.eager_transforms._set_tensor_requires_grad(child); _set_tensor_requires_grad = None
set_inplace_requires_grad_allowed_1 = torch._C._functorch.set_inplace_requires_grad_allowed(False); set_inplace_requires_grad_allowed_1 = None
set_inplace_requires_grad_allowed_2 = torch._C._functorch.set_inplace_requires_grad_allowed(True); set_inplace_requires_grad_allowed_2 = None
- _set_tensor_requires_grad_1 = torch._functorch.eager_transforms._set_tensor_requires_grad(child_1); _set_tensor_requires_grad_1 = None
+ _set_tensor_requires_grad_1: "f32[3, 3, 3]" = torch._functorch.eager_transforms._set_tensor_requires_grad(child_1); _set_tensor_requires_grad_1 = None
set_inplace_requires_grad_allowed_3 = torch._C._functorch.set_inplace_requires_grad_allowed(False); set_inplace_requires_grad_allowed_3 = None
- sin = child.sin()
- add = sin + child_1; sin = None
- output = add.sum(); add = None
- aux = child.cos()
+ sin: "f32[3, 3, 3]" = child.sin()
+ add: "f32[3, 3, 3]" = sin + child_1; sin = None
+ output: "f32[]" = add.sum(); add = None
+ aux: "f32[3, 3, 3]" = child.cos()
_autograd_grad = torch._functorch.eager_transforms._autograd_grad((output,), [child, child_1], create_graph = True); child = child_1 = None
- child_2 = _autograd_grad[0]
- child_3 = _autograd_grad[1]; _autograd_grad = None
+ child_2: "f32[3, 3, 3]" = _autograd_grad[0]
+ child_3: "f32[3, 3, 3]" = _autograd_grad[1]; _autograd_grad = None
_unwrap_for_grad: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(child_2, 1); child_2 = None
_unwrap_for_grad_1: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(child_3, 1); child_3 = None
@@ -4250,39 +4250,39 @@
_saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func.{grad, vjp, jacrev, hessian} don't yet support saved tensor hooks. Please open an issue with your use case."); _saved_tensors_hooks_disable = None
_grad_increment_nesting = torch._C._functorch._grad_increment_nesting(); _grad_increment_nesting = None
- diff_args = torch._C._functorch._wrap_for_grad(l_x_, 1); l_x_ = None
+ diff_args: "f32[]" = torch._C._functorch._wrap_for_grad(l_x_, 1); l_x_ = None
set_inplace_requires_grad_allowed = torch._C._functorch.set_inplace_requires_grad_allowed(True); set_inplace_requires_grad_allowed = None
- _set_tensor_requires_grad = torch._functorch.eager_transforms._set_tensor_requires_grad(diff_args); _set_tensor_requires_grad = None
+ _set_tensor_requires_grad: "f32[]" = torch._functorch.eager_transforms._set_tensor_requires_grad(diff_args); _set_tensor_requires_grad = None
set_inplace_requires_grad_allowed_1 = torch._C._functorch.set_inplace_requires_grad_allowed(False); set_inplace_requires_grad_allowed_1 = None
_saved_tensors_hooks_disable_1 = torch._C._autograd._saved_tensors_hooks_disable("torch.func.{grad, vjp, jacrev, hessian} don't yet support saved tensor hooks. Please open an issue with your use case."); _saved_tensors_hooks_disable_1 = None
_grad_increment_nesting_1 = torch._C._functorch._grad_increment_nesting(); _grad_increment_nesting_1 = None
- diff_args_1 = torch._C._functorch._wrap_for_grad(diff_args, 2)
+ diff_args_1: "f32[]" = torch._C._functorch._wrap_for_grad(diff_args, 2)
set_inplace_requires_grad_allowed_2 = torch._C._functorch.set_inplace_requires_grad_allowed(True); set_inplace_requires_grad_allowed_2 = None
- _set_tensor_requires_grad_1 = torch._functorch.eager_transforms._set_tensor_requires_grad(diff_args_1); _set_tensor_requires_grad_1 = None
+ _set_tensor_requires_grad_1: "f32[]" = torch._functorch.eager_transforms._set_tensor_requires_grad(diff_args_1); _set_tensor_requires_grad_1 = None
set_inplace_requires_grad_allowed_3 = torch._C._functorch.set_inplace_requires_grad_allowed(False); set_inplace_requires_grad_allowed_3 = None
- sin = diff_args_1.sin()
- output = sin.sum(); sin = None
+ sin: "f32[]" = diff_args_1.sin()
+ output: "f32[]" = sin.sum(); sin = None
_autograd_grad = torch._functorch.eager_transforms._autograd_grad((output,), [diff_args_1], create_graph = True); diff_args_1 = None
- grad_input = _autograd_grad[0]; _autograd_grad = None
+ grad_input: "f32[]" = _autograd_grad[0]; _autograd_grad = None
- grad_input_1 = torch._C._functorch._unwrap_for_grad(grad_input, 2); grad_input = None
+ grad_input_1: "f32[]" = torch._C._functorch._unwrap_for_grad(grad_input, 2); grad_input = None
- output_1 = torch._C._functorch._unwrap_for_grad(output, 2); output = output_1 = None
+ output_1: "f32[]" = torch._C._functorch._unwrap_for_grad(output, 2); output = output_1 = None
_grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting(); _grad_decrement_nesting = None
_saved_tensors_hooks_disable_2 = torch._C._autograd._saved_tensors_hooks_disable("torch.func.{grad, vjp, jacrev, hessian} don't yet support saved tensor hooks. Please open an issue with your use case."); _saved_tensors_hooks_disable_2 = None
_autograd_grad_1 = torch._functorch.eager_transforms._autograd_grad((grad_input_1,), [diff_args], create_graph = True); diff_args = None
- grad_input_2 = _autograd_grad_1[0]; _autograd_grad_1 = None
+ grad_input_2: "f32[]" = _autograd_grad_1[0]; _autograd_grad_1 = None
grad_input_3: "f32[]" = torch._C._functorch._unwrap_for_grad(grad_input_2, 1); grad_input_2 = None
@@ -4375,20 +4375,20 @@
_saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func.{grad, vjp, jacrev, hessian} don't yet support saved tensor hooks. Please open an issue with your use case."); _saved_tensors_hooks_disable = None
_grad_increment_nesting = torch._C._functorch._grad_increment_nesting(); _grad_increment_nesting = None
- diff_args = torch._C._functorch._wrap_for_grad(l_x_, 1); l_x_ = None
+ diff_args: "f32[3, 3, 3]" = torch._C._functorch._wrap_for_grad(l_x_, 1); l_x_ = None
set_inplace_requires_grad_allowed = torch._C._functorch.set_inplace_requires_grad_allowed(True); set_inplace_requires_grad_allowed = None
- _set_tensor_requires_grad = torch._functorch.eager_transforms._set_tensor_requires_grad(diff_args); _set_tensor_requires_grad = None
+ _set_tensor_requires_grad: "f32[3, 3, 3]" = torch._functorch.eager_transforms._set_tensor_requires_grad(diff_args); _set_tensor_requires_grad = None
set_inplace_requires_grad_allowed_1 = torch._C._functorch.set_inplace_requires_grad_allowed(False); set_inplace_requires_grad_allowed_1 = None
- sin = diff_args.sin()
- sum_1 = sin.sum(); sin = None
- output = sum_1 + 3.0; sum_1 = None
+ sin: "f32[3, 3, 3]" = diff_args.sin()
+ sum_1: "f32[]" = sin.sum(); sin = None
+ output: "f32[]" = sum_1 + 3.0; sum_1 = None
_autograd_grad = torch._functorch.eager_transforms._autograd_grad((output,), [diff_args], create_graph = True); diff_args = None
- grad_input = _autograd_grad[0]; _autograd_grad = None
+ grad_input: "f32[3, 3, 3]" = _autograd_grad[0]; _autograd_grad = None
grad_input_1: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(grad_input, 1); grad_input = None
@@ -4478,7 +4478,7 @@
_vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(12, 'error'); _vmap_increment_nesting = None
- child_1 = torch._C._functorch._add_batch_dim(child, 0, 1); child = None
+ child_1: "f32[4, 3]" = torch._C._functorch._add_batch_dim(child, 0, 1); child = None
_jvp_treespec_compare = torch._functorch.eager_transforms._jvp_treespec_compare((l_x_,), (child_1,)); _jvp_treespec_compare = None
@@ -4488,19 +4488,19 @@
_maybe_load_decompositions = torch.autograd.forward_ad._maybe_load_decompositions(); _maybe_load_decompositions = None
- _make_dual = torch._make_dual(l_x_, child_1, level = 0); child_1 = None
+ _make_dual: "f32[4, 3]" = torch._make_dual(l_x_, child_1, level = 0); child_1 = None
- _wrap_for_grad = torch._C._functorch._wrap_for_grad(l_x_, 2); l_x_ = _wrap_for_grad = None
+ _wrap_for_grad: "f32[4, 3]" = torch._C._functorch._wrap_for_grad(l_x_, 2); l_x_ = _wrap_for_grad = None
- result_duals = torch.sin(_make_dual); _make_dual = None
+ result_duals: "f32[4, 3]" = torch.sin(_make_dual); _make_dual = None
_unpack_dual = torch._unpack_dual(result_duals, level = 0); result_duals = None
- primal = _unpack_dual[0]
- dual = _unpack_dual[1]; _unpack_dual = None
+ primal: "f32[4, 3]" = _unpack_dual[0]
+ dual: "f32[4, 3]" = _unpack_dual[1]; _unpack_dual = None
primals_out_unflatten: "f32[4, 3]" = torch._C._functorch._unwrap_for_grad(primal, 2); primal = primals_out_unflatten = None
- tangents_out_unflatten = torch._C._functorch._unwrap_for_grad(dual, 2); dual = None
+ tangents_out_unflatten: "f32[4, 3]" = torch._C._functorch._unwrap_for_grad(dual, 2); dual = None
_exit_dual_level = torch._C._exit_dual_level(0); _exit_dual_level = None
_set_fwd_grad_enabled_1 = torch._C._set_fwd_grad_enabled(True); _set_fwd_grad_enabled_1 = None
@@ -4561,7 +4561,7 @@
_vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(12, 'error'); _vmap_increment_nesting = None
- child_1 = torch._C._functorch._add_batch_dim(child, 0, 1); child = None
+ child_1: "f32[3, 4]" = torch._C._functorch._add_batch_dim(child, 0, 1); child = None
_jvp_treespec_compare = torch._functorch.eager_transforms._jvp_treespec_compare((l_y_,), (child_1,)); _jvp_treespec_compare = None
@@ -4571,20 +4571,20 @@
_maybe_load_decompositions = torch.autograd.forward_ad._maybe_load_decompositions(); _maybe_load_decompositions = None
- _make_dual = torch._make_dual(l_y_, child_1, level = 0); child_1 = None
+ _make_dual: "f32[3, 4]" = torch._make_dual(l_y_, child_1, level = 0); child_1 = None
- _wrap_for_grad = torch._C._functorch._wrap_for_grad(l_x_, 2); l_x_ = _wrap_for_grad = None
- _wrap_for_grad_1 = torch._C._functorch._wrap_for_grad(l_y_, 2); l_y_ = _wrap_for_grad_1 = None
+ _wrap_for_grad: "f32[4, 3]" = torch._C._functorch._wrap_for_grad(l_x_, 2); l_x_ = _wrap_for_grad = None
+ _wrap_for_grad_1: "f32[3, 4]" = torch._C._functorch._wrap_for_grad(l_y_, 2); l_y_ = _wrap_for_grad_1 = None
- result_duals = _make_dual.sin(); _make_dual = None
+ result_duals: "f32[3, 4]" = _make_dual.sin(); _make_dual = None
_unpack_dual = torch._unpack_dual(result_duals, level = 0); result_duals = None
- primal = _unpack_dual[0]
- dual = _unpack_dual[1]; _unpack_dual = None
+ primal: "f32[3, 4]" = _unpack_dual[0]
+ dual: "f32[3, 4]" = _unpack_dual[1]; _unpack_dual = None
primals_out_unflatten: "f32[3, 4]" = torch._C._functorch._unwrap_for_grad(primal, 2); primal = primals_out_unflatten = None
- tangents_out_unflatten = torch._C._functorch._unwrap_for_grad(dual, 2); dual = None
+ tangents_out_unflatten: "f32[3, 4]" = torch._C._functorch._unwrap_for_grad(dual, 2); dual = None
_exit_dual_level = torch._C._exit_dual_level(0); _exit_dual_level = None
_set_fwd_grad_enabled_1 = torch._C._set_fwd_grad_enabled(True); _set_fwd_grad_enabled_1 = None
@@ -4645,7 +4645,7 @@
_vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(12, 'error'); _vmap_increment_nesting = None
- child_1 = torch._C._functorch._add_batch_dim(child, 0, 1); child = None
+ child_1: "f32[3, 4]" = torch._C._functorch._add_batch_dim(child, 0, 1); child = None
_jvp_treespec_compare = torch._functorch.eager_transforms._jvp_treespec_compare((l_y_,), (child_1,)); _jvp_treespec_compare = None
@@ -4655,22 +4655,22 @@
_maybe_load_decompositions = torch.autograd.forward_ad._maybe_load_decompositions(); _maybe_load_decompositions = None
- _make_dual = torch._make_dual(l_y_, child_1, level = 0); child_1 = None
+ _make_dual: "f32[3, 4]" = torch._make_dual(l_y_, child_1, level = 0); child_1 = None
- aux = torch._C._functorch._wrap_for_grad(l_x_, 2); l_x_ = None
- _wrap_for_grad_1 = torch._C._functorch._wrap_for_grad(l_y_, 2); l_y_ = _wrap_for_grad_1 = None
+ aux: "f32[4, 3]" = torch._C._functorch._wrap_for_grad(l_x_, 2); l_x_ = None
+ _wrap_for_grad_1: "f32[3, 4]" = torch._C._functorch._wrap_for_grad(l_y_, 2); l_y_ = _wrap_for_grad_1 = None
- result_duals = _make_dual.sin(); _make_dual = None
+ result_duals: "f32[3, 4]" = _make_dual.sin(); _make_dual = None
aux_1: "f32[4, 3]" = torch._C._functorch._unwrap_for_grad(aux, 2); aux = None
_unpack_dual = torch._unpack_dual(result_duals, level = 0); result_duals = None
- primal = _unpack_dual[0]
- dual = _unpack_dual[1]; _unpack_dual = None
+ primal: "f32[3, 4]" = _unpack_dual[0]
+ dual: "f32[3, 4]" = _unpack_dual[1]; _unpack_dual = None
primals_out_unflatten: "f32[3, 4]" = torch._C._functorch._unwrap_for_grad(primal, 2); primal = primals_out_unflatten = None
- tangents_out_unflatten = torch._C._functorch._unwrap_for_grad(dual, 2); dual = None
+ tangents_out_unflatten: "f32[3, 4]" = torch._C._functorch._unwrap_for_grad(dual, 2); dual = None
_exit_dual_level = torch._C._exit_dual_level(0); _exit_dual_level = None
_set_fwd_grad_enabled_1 = torch._C._set_fwd_grad_enabled(True); _set_fwd_grad_enabled_1 = None
@@ -4734,7 +4734,7 @@
_vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(12, 'same'); _vmap_increment_nesting = None
- child_1 = torch._C._functorch._add_batch_dim(child, 0, 1); child = None
+ child_1: "f32[4, 3]" = torch._C._functorch._add_batch_dim(child, 0, 1); child = None
_jvp_treespec_compare = torch._functorch.eager_transforms._jvp_treespec_compare((l_x_,), (child_1,)); _jvp_treespec_compare = None
@@ -4744,27 +4744,27 @@
_maybe_load_decompositions = torch.autograd.forward_ad._maybe_load_decompositions(); _maybe_load_decompositions = None
- child_3 = torch._make_dual(l_x_, child_1, level = 0); child_1 = None
+ child_3: "f32[4, 3]" = torch._make_dual(l_x_, child_1, level = 0); child_1 = None
- _wrap_for_grad = torch._C._functorch._wrap_for_grad(l_x_, 2); l_x_ = _wrap_for_grad = None
- _wrap_for_grad_1 = torch._C._functorch._wrap_for_grad(l_y_, 2); l_y_ = None
+ _wrap_for_grad: "f32[4, 3]" = torch._C._functorch._wrap_for_grad(l_x_, 2); l_x_ = _wrap_for_grad = None
+ _wrap_for_grad_1: "f32[3, 4]" = torch._C._functorch._wrap_for_grad(l_y_, 2); l_y_ = None
- child_2 = _wrap_for_grad_1.sin(); _wrap_for_grad_1 = None
+ child_2: "f32[3, 4]" = _wrap_for_grad_1.sin(); _wrap_for_grad_1 = None
_unpack_dual = torch._unpack_dual(child_2, level = 0); child_2 = None
- primal = _unpack_dual[0]; _unpack_dual = None
+ primal: "f32[3, 4]" = _unpack_dual[0]; _unpack_dual = None
- tangent = torch.zeros_like(primal)
+ tangent: "f32[3, 4]" = torch.zeros_like(primal)
_unpack_dual_1 = torch._unpack_dual(child_3, level = 0); child_3 = None
- primal_1 = _unpack_dual_1[0]
- dual = _unpack_dual_1[1]; _unpack_dual_1 = None
+ primal_1: "f32[4, 3]" = _unpack_dual_1[0]
+ dual: "f32[4, 3]" = _unpack_dual_1[1]; _unpack_dual_1 = None
child_4: "f32[3, 4]" = torch._C._functorch._unwrap_for_grad(primal, 2); primal = child_4 = None
child_5: "f32[4, 3]" = torch._C._functorch._unwrap_for_grad(primal_1, 2); primal_1 = child_5 = None
child_6: "f32[3, 4]" = torch._C._functorch._unwrap_for_grad(tangent, 2); tangent = None
- child_7 = torch._C._functorch._unwrap_for_grad(dual, 2); dual = None
+ child_7: "f32[4, 3]" = torch._C._functorch._unwrap_for_grad(dual, 2); dual = None
_exit_dual_level = torch._C._exit_dual_level(0); _exit_dual_level = None
_set_fwd_grad_enabled_1 = torch._C._set_fwd_grad_enabled(True); _set_fwd_grad_enabled_1 = None
@@ -4850,14 +4850,14 @@
_maybe_load_decompositions = torch.autograd.forward_ad._maybe_load_decompositions(); _maybe_load_decompositions = None
- _make_dual = torch._make_dual(l_x_, l_v_, level = 0); l_x_ = l_v_ = None
+ _make_dual: "f32[3, 3]" = torch._make_dual(l_x_, l_v_, level = 0); l_x_ = l_v_ = None
- sin = _make_dual.sin(); _make_dual = None
- result_duals = sin.sum(); sin = None
+ sin: "f32[3, 3]" = _make_dual.sin(); _make_dual = None
+ result_duals: "f32[]" = sin.sum(); sin = None
_unpack_dual = torch._unpack_dual(result_duals, level = 0); result_duals = None
- primal = _unpack_dual[0]
- dual = _unpack_dual[1]; _unpack_dual = None
+ primal: "f32[]" = _unpack_dual[0]
+ dual: "f32[]" = _unpack_dual[1]; _unpack_dual = None
primals_out_unflatten: "f32[]" = torch._C._functorch._unwrap_for_grad(primal, 1); primal = None
@@ -4904,16 +4904,16 @@
_maybe_load_decompositions = torch.autograd.forward_ad._maybe_load_decompositions(); _maybe_load_decompositions = None
- aux = torch._make_dual(l_x_, l_v_, level = 0); l_x_ = l_v_ = None
+ aux: "f32[3, 3]" = torch._make_dual(l_x_, l_v_, level = 0); l_x_ = l_v_ = None
- sin = aux.sin()
- result_duals = sin.sum(); sin = None
+ sin: "f32[3, 3]" = aux.sin()
+ result_duals: "f32[]" = sin.sum(); sin = None
aux_1: "f32[3, 3]" = torch._C._functorch._unwrap_for_grad(aux, 1); aux = None
_unpack_dual = torch._unpack_dual(result_duals, level = 0); result_duals = None
- primal = _unpack_dual[0]
- dual = _unpack_dual[1]; _unpack_dual = None
+ primal: "f32[]" = _unpack_dual[0]
+ dual: "f32[]" = _unpack_dual[1]; _unpack_dual = None
primals_out_unflatten: "f32[]" = torch._C._functorch._unwrap_for_grad(primal, 1); primal = None
@@ -4962,22 +4962,22 @@
_maybe_load_decompositions = torch.autograd.forward_ad._maybe_load_decompositions(); _maybe_load_decompositions = None
- aux = torch._make_dual(l_x_, l_v_, level = 0); l_x_ = None
+ aux: "f32[3, 3]" = torch._make_dual(l_x_, l_v_, level = 0); l_x_ = None
_maybe_load_decompositions_1 = torch.autograd.forward_ad._maybe_load_decompositions(); _maybe_load_decompositions_1 = None
- _make_dual_1 = torch._make_dual(l_y_, l_v_, level = 0); l_y_ = l_v_ = None
+ _make_dual_1: "f32[3, 3]" = torch._make_dual(l_y_, l_v_, level = 0); l_y_ = l_v_ = None
- sin = aux.sin()
- sum_1 = sin.sum(); sin = None
- cos = _make_dual_1.cos(); _make_dual_1 = None
- result_duals = sum_1 + cos; sum_1 = cos = None
+ sin: "f32[3, 3]" = aux.sin()
+ sum_1: "f32[]" = sin.sum(); sin = None
+ cos: "f32[3, 3]" = _make_dual_1.cos(); _make_dual_1 = None
+ result_duals: "f32[3, 3]" = sum_1 + cos; sum_1 = cos = None
aux_1: "f32[3, 3]" = torch._C._functorch._unwrap_for_grad(aux, 1); aux = None
_unpack_dual = torch._unpack_dual(result_duals, level = 0); result_duals = None
- primal = _unpack_dual[0]
- dual = _unpack_dual[1]; _unpack_dual = None
+ primal: "f32[3, 3]" = _unpack_dual[0]
+ dual: "f32[3, 3]" = _unpack_dual[1]; _unpack_dual = None
primals_out_unflatten: "f32[3, 3]" = torch._C._functorch._unwrap_for_grad(primal, 1); primal = None
@@ -5027,14 +5027,14 @@
_maybe_load_decompositions = torch.autograd.forward_ad._maybe_load_decompositions(); _maybe_load_decompositions = None
- _make_dual = torch._make_dual(l_x_, l_v_, level = 0); l_x_ = l_v_ = None
+ _make_dual: "f32[3, 3]" = torch._make_dual(l_x_, l_v_, level = 0); l_x_ = l_v_ = None
- sin = _make_dual.sin(); _make_dual = None
- result_duals = sin.sum(); sin = None
+ sin: "f32[3, 3]" = _make_dual.sin(); _make_dual = None
+ result_duals: "f32[]" = sin.sum(); sin = None
_unpack_dual = torch._unpack_dual(result_duals, level = 0); result_duals = None
- primal = _unpack_dual[0]
- dual = _unpack_dual[1]; _unpack_dual = None
+ primal: "f32[]" = _unpack_dual[0]
+ dual: "f32[]" = _unpack_dual[1]; _unpack_dual = None
primals_out_unflatten: "f32[]" = torch._C._functorch._unwrap_for_grad(primal, 1); primal = None
@@ -5098,14 +5098,14 @@
_maybe_load_decompositions = torch.autograd.forward_ad._maybe_load_decompositions(); _maybe_load_decompositions = None
- _make_dual = torch._make_dual(l_x_, l_v_, level = 0); l_x_ = l_v_ = None
+ _make_dual: "f32[3, 3]" = torch._make_dual(l_x_, l_v_, level = 0); l_x_ = l_v_ = None
- sin = _make_dual.sin(); _make_dual = None
- result_duals = sin.sum(); sin = None
+ sin: "f32[3, 3]" = _make_dual.sin(); _make_dual = None
+ result_duals: "f32[]" = sin.sum(); sin = None
_unpack_dual = torch._unpack_dual(result_duals, level = 0); result_duals = None
- primal = _unpack_dual[0]
- dual = _unpack_dual[1]; _unpack_dual = None
+ primal: "f32[]" = _unpack_dual[0]
+ dual: "f32[]" = _unpack_dual[1]; _unpack_dual = None
primals_out_unflatten: "f32[]" = torch._C._functorch._unwrap_for_grad(primal, 1); primal = None
@@ -5171,7 +5171,7 @@
_maybe_load_decompositions = torch.autograd.forward_ad._maybe_load_decompositions(); _maybe_load_decompositions = None
- child = torch._make_dual(l_x_, l_x_, level = 0); l_x_ = None
+ child: "f32[3, 3, 3]" = torch._make_dual(l_x_, l_x_, level = 0); l_x_ = None
_jvp_treespec_compare_1 = torch._functorch.eager_transforms._jvp_treespec_compare((child,), (child,)); _jvp_treespec_compare_1 = None
@@ -5180,27 +5180,27 @@
_maybe_load_decompositions_1 = torch.autograd.forward_ad._maybe_load_decompositions(); _maybe_load_decompositions_1 = None
- _make_dual_1 = torch._make_dual(child, child, level = 0); child = None
+ _make_dual_1: "f32[3, 3, 3]" = torch._make_dual(child, child, level = 0); child = None
- result_duals = torch.sin(_make_dual_1); _make_dual_1 = None
+ result_duals: "f32[3, 3, 3]" = torch.sin(_make_dual_1); _make_dual_1 = None
_unpack_dual = torch._unpack_dual(result_duals, level = 0); result_duals = None
- primal = _unpack_dual[0]
- dual = _unpack_dual[1]; _unpack_dual = None
+ primal: "f32[3, 3, 3]" = _unpack_dual[0]
+ dual: "f32[3, 3, 3]" = _unpack_dual[1]; _unpack_dual = None
- primals_out_unflatten = torch._C._functorch._unwrap_for_grad(primal, 2); primal = None
+ primals_out_unflatten: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(primal, 2); primal = None
- tangents_out_unflatten = torch._C._functorch._unwrap_for_grad(dual, 2); dual = None
+ tangents_out_unflatten: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(dual, 2); dual = None
_set_fwd_grad_enabled_2 = torch._C._set_fwd_grad_enabled(True); _set_fwd_grad_enabled_2 = None
_jvp_decrement_nesting = torch._C._functorch._jvp_decrement_nesting(); _jvp_decrement_nesting = None
_unpack_dual_1 = torch._unpack_dual(primals_out_unflatten, level = 0); primals_out_unflatten = None
- primal_1 = _unpack_dual_1[0]
- dual_1 = _unpack_dual_1[1]; _unpack_dual_1 = None
+ primal_1: "f32[3, 3, 3]" = _unpack_dual_1[0]
+ dual_1: "f32[3, 3, 3]" = _unpack_dual_1[1]; _unpack_dual_1 = None
_unpack_dual_2 = torch._unpack_dual(tangents_out_unflatten, level = 0); tangents_out_unflatten = None
- primal_2 = _unpack_dual_2[0]
- dual_2 = _unpack_dual_2[1]; _unpack_dual_2 = None
+ primal_2: "f32[3, 3, 3]" = _unpack_dual_2[0]
+ dual_2: "f32[3, 3, 3]" = _unpack_dual_2[1]; _unpack_dual_2 = None
_unwrap_for_grad_2: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(primal_1, 1); primal_1 = None
_unwrap_for_grad_3: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(primal_2, 1); primal_2 = None
@@ -5516,11 +5516,11 @@
_vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(3, 'error'); _vmap_increment_nesting = None
- _add_batch_dim = torch._C._functorch._add_batch_dim(l_x_, 0, 1); l_x_ = None
+ _add_batch_dim: "f32[3, 3]" = torch._C._functorch._add_batch_dim(l_x_, 0, 1); l_x_ = None
- sum_1 = _add_batch_dim.sum(0)
- sum_2 = _add_batch_dim.sum(1); _add_batch_dim = None
- batched_outputs = sum_1 + sum_2; sum_1 = sum_2 = None
+ sum_1: "f32[3]" = _add_batch_dim.sum(0)
+ sum_2: "f32[3]" = _add_batch_dim.sum(1); _add_batch_dim = None
+ batched_outputs: "f32[3]" = sum_1 + sum_2; sum_1 = sum_2 = None
_remove_batch_dim: "f32[3, 3]" = torch._C._functorch._remove_batch_dim(batched_outputs, 1, 3, 0); batched_outputs = None
@@ -5554,12 +5554,12 @@
_vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(3, 'error'); _vmap_increment_nesting = None
- _add_batch_dim = torch._C._functorch._add_batch_dim(l_x_, 0, 1); l_x_ = None
+ _add_batch_dim: "f32[3, 3]" = torch._C._functorch._add_batch_dim(l_x_, 0, 1); l_x_ = None
- sum_1 = _add_batch_dim.sum(0)
- sum_2 = _add_batch_dim.sum(1); _add_batch_dim = None
- add = sum_1 + sum_2; sum_1 = sum_2 = None
- batched_outputs = add + 3; add = None
+ sum_1: "f32[3]" = _add_batch_dim.sum(0)
+ sum_2: "f32[3]" = _add_batch_dim.sum(1); _add_batch_dim = None
+ add: "f32[3]" = sum_1 + sum_2; sum_1 = sum_2 = None
+ batched_outputs: "f32[3]" = add + 3; add = None
_remove_batch_dim: "f32[3, 3]" = torch._C._functorch._remove_batch_dim(batched_outputs, 1, 3, 0); batched_outputs = None
@@ -5594,12 +5594,12 @@
_vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(3, 'error'); _vmap_increment_nesting = None
- _add_batch_dim = torch._C._functorch._add_batch_dim(l_x_, 0, 1); l_x_ = None
+ _add_batch_dim: "f32[3, 3]" = torch._C._functorch._add_batch_dim(l_x_, 0, 1); l_x_ = None
- sum_1 = _add_batch_dim.sum(0)
- sum_2 = _add_batch_dim.sum(1); _add_batch_dim = None
- add = sum_1 + sum_2; sum_1 = sum_2 = None
- batched_outputs = add + l_y_; add = l_y_ = None
+ sum_1: "f32[3]" = _add_batch_dim.sum(0)
+ sum_2: "f32[3]" = _add_batch_dim.sum(1); _add_batch_dim = None
+ add: "f32[3]" = sum_1 + sum_2; sum_1 = sum_2 = None
+ batched_outputs: "f32[3, 3]" = add + l_y_; add = l_y_ = None
_remove_batch_dim: "f32[3, 3, 3]" = torch._C._functorch._remove_batch_dim(batched_outputs, 1, 3, 0); batched_outputs = None
@@ -5635,13 +5635,13 @@
_vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(3, 'error'); _vmap_increment_nesting = None
- _add_batch_dim = torch._C._functorch._add_batch_dim(l_x_, 0, 1); l_x_ = None
- _add_batch_dim_1 = torch._C._functorch._add_batch_dim(l_y_, 1, 1); l_y_ = None
+ _add_batch_dim: "f32[3, 3]" = torch._C._functorch._add_batch_dim(l_x_, 0, 1); l_x_ = None
+ _add_batch_dim_1: "f32[3]" = torch._C._functorch._add_batch_dim(l_y_, 1, 1); l_y_ = None
- sum_1 = _add_batch_dim.sum(0)
- sum_2 = _add_batch_dim.sum(1); _add_batch_dim = None
- add = sum_1 + sum_2; sum_1 = sum_2 = None
- batched_outputs = add + _add_batch_dim_1; add = _add_batch_dim_1 = None
+ sum_1: "f32[3]" = _add_batch_dim.sum(0)
+ sum_2: "f32[3]" = _add_batch_dim.sum(1); _add_batch_dim = None
+ add: "f32[3]" = sum_1 + sum_2; sum_1 = sum_2 = None
+ batched_outputs: "f32[3]" = add + _add_batch_dim_1; add = _add_batch_dim_1 = None
_remove_batch_dim: "f32[3, 3]" = torch._C._functorch._remove_batch_dim(batched_outputs, 1, 3, 0); batched_outputs = None
@@ -5679,13 +5679,13 @@
_vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(3, 'error'); _vmap_increment_nesting = None
- _add_batch_dim = torch._C._functorch._add_batch_dim(l_x_, 0, 1); l_x_ = None
- _add_batch_dim_1 = torch._C._functorch._add_batch_dim(l_y_, 1, 1); l_y_ = None
+ _add_batch_dim: "f32[3, 3]" = torch._C._functorch._add_batch_dim(l_x_, 0, 1); l_x_ = None
+ _add_batch_dim_1: "f32[3]" = torch._C._functorch._add_batch_dim(l_y_, 1, 1); l_y_ = None
- sum_1 = _add_batch_dim.sum(0)
- sum_2 = _add_batch_dim.sum(1); _add_batch_dim = None
- add = sum_1 + sum_2; sum_1 = sum_2 = None
- batched_outputs = add + _add_batch_dim_1; add = _add_batch_dim_1 = None
+ sum_1: "f32[3]" = _add_batch_dim.sum(0)
+ sum_2: "f32[3]" = _add_batch_dim.sum(1); _add_batch_dim = None
+ add: "f32[3]" = sum_1 + sum_2; sum_1 = sum_2 = None
+ batched_outputs: "f32[3]" = add + _add_batch_dim_1; add = _add_batch_dim_1 = None
_remove_batch_dim: "f32[3, 3]" = torch._C._functorch._remove_batch_dim(batched_outputs, 1, 3, 0); batched_outputs = None
@@ -5719,19 +5719,19 @@
_vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(3, 'error'); _vmap_increment_nesting = None
- child = torch._C._functorch._add_batch_dim(l_x_, 0, 1); l_x_ = None
- child_1 = torch._C._functorch._add_batch_dim(l_y_, 0, 1); l_y_ = None
+ child: "f32[3, 3]" = torch._C._functorch._add_batch_dim(l_x_, 0, 1); l_x_ = None
+ child_1: "f32[3, 3]" = torch._C._functorch._add_batch_dim(l_y_, 0, 1); l_y_ = None
lazy_load_decompositions_1 = torch._functorch.vmap.lazy_load_decompositions(); lazy_load_decompositions_1 = None
_vmap_increment_nesting_1 = torch._C._functorch._vmap_increment_nesting(3, 'error'); _vmap_increment_nesting_1 = None
- _add_batch_dim_2 = torch._C._functorch._add_batch_dim(child, 1, 2); child = None
- _add_batch_dim_3 = torch._C._functorch._add_batch_dim(child_1, 1, 2); child_1 = None
+ _add_batch_dim_2: "f32[3]" = torch._C._functorch._add_batch_dim(child, 1, 2); child = None
+ _add_batch_dim_3: "f32[3]" = torch._C._functorch._add_batch_dim(child_1, 1, 2); child_1 = None
- batched_outputs = _add_batch_dim_2 + _add_batch_dim_3; _add_batch_dim_2 = _add_batch_dim_3 = None
+ batched_outputs: "f32[3]" = _add_batch_dim_2 + _add_batch_dim_3; _add_batch_dim_2 = _add_batch_dim_3 = None
- batched_outputs_1 = torch._C._functorch._remove_batch_dim(batched_outputs, 2, 3, 0); batched_outputs = None
+ batched_outputs_1: "f32[3, 3]" = torch._C._functorch._remove_batch_dim(batched_outputs, 2, 3, 0); batched_outputs = None
_vmap_decrement_nesting = torch._C._functorch._vmap_decrement_nesting(); _vmap_decrement_nesting = None
@@ -5768,17 +5768,17 @@
_vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(5, 'error'); _vmap_increment_nesting = None
- child = torch._C._functorch._add_batch_dim(l_y_, 0, 1); l_y_ = None
+ child: "f32[3]" = torch._C._functorch._add_batch_dim(l_y_, 0, 1); l_y_ = None
lazy_load_decompositions_1 = torch._functorch.vmap.lazy_load_decompositions(); lazy_load_decompositions_1 = None
_vmap_increment_nesting_1 = torch._C._functorch._vmap_increment_nesting(3, 'error'); _vmap_increment_nesting_1 = None
- _add_batch_dim_1 = torch._C._functorch._add_batch_dim(child, 0, 2); child = None
+ _add_batch_dim_1: "f32[]" = torch._C._functorch._add_batch_dim(child, 0, 2); child = None
- batched_outputs = l_x_ * _add_batch_dim_1; l_x_ = _add_batch_dim_1 = None
+ batched_outputs: "f32[2, 3]" = l_x_ * _add_batch_dim_1; l_x_ = _add_batch_dim_1 = None
- batched_outputs_1 = torch._C._functorch._remove_batch_dim(batched_outputs, 2, 3, 0); batched_outputs = None
+ batched_outputs_1: "f32[3, 2, 3]" = torch._C._functorch._remove_batch_dim(batched_outputs, 2, 3, 0); batched_outputs = None
_vmap_decrement_nesting = torch._C._functorch._vmap_decrement_nesting(); _vmap_decrement_nesting = None
@@ -5813,10 +5813,10 @@
_vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(2, 'error'); _vmap_increment_nesting = None
- _add_batch_dim = torch._C._functorch._add_batch_dim(l_x_, 0, 1); l_x_ = None
+ _add_batch_dim: "f32[4, 3]" = torch._C._functorch._add_batch_dim(l_x_, 0, 1); l_x_ = None
- child = _add_batch_dim.sum(0)
- child_1 = _add_batch_dim.sum(1); _add_batch_dim = None
+ child: "f32[3]" = _add_batch_dim.sum(0)
+ child_1: "f32[4]" = _add_batch_dim.sum(1); _add_batch_dim = None
_remove_batch_dim: "f32[2, 3]" = torch._C._functorch._remove_batch_dim(child, 1, 2, 0); child = None
_remove_batch_dim_1: "f32[2, 4]" = torch._C._functorch._remove_batch_dim(child_1, 1, 2, 0); child_1 = None
@@ -5850,10 +5850,10 @@
_vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(2, 'error'); _vmap_increment_nesting = None
- _add_batch_dim = torch._C._functorch._add_batch_dim(l_x_, 0, 1); l_x_ = None
+ _add_batch_dim: "f32[4, 3]" = torch._C._functorch._add_batch_dim(l_x_, 0, 1); l_x_ = None
- child = _add_batch_dim.sum(0)
- child_1 = _add_batch_dim.sum(1); _add_batch_dim = None
+ child: "f32[3]" = _add_batch_dim.sum(0)
+ child_1: "f32[4]" = _add_batch_dim.sum(1); _add_batch_dim = None
_remove_batch_dim: "f32[3, 2]" = torch._C._functorch._remove_batch_dim(child, 1, 2, 1); child = None
_remove_batch_dim_1: "f32[2, 4]" = torch._C._functorch._remove_batch_dim(child_1, 1, 2, 0); child_1 = None
@@ -5888,10 +5888,10 @@
_vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(2, 'error'); _vmap_increment_nesting = None
- _add_batch_dim = torch._C._functorch._add_batch_dim(l_x_, 0, 1); l_x_ = None
+ _add_batch_dim: "f32[4, 3]" = torch._C._functorch._add_batch_dim(l_x_, 0, 1); l_x_ = None
- child = _add_batch_dim.sum(0)
- child_1 = _add_batch_dim.sum(1); _add_batch_dim = None
+ child: "f32[3]" = _add_batch_dim.sum(0)
+ child_1: "f32[4]" = _add_batch_dim.sum(1); _add_batch_dim = None
_remove_batch_dim: "f32[3, 2]" = torch._C._functorch._remove_batch_dim(child, 1, 2, 1); child = None
_remove_batch_dim_1: "f32[2, 4]" = torch._C._functorch._remove_batch_dim(child_1, 1, 2, 0); child_1 = None
diff --git a/test/dynamo/test_subclasses.py b/test/dynamo/test_subclasses.py
index 22b3bbf..01cdb4e 100644
--- a/test/dynamo/test_subclasses.py
+++ b/test/dynamo/test_subclasses.py
@@ -987,12 +987,12 @@
actual,
"""\
class GraphModule(torch.nn.Module):
- def forward(self, L_x_ : torch.Tensor):
+ def forward(self, L_x_: "f32[3, 4]"):
l_x_ = L_x_
- add_ = l_x_.add_(1.0)
- relu_ = torch.relu_(l_x_); l_x_ = None
- add = add_ + relu_; add_ = relu_ = None
+ add_: "f32[3, 4]" = l_x_.add_(1.0)
+ relu_: "f32[3, 4]" = torch.relu_(l_x_); l_x_ = None
+ add: "f32[3, 4]" = add_ + relu_; add_ = relu_ = None
return (add,)
""",
)
@@ -1029,12 +1029,12 @@
actual,
"""\
class GraphModule(torch.nn.Module):
- def forward(self, L_x_ : torch.Tensor):
+ def forward(self, L_x_: "f32[3, 4]"):
l_x_ = L_x_
- add_ = l_x_.add_(1.0)
- relu_ = torch.relu_(l_x_); l_x_ = None
- add = add_ + relu_; add_ = relu_ = None
+ add_: "f32[3, 4]" = l_x_.add_(1.0)
+ relu_: "f32[3, 4]" = torch.relu_(l_x_); l_x_ = None
+ add: "f32[3, 4]" = add_ + relu_; add_ = relu_ = None
return (add,)
""",
)
@@ -1106,17 +1106,17 @@
2,
"""\
class GraphModule(torch.nn.Module):
- def forward(self, L_x_ : torch.Tensor):
+ def forward(self, L_x_: "f32[3, 4]"):
l_x_ = L_x_
wrap_body_0 = self.wrap_body_0
wrap = torch._higher_order_ops.wrap.wrap(wrap_body_0, l_x_); wrap_body_0 = l_x_ = None
- getitem = wrap[0]; wrap = None
+ getitem: "f32[3, 4]" = wrap[0]; wrap = None
return (getitem,)
class GraphModule(torch.nn.Module):
- def forward(self, l_x_):
- add_ = l_x_.add_(1.0); l_x_ = None
+ def forward(self, l_x_: "f32[3, 4]"):
+ add_: "f32[3, 4]" = l_x_.add_(1.0); l_x_ = None
return (add_,)
""",
)
@@ -1136,17 +1136,17 @@
3,
"""\
class GraphModule(torch.nn.Module):
- def forward(self, L_x_ : torch.Tensor):
+ def forward(self, L_x_: "f32[3, 4]"):
l_x_ = L_x_
wrap_body_0 = self.wrap_body_0
wrap = torch._higher_order_ops.wrap.wrap(wrap_body_0, l_x_); wrap_body_0 = l_x_ = None
- getitem = wrap[0]; wrap = None
+ getitem: "f32[3, 4]" = wrap[0]; wrap = None
return (getitem,)
class GraphModule(torch.nn.Module):
- def forward(self, l_x_):
- add_ = l_x_.add_(1.0); l_x_ = None
+ def forward(self, l_x_: "f32[3, 4]"):
+ add_: "f32[3, 4]" = l_x_.add_(1.0); l_x_ = None
return (add_,)
""",
)
diff --git a/torch/fx/graph.py b/torch/fx/graph.py
index 0c5687b..720b808 100644
--- a/torch/fx/graph.py
+++ b/torch/fx/graph.py
@@ -573,14 +573,13 @@
if verbose:
# override annotation with more detailed information
- from torch._subclasses.fake_tensor import FakeTensor
from torch.fx.experimental.proxy_tensor import py_sym_types
from torch.fx.passes.shape_prop import TensorMetadata
meta_val = node.meta.get('val', node.meta.get('tensor_meta', node.meta.get('example_value', None)))
# use string as annotation, to make it valid python code
- if isinstance(meta_val, FakeTensor):
+ if isinstance(meta_val, torch.Tensor):
stride_annotation = f"{stringify_shape(meta_val.stride())}" if include_stride else ""
device_annotation = f"{meta_val.device}" if include_device else ""
maybe_type_annotation = \