| # Copyright (c) Meta Platforms, Inc. and affiliates. |
| # All rights reserved. |
| # |
| # This source code is licensed under the BSD-style license found in the |
| # LICENSE file in the root directory of this source tree. |
| |
| from dataclasses import dataclass |
| |
| import torch |
| from executorch.backends.example.example_operators.op_base import OpBase |
| from executorch.backends.example.example_operators.utils import ( |
| _annotate_nodes, |
| _nodes_are_annotated, |
| ) |
| |
| |
| def _annotate_linear(partitions, quant_config): |
| """ |
| This is what the graph of a simple linear op looks like: |
| fn_weight = self.fn_weight |
| fn_bias = self.fn_bias |
| permute_copy = torch.ops.aten.permute_copy.default(fn_weight, [1, 0]); fn_weight = None |
| addmm = torch.ops.aten.addmm.default(fn_bias, arg2_1, permute_copy); fn_bias = arg2_1 = permute_copy = None |
| """ |
| linear_node = partitions[0].output_nodes[0] |
| if _nodes_are_annotated([linear_node]): |
| return |
| |
| input_node = linear_node.args[0] |
| # permute_node = linear_node.args[1] |
| # print("permute_node: ", permute_node, " args: ", permute_node.args, " target: ", permute_node.target) |
| weight_node = linear_node.args[1] |
| print( |
| "weight_node: ", |
| weight_node, |
| " args: ", |
| weight_node.args, |
| " target: ", |
| weight_node.target, |
| ) |
| # Unused. |
| # bias_node = output_node.args[0] |
| |
| # if _nodes_are_annotated([linear_node, permute_node]): |
| # return |
| |
| _annotate_nodes( |
| [(linear_node, input_node)], quant_config.input_quant_spec, input_node=True |
| ) |
| _annotate_nodes( |
| [(linear_node, weight_node)], quant_config.weight_quant_spec, input_node=True |
| ) |
| _annotate_nodes([(linear_node,)], quant_config.output_quant_spec) |
| |
| |
| @dataclass |
| class LinearNode(OpBase): |
| def __init__(self): |
| super().__init__( |
| pattern=(torch.nn.Linear,), |
| annotate_handle=_annotate_linear, |
| ) |