[FX] Add type annotation to `getitem` node before `split_module` (#88510)

Summary: Some nodes lost the type annotation during `split_module`, causing the submodels to be un-scriptable. This is because compiler always infer Tensor type, which is wrong for non-Tensor types. We attempt to infer type annotation for `getitem` node to improve scriptability.

Test Plan:
```
buck2 test //caffe2/test:fx_experimental
```

Differential Revision: D41037819

Pull Request resolved: https://github.com/pytorch/pytorch/pull/88510
Approved by: https://github.com/xush6528
diff --git a/torch/fx/passes/split_module.py b/torch/fx/passes/split_module.py
index c6954c2..0343bae 100644
--- a/torch/fx/passes/split_module.py
+++ b/torch/fx/passes/split_module.py
@@ -1,4 +1,5 @@
 import inspect
+import operator
 from typing import Any, Callable, Dict, List, Optional
 
 import torch
@@ -159,6 +160,25 @@
 
     # split nodes into parititons
     for node in m.graph.nodes:
+        # Annotations on local names within function are lost during FX transforms.
+        # Adding back known type annotation for getitem nodes for jit scriptability.
+        if node.target == operator.getitem:
+            sequence_node, index_node = node.args
+            # only support type Tuple for now
+            if (
+                hasattr(sequence_node.type, "_name")
+                and sequence_node.type._name == "Tuple"
+            ):
+                parameterized_types = sequence_node.type.__args__
+                if len(parameterized_types) == 2 and isinstance(
+                    parameterized_types[1], type(...)
+                ):
+                    node.type = parameterized_types[0]
+                else:
+                    assert len(parameterized_types) > index_node
+                    node_type = parameterized_types[index_node]
+                    node.type = node_type
+
         orig_nodes[node.name] = node
 
         # TODO currently placeholders/parameters aren't put into random partitions,
@@ -210,7 +230,10 @@
     for partition_name in sorted_partitions:
         partition = partitions[partition_name]
         for input in partition.inputs:
-            placeholder = partition.graph.placeholder(input)
+            placeholder = partition.graph.placeholder(
+                input,
+                type_expr=orig_nodes[input].type,
+            )
             placeholder.meta = orig_nodes[input].meta.copy()
             partition.environment[orig_nodes[input]] = placeholder
 
@@ -248,7 +271,11 @@
             assert isinstance(gathered_args, tuple)
             assert isinstance(gathered_kwargs, dict)
             new_node = partition.graph.create_node(
-                op=node.op, target=target, args=gathered_args, kwargs=gathered_kwargs
+                op=node.op,
+                target=target,
+                args=gathered_args,
+                kwargs=gathered_kwargs,
+                type_expr=node.type,
             )
             new_node.meta = node.meta.copy()
             partition.environment[node] = new_node