blob: ede32a5e6592095c32e041371617d74c431c5624 [file]
# Copyright (c) Qualcomm Innovation Center, Inc.
# All rights reserved
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from typing import Dict, Optional
import torch
from torch._export.utils import (
get_buffer,
get_lifted_tensor_constant,
get_param,
is_buffer,
is_lifted_tensor_constant,
is_param,
)
def is_parameter(
node: torch.fx.Node, edge_program: torch.export.ExportedProgram
) -> bool:
return (
is_param(edge_program, node)
or is_buffer(edge_program, node)
or is_lifted_tensor_constant(edge_program, node)
)
def get_parameter(
node: torch.fx.Node, edge_program: torch.export.ExportedProgram
) -> torch.Tensor:
param = None
if is_param(edge_program, node):
param = get_param(edge_program, node)
if is_buffer(edge_program, node):
param = get_buffer(edge_program, node)
if is_lifted_tensor_constant(edge_program, node):
param = get_lifted_tensor_constant(edge_program, node)
if param is not None:
# update node.meta["val"] to qualified QNN datatype (e.g. i64 to i32)
assert isinstance(param, torch.Tensor), "Expect parameter to be tensor"
param = param.type(node.meta["val"].dtype)
return param
def set_parameter(
param: torch.Tensor, node: torch.fx.Node, edge_program: torch.export.ExportedProgram
):
status = False
if is_param(edge_program, node):
edge_program.state_dict[
edge_program.graph_signature.inputs_to_parameters[node.name]
] = param
status = True
if is_buffer(edge_program, node):
buffer_name = edge_program.graph_signature.inputs_to_buffers[node.name]
if buffer_name in edge_program.graph_signature.non_persistent_buffers:
edge_program.constants[buffer_name] = param
else:
edge_program.state_dict[buffer_name] = param
status = True
assert status, "Failed to set parameter"
def is_graph_input(
tensor: torch.fx.Node, edge_program: torch.export.ExportedProgram
) -> bool:
"""
Check if the given tensor is a graph input
Args:
tensor: EdgeIR Tensor that is being checked for graph input
"""
return tensor.op == "placeholder" and not is_parameter(tensor, edge_program)
def is_graph_output(tensor: torch.fx.Node) -> bool:
"""
Check if the given tensor is used as a graph output
Args:
tensor: EdgeIR Tensor that is being checked for graph input
"""
for user in tensor.users.keys():
# getitem node is skiped, check the op_skip_ops.py
if user.op == "output" or (
user.target.__name__ == "getitem" and is_graph_output(user)
):
return True
return False
def is_constant(
tensor: torch.fx.Node, edge_program: torch.export.ExportedProgram
) -> bool:
"""
Check if the given tensor is a constant
Args:
tensor: EdgeIR Tensor that is being checked for graph input
"""
# constants should not be treated as input placeholder
# pay attention to the pytorch design, change this if
# breakage happened:
# pytorch/torch/_export/passes/lift_constant_tensor_pass.py
if is_parameter(tensor, edge_program):
return tensor.meta["val"].constant is not None
return False
def deduce_dtype(
tensor: torch.Tensor, quant_infos: Optional[Dict] = None
) -> torch.dtype:
if quant_infos:
quant_range = quant_infos["quant_max"] - quant_infos["quant_min"]
unsigned = quant_infos["quant_min"] >= 0
if quant_range <= torch.iinfo(torch.int8).max - torch.iinfo(torch.int8).min:
return torch.uint8 if unsigned else torch.int8
elif quant_range <= torch.iinfo(torch.int16).max - torch.iinfo(torch.int16).min:
return torch.uint16 if unsigned else torch.int16
return quant_infos["dtype"]
return tensor.dtype