remove transpose addmm weights hack (#358)
Summary:
Pull Request resolved: https://github.com/pytorch/executorch/pull/358
### Background
A common pattern we when encountering addmm is that weights are permuted before given to addmm. This is because generally for torch.nn.Linear, the input shape and weight shape are given as such:
```
input: (*, in_features)
weight: (out_features,in_features)
```
while the input shape and weight shape of addmm are the following:
```
input1 (input): (*, in_features)
input2 (weight): (in_features, out_features)
```
so when decomposing nn.Linear to addmm, the weights go through a permute node to comply with addmm's shapes
### XNNPACK Status
XNNPACK can handle both the transpose and normal weight shape, however it requires a flag for whether or not the weights are transposed. So an easy optimization is to skip the permute node and use the flag.
### Change and Motivation
Currently, we have hardcoded some of this optimization logic directly into serialization. I believe that serialization should not be aware of these optimizations, which is why I am removing this logic from within serialization. Instead this logic should be performed completely by the addmm --> linear pass which recomposes permute + addmm into a singular linear. We should no longer rely on serialization logic to perform this logic (Right now its errorneous and causing a bug).
Reviewed By: kirklandsign
Differential Revision: D49129704
fbshipit-source-id: 1134c33f76eb27ac05a90b29c6dc057c8c647b58
diff --git a/backends/xnnpack/operators/node_visitor.py b/backends/xnnpack/operators/node_visitor.py
index 0d07f09..3589efb 100644
--- a/backends/xnnpack/operators/node_visitor.py
+++ b/backends/xnnpack/operators/node_visitor.py
@@ -289,12 +289,6 @@
# convert tensor shape must reflect memory format, default is contiguous, so
# only permute shape if we are converting the tensor to nhwc format
- if tensor.target in (
- exir_ops.edge.aten.permute_copy.default,
- exir_ops.edge.aten.t_copy.default,
- ):
- # We ignore transpose nodes and reverse the dims to before it
- dims = dims[::-1]
if swap_nc_for_depthwise_weights:
dims = [dims[1], dims[0]] + dims[2:]
if convert_to_nhwc:
diff --git a/backends/xnnpack/operators/op_addmm.py b/backends/xnnpack/operators/op_addmm.py
index bf8804e..dbc38d1 100644
--- a/backends/xnnpack/operators/op_addmm.py
+++ b/backends/xnnpack/operators/op_addmm.py
@@ -22,7 +22,6 @@
from executorch.backends.xnnpack.utils.xnnpack_constants import (
XNN_FLAG_TRANSPOSE_WEIGHTS,
)
-from executorch.exir.dialects._ops import ops as exir_ops
@register_node_visitor
@@ -56,15 +55,7 @@
# output
output_id = vals_to_ids[node]
- flag = (
- 0
- if get_input_node(node, 2).target
- in (
- exir_ops.edge.aten.permute_copy.default,
- exir_ops.edge.aten.t_copy.default,
- )
- else XNN_FLAG_TRANSPOSE_WEIGHTS
- )
+ flag = XNN_FLAG_TRANSPOSE_WEIGHTS
ser_node = XNode(
xnode_union=XNNFullyConnected(