[FX] Add type annotation to `getitem` node before `split_module` (#88510)
Summary: Some nodes lost the type annotation during `split_module`, causing the submodels to be un-scriptable. This is because compiler always infer Tensor type, which is wrong for non-Tensor types. We attempt to infer type annotation for `getitem` node to improve scriptability.
Test Plan:
```
buck2 test //caffe2/test:fx_experimental
```
Differential Revision: D41037819
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88510
Approved by: https://github.com/xush6528
diff --git a/torch/fx/passes/split_module.py b/torch/fx/passes/split_module.py
index c6954c2..0343bae 100644
--- a/torch/fx/passes/split_module.py
+++ b/torch/fx/passes/split_module.py
@@ -1,4 +1,5 @@
import inspect
+import operator
from typing import Any, Callable, Dict, List, Optional
import torch
@@ -159,6 +160,25 @@
# split nodes into parititons
for node in m.graph.nodes:
+ # Annotations on local names within function are lost during FX transforms.
+ # Adding back known type annotation for getitem nodes for jit scriptability.
+ if node.target == operator.getitem:
+ sequence_node, index_node = node.args
+ # only support type Tuple for now
+ if (
+ hasattr(sequence_node.type, "_name")
+ and sequence_node.type._name == "Tuple"
+ ):
+ parameterized_types = sequence_node.type.__args__
+ if len(parameterized_types) == 2 and isinstance(
+ parameterized_types[1], type(...)
+ ):
+ node.type = parameterized_types[0]
+ else:
+ assert len(parameterized_types) > index_node
+ node_type = parameterized_types[index_node]
+ node.type = node_type
+
orig_nodes[node.name] = node
# TODO currently placeholders/parameters aren't put into random partitions,
@@ -210,7 +230,10 @@
for partition_name in sorted_partitions:
partition = partitions[partition_name]
for input in partition.inputs:
- placeholder = partition.graph.placeholder(input)
+ placeholder = partition.graph.placeholder(
+ input,
+ type_expr=orig_nodes[input].type,
+ )
placeholder.meta = orig_nodes[input].meta.copy()
partition.environment[orig_nodes[input]] = placeholder
@@ -248,7 +271,11 @@
assert isinstance(gathered_args, tuple)
assert isinstance(gathered_kwargs, dict)
new_node = partition.graph.create_node(
- op=node.op, target=target, args=gathered_args, kwargs=gathered_kwargs
+ op=node.op,
+ target=target,
+ args=gathered_args,
+ kwargs=gathered_kwargs,
+ type_expr=node.type,
)
new_node.meta = node.meta.copy()
partition.environment[node] = new_node