Fix TRTOperatorSupport (#64873)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/64873
Fix TRTOperatorSupport's key naming to match the key generated by torch.fx.passes.tools_common.get_node_target. The get_node_target is used by splitter_base for comparing whether operator is supported by name.
Test Plan:
print out the supported operator dict and check name.
Run TRTSplitter with lrm_split_model_generator and verify split result is correct with all supported operators printed.
current split result:
````
Supported node types in the model:
acc_ops.size: ((), {'input': torch.float32})
acc_ops.getitem: ((), {'input': torch.float32})
acc_ops.getitem: ((), {'input': None})
acc_ops.reshape: ((), {'input': torch.float32})
acc_ops.unsqueeze: ((), {'input': torch.float32})
acc_ops.linear: ((), {'input': torch.float32, 'weight': torch.float32})
acc_ops.linear: ((), {'input': torch.float32, 'weight': torch.float32, 'bias': torch.float32})
acc_ops.mul: ((), {'input': torch.float32, 'other': torch.float32})
acc_ops.cat: ((), {})
acc_ops.add: ((), {'input': torch.float32, 'other': torch.float32})
acc_ops.add: ((), {'input': torch.float32})
acc_ops.tanh: ((), {'input': torch.float32})
acc_ops.transpose: ((), {'input': torch.float32})
acc_ops.matmul: ((), {'input': torch.float32, 'other': torch.float32})
acc_ops.div: ((), {'input': torch.float32, 'other': torch.float32})
acc_ops.squeeze: ((), {'input': torch.float32})
acc_ops.noop: ((), {'input': torch.float32})
acc_ops.layer_norm: ((), {'input': torch.float32, 'weight': torch.float32, 'bias': torch.float32})
acc_ops.permute: ((), {'input': torch.float32})
acc_ops.sigmoid: ((), {'input': torch.float32})
acc_ops.flatten: ((), {'input': torch.float32})
acc_ops.softmax: ((), {'input': torch.float32})
acc_ops.sum: ((), {'input': torch.float32})
Unsupported node types in the model:
torch.ops.fb.pad_sequence_embeddings: ((), {'embeddings': torch.float32, 'offsets': torch.int32})
acc_ops.linalg_norm: ((), {'input': torch
```
Reviewed By: yinghai
Differential Revision: D30884463
fbshipit-source-id: 22442aa6a69cd148ce9bc8be8f62157dd6d19954
diff --git a/torch/fx/experimental/fx2trt/tools/trt_splitter.py b/torch/fx/experimental/fx2trt/tools/trt_splitter.py
index 97af41a..94bea5e 100644
--- a/torch/fx/experimental/fx2trt/tools/trt_splitter.py
+++ b/torch/fx/experimental/fx2trt/tools/trt_splitter.py
@@ -17,7 +17,19 @@
def __init__(self):
self._support_dict = {}
for k in CONVERTERS.keys():
- self._support_dict[k] = None
+ name = self.get_op_name(k)
+ self._support_dict[name] = None
+
+ def get_op_name(self, k):
+ if isinstance(k, str):
+ return k
+ elif k.__module__ and "acc_ops" in k.__module__:
+ return f"acc_ops.{k.__name__}"
+ else:
+ module = k.__module__
+ return f"{module if module else ''}.{k.__name__}".replace('_', '')
+
+
class TRTSplitter(splitter_base._SplitterBase):
@@ -32,7 +44,6 @@
operator_support = TRTOperatorSupport()
if not settings:
settings = splitter_base._SplitterSettingBase()
-
super().__init__(module, sample_input, operator_support, settings)
def _lower_model_to_backend(