blob: c17f9332e0c5398ce9e6b6646cfee8d5cc73adcf [file] [log] [blame] [edit]
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import operator
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Set, Tuple, Union
import torch
from executorch.backends.vulkan.serialization.vulkan_graph_schema import (
VkDataType,
VkMemoryLayout,
VkStorageType,
)
from executorch.exir.backend.canonical_partitioners.config_partitioner import (
format_target_name,
)
from executorch.exir.dialects.edge._ops import EdgeOpOverload
from executorch.exir.tensor import TensorSpec
from torch._export.utils import is_buffer, is_lifted_tensor_constant, is_param
from torch._subclasses.fake_tensor import FakeTensor, FakeTensorConverter
from torch.export import ExportedProgram
from torch.export.exported_program import InputKind
from torch.export.graph_signature import TensorArgument
TorchOpType = Union[EdgeOpOverload, torch._ops.OpOverload, str]
_DQ_OPS = {
"dequantize_per_tensor.tensor",
"dequantize_per_tensor.default",
"dequantize_per_channel.default",
"dequantize_per_channel_group.default",
"dequantize_per_token.default",
"dequantize_affine.default",
}
_Q_OPS = {
"quantize_per_tensor.tensor",
"quantize_per_tensor.default",
"quantize_per_channel.default",
"quantize_per_token.default",
"quantize_affine.default",
}
_VULKAN_DTYPES: Dict[torch.dtype, VkDataType] = {
torch.bool: VkDataType.BOOL,
torch.uint8: VkDataType.UINT8,
torch.int8: VkDataType.INT8,
torch.int32: VkDataType.INT32,
torch.int64: VkDataType.INT64,
torch.float16: VkDataType.FLOAT16,
torch.float32: VkDataType.FLOAT32,
torch.float64: VkDataType.FLOAT64,
}
##
## Dtype sets for per-operator dtype constraints
##
DtypeSet = Set[torch.dtype]
FP_T: DtypeSet = {torch.float16, torch.float32}
INT_T: DtypeSet = {torch.int32, torch.int64}
QINT8_T: DtypeSet = {torch.int8}
BOOL_T: DtypeSet = {torch.bool}
ALL_T: DtypeSet = set(_VULKAN_DTYPES.keys())
NONE_T: DtypeSet = set() # Marker for non-tensor args (skip validation)
# Composite dtype sets for specific operator requirements
FP_INT_T: DtypeSet = FP_T | INT_T
FP_INT_BOOL_T: DtypeSet = FP_T | INT_T | BOOL_T
class DtypeSetList:
"""
Wrapper around a list of DtypeSet with broadcasting semantics.
If only one DtypeSet is provided, it applies to all positions.
"""
def __init__(self, dtype_sets: Union[DtypeSet, List[DtypeSet]]):
self.vals: List[DtypeSet] = (
dtype_sets if isinstance(dtype_sets, list) else [dtype_sets]
)
def __len__(self) -> int:
return len(self.vals)
def __getitem__(self, idx: int) -> DtypeSet:
# Broadcasting: single set applies to all positions
if idx > 0 and len(self.vals) == 1:
return self.vals[0]
if idx >= len(self.vals):
return set()
return self.vals[idx]
def is_empty(self) -> bool:
return len(self.vals) == 0
def any_constrained(self) -> bool:
"""Returns True if any position has dtype constraints."""
return any(len(s) > 0 for s in self.vals)
##
## Node type determination
##
# Convenience type
MaybeNodeList = Union[torch.fx.Node, List[torch.fx.Node], Tuple[torch.fx.Node]]
def is_torch_op_node(node: torch.fx.Node) -> bool:
if node.op != "call_function":
return False
if isinstance(node.target, EdgeOpOverload):
return True
if isinstance(node.target, torch._ops.OpOverload):
return True
return False
def is_dequant_node(node: torch.fx.Node) -> bool:
if node.op != "call_function":
return False
node_name = format_target_name(node.target.__name__) # pyre-ignore
return node_name in _DQ_OPS
def is_quant_node(node: torch.fx.Node) -> bool:
if node.op != "call_function":
return False
node_name = format_target_name(node.target.__name__) # pyre-ignore
return node_name in _Q_OPS
def is_choose_qparams_node(node: torch.fx.Node) -> bool:
if node.op != "call_function":
return False
node_name = format_target_name(node.target.__name__) # pyre-ignore
return "choose_qparams" in node_name
def is_dynamic_qscale(node: Any) -> bool:
"""Check if a scale node is dynamically computed via a choose_qparams op."""
return (
isinstance(node, torch.fx.Node)
and node.target == operator.getitem
and is_choose_qparams_node(node.args[0])
)
def is_dequant_per_channel_node(node: torch.fx.Node) -> bool:
if node.op != "call_function":
return False
node_name = format_target_name(node.target.__name__) # pyre-ignore
return node_name == "dequantize_per_channel.default"
def is_view_copy_node(node: torch.fx.Node) -> bool:
if node.op != "call_function":
return False
node_name = format_target_name(node.target.__name__) # pyre-ignore
return "view_copy" in node_name
def is_linear_node(node: torch.fx.Node) -> bool:
if node.op != "call_function":
return False
node_name = format_target_name(node.target.__name__) # pyre-ignore
return node_name == "linear.default"
def is_get_attr_node(node: torch.fx.Node) -> bool:
return isinstance(node, torch.fx.Node) and node.op == "get_attr"
def is_constant(program: ExportedProgram, node: torch.fx.Node) -> bool:
return node.name in program.graph_signature.inputs_to_lifted_tensor_constants
def is_param_node(program: ExportedProgram, node: torch.fx.Node) -> bool:
"""
Check if the given node is a parameter within the exported program
"""
return (
is_get_attr_node(node)
or is_param(program, node)
or is_buffer(program, node)
or is_lifted_tensor_constant(program, node)
)
def is_mutable_buffer_node(
node: torch.fx.Node, exported_program: ExportedProgram
) -> bool:
if node.target not in exported_program.graph_signature.inputs_to_buffers:
return False
buf = exported_program.graph_signature.inputs_to_buffers[node.target]
return buf in exported_program.graph_signature.buffers_to_mutate.values()
def is_symint_node(node: torch.fx.Node) -> bool:
"""
Returns true if the given node produces a SymInt value
"""
if "val" not in node.meta:
return False
if isinstance(node.meta["val"], torch.SymInt):
return True
return False
def is_single_tensor_node(node: torch.fx.Node) -> bool:
"""
Returns true if the given node produces a single tensor value
"""
if "val" not in node.meta:
return False
if isinstance(node.meta["val"], FakeTensor):
return True
return False
def is_tensor_collection_node(node: Any) -> bool:
"""
Returns true if the given node produces a collection of tensor values
"""
if not isinstance(node, torch.fx.Node):
return False
if "val" not in node.meta:
return False
if isinstance(node.meta["val"], list) or isinstance(node.meta["val"], tuple):
return all(isinstance(x, FakeTensor) for x in node.meta["val"])
return False
def is_tensor_node(node: Any) -> bool:
"""
Returns true if the given node produces a tensor value, or a collection of tensor values
"""
if not isinstance(node, torch.fx.Node):
return False
if "val" not in node.meta:
return False
if isinstance(node.meta["val"], FakeTensor):
return True
if isinstance(node.meta["val"], list) or isinstance(node.meta["val"], tuple):
return all(isinstance(x, FakeTensor) for x in node.meta["val"])
return False
def is_tensor_arg_node(node: Any) -> bool:
if isinstance(node, torch.fx.Node):
return is_tensor_node(node)
elif isinstance(node, (list, tuple)):
if len(node) == 0:
return False
return all(is_tensor_node(n) for n in node)
return False
def num_tensor_arg_nodes(node: torch.fx.Node) -> int:
"""
For a given node, return the number of argument nodes that are associated with
tensors.
"""
count = 0
for arg_node in node.args:
if not isinstance(arg_node, torch.fx.Node):
continue
if is_tensor_node(arg_node):
count += 1
return count
def num_tensors_in_node(node: torch.fx.Node) -> int:
"""
Returns the number of tensors associated a given node
"""
if "val" not in node.meta:
return 0
if isinstance(node.meta["val"], FakeTensor):
return 1
if isinstance(node.meta["val"], list) or isinstance(node.meta["val"], tuple):
if all(isinstance(x, FakeTensor) for x in node.meta["val"]):
return len(node.meta["val"])
return 0
def get_vk_datatype(torch_dtype: torch.dtype) -> VkDataType:
"""
Returns Vulkan dtype corresponding to torch dtype
"""
if torch_dtype not in _VULKAN_DTYPES:
raise AssertionError(f"Invalid dtype for vulkan_preprocess ({torch_dtype})")
return _VULKAN_DTYPES[torch_dtype]
def output_dtypes_are_supported(node: torch.fx.Node) -> bool:
"""
Returns true if the output of the given tensor node has dtype that
is supported by the Vulkan backend.
"""
if not is_tensor_node(node):
return True
# The val metadata must exist after previous check
node_val = node.meta.get("val", None)
assert node_val is not None
# Get all the tensor dtypes in the node
tensor_dtypes = []
if isinstance(node_val, FakeTensor):
tensor_dtypes = [node_val.dtype]
elif isinstance(node_val, list) or isinstance(node_val, tuple):
tensor_dtypes = [x.dtype for x in node_val]
# Verify that all the tensor_dtypes are in vk_torch_dtypes
return all(dtype in _VULKAN_DTYPES for dtype in tensor_dtypes)
def input_dtypes_are_supported(node: torch.fx.Node) -> bool:
"""
Returns true if all the inputs to the given tensor node have dtype that
is supported by the Vulkan backend.
"""
if not is_tensor_node(node):
return True
# Iterate over all the args of the node
for arg_node in node.args:
# The arg could be a single node, or a list (e.g., first arg of cat)
if isinstance(arg_node, torch.fx.Node):
if not output_dtypes_are_supported(arg_node):
return False
elif isinstance(arg_node, (list, tuple)):
if not all(output_dtypes_are_supported(x) for x in arg_node):
return False
return True
def io_dtypes_are_supported(node: torch.fx.Node) -> bool:
"""
Returns true if all the inputs and outputs of the given tensor node have
dtype that is supported by the Vulkan backend.
"""
if not output_dtypes_are_supported(node):
return False
if not input_dtypes_are_supported(node):
return False
return True
def check_node_dtypes( # noqa: C901
node: torch.fx.Node,
inputs_dtypes: DtypeSetList,
outputs_dtypes: DtypeSetList,
) -> Tuple[bool, str]:
"""
Check if all tensor inputs/outputs have dtypes in the allowed sets.
Returns (is_valid, reason_string) for better error reporting.
"""
# Check input tensor dtypes
for i, arg in enumerate(node.args):
allowed_dtypes = inputs_dtypes[i]
# Skip non-constrained positions (NO_DTYPE = empty set)
if len(allowed_dtypes) == 0:
continue
if is_tensor_node(arg):
if isinstance(arg.meta["val"], (list, tuple)):
arg_dtype = arg.meta["val"][0].dtype
else:
arg_dtype = arg.meta["val"].dtype
if arg_dtype not in allowed_dtypes:
return False, f"input[{i}] dtype {arg_dtype} not in {allowed_dtypes}"
elif isinstance(arg, (list, tuple)):
# Handle tensor list inputs (e.g., cat)
for j, sub_arg in enumerate(arg):
if is_tensor_node(sub_arg):
sub_dtype = sub_arg.meta["val"].dtype
if sub_dtype not in allowed_dtypes:
return (
False,
f"input[{i}][{j}] dtype {sub_dtype} not in {allowed_dtypes}",
)
# Check output tensor dtypes
out_val = node.meta.get("val")
if isinstance(out_val, FakeTensor):
allowed_dtypes = outputs_dtypes[0]
if len(allowed_dtypes) > 0 and out_val.dtype not in allowed_dtypes:
return False, f"output dtype {out_val.dtype} not in {allowed_dtypes}"
elif isinstance(out_val, (list, tuple)):
for i, t in enumerate(out_val):
if isinstance(t, FakeTensor):
allowed_dtypes = outputs_dtypes[i]
if len(allowed_dtypes) > 0 and t.dtype not in allowed_dtypes:
return False, f"output[{i}] dtype {t.dtype} not in {allowed_dtypes}"
return True, "dtypes valid"
def tensor_node_is_bool(node: torch.fx.Node) -> bool:
"""
Returns true if a given node contains a tensor with bool dtype
"""
if isinstance(node.meta["val"], FakeTensor):
return node.meta["val"].dtype == torch.bool
if isinstance(node.meta["val"], list) or isinstance(node.meta["val"], tuple):
for fake_tensor in node.meta["val"]:
if isinstance(fake_tensor, FakeTensor):
if fake_tensor.dtype == torch.bool:
return True
return False
def ndim_of(node: Any) -> Optional[int]:
"""
Returns the number of dimensions of the tensor produced by the given node
"""
if not is_single_tensor_node(node):
return None
return node.meta["val"].ndim
def is_unsqueezed_vector(node: torch.fx.Node) -> bool:
"""
Returns True if the node's tensor has all dimensions equal to 1 except for the last dimension.
"""
if not is_single_tensor_node(node):
return False
tensor = node.meta["val"]
assert isinstance(tensor, FakeTensor)
if len(tensor.shape) < 1:
return False
# All dims except last are 1, last can be any size
return all(dim == 1 for dim in tensor.shape[:-1])
def op_contains_bool_tensor(node: torch.fx.Node) -> bool:
"""
Returns true if the operator used to compute the given node contains a bool tensor
"""
if is_tensor_node(node) and tensor_node_is_bool(node):
return True
for arg_node in node.args:
# pyre-ignore[6]
if is_tensor_node(arg_node) and tensor_node_is_bool(arg_node):
return True
return False
def op_contains_high_dim_tensor(node: torch.fx.Node) -> bool:
"""
Returns true if the operator used to compute the given node contains a tensor
with more than 4 dimensions
"""
if is_tensor_node(node) and tensor_node_is_high_dim(node):
return True
for arg_node in node.args:
# pyre-ignore[6]
if is_tensor_node(arg_node) and tensor_node_is_high_dim(arg_node):
return True
return False
def get_primary_arg_idx(self, node: torch.fx.Node) -> Optional[int]:
primary_arg_idx: Optional[int] = None
for i, arg_node in enumerate(node.args):
if self.is_non_constant_tensor_node(arg_node):
return i
return primary_arg_idx
def node_comes_from_any_nn_module_in_set(
node,
nn_module_typenames: Set[str],
) -> bool:
if isinstance(node, (list, tuple)):
return all(
node_comes_from_any_nn_module_in_set(n, nn_module_typenames) for n in node
)
if not isinstance(node, torch.fx.Node):
return False
nn_module_stack = node.meta.get("nn_module_stack", None)
if nn_module_stack is None:
return False
for _, packed in nn_module_stack.items():
_, typename = packed
for partial_name in nn_module_typenames:
if partial_name in typename:
return True
return False
def get_tensor_name(exp_prog: ExportedProgram, node: torch.fx.Node) -> str:
if node is None:
return ""
if is_param(exp_prog, node):
return exp_prog.graph_signature.inputs_to_parameters[node.name]
elif is_buffer(exp_prog, node):
return exp_prog.graph_signature.inputs_to_buffers[node.name]
elif is_lifted_tensor_constant(exp_prog, node):
return exp_prog.graph_signature.inputs_to_lifted_tensor_constants[node.name]
else:
assert isinstance(node.target, str)
return node.target
return ""
def find_dequant_user(node: torch.fx.Node) -> Optional[torch.fx.Node]:
"""
Search the direct users of the given node and return the first one that is a
dequantization op. Returns None if no dequantization op is found.
"""
for user in node.users:
if is_dequant_node(user):
return user
return None
def find_quant_user(node: torch.fx.Node) -> Optional[torch.fx.Node]:
"""
Search the direct users of the given node and return the first one that is a
quantization op. Returns None if no quantization op is found.
"""
for user in node.users:
if is_quant_node(user):
return user
return None
def node_has_target(node: Any, target: str):
if not hasattr(node, "target"):
return False
if isinstance(node.target, str):
return node.target == target
elif hasattr(node.target, "name"):
return node.target.name() == target
return False
##
## Memory Layout, Storage Type Determination
##
ImageExtents = Tuple[int, int, int]
DEFAULT_TEXTURE_LIMITS = (16384, 16384, 2048)
DEFAULT_BUFFER_LIMIT = 128 * (1024 * 1024)
all_storage_types: Set[VkStorageType] = {
VkStorageType.BUFFER,
VkStorageType.TEXTURE_3D,
}
# Memory layouts available to non-quantized tensors
all_memory_layouts: Set[VkMemoryLayout] = {
VkMemoryLayout.TENSOR_WIDTH_PACKED,
VkMemoryLayout.TENSOR_HEIGHT_PACKED,
VkMemoryLayout.TENSOR_CHANNELS_PACKED,
}
# Memory layouts available to quantized tensors
all_quantized_memory_layouts: Set[VkMemoryLayout] = {
VkMemoryLayout.PACKED_INT8_4W4C,
VkMemoryLayout.PACKED_INT8_4H4W,
VkMemoryLayout.PACKED_INT8_4W,
VkMemoryLayout.PACKED_INT8_4C1W,
VkMemoryLayout.PACKED_INT8_CONV2D,
}
universal_memory_layout_set: Set[VkMemoryLayout] = (
all_memory_layouts | all_quantized_memory_layouts
)
MemoryLayoutSet = Set[VkMemoryLayout]
MemoryLayoutSetList = Union[MemoryLayoutSet, List[MemoryLayoutSet]]
_LAYOUT_TO_PACKED_DIM: Dict[VkMemoryLayout, int] = {
VkMemoryLayout.TENSOR_WIDTH_PACKED: 0,
VkMemoryLayout.TENSOR_HEIGHT_PACKED: 1,
VkMemoryLayout.TENSOR_CHANNELS_PACKED: 2,
VkMemoryLayout.PACKED_INT8_4W4C: 2,
VkMemoryLayout.PACKED_INT8_4H4W: 0,
VkMemoryLayout.PACKED_INT8_4C1W: 2,
VkMemoryLayout.PACKED_INT8_CONV2D: 2,
}
def packed_dim_of(layout: VkMemoryLayout) -> int:
return _LAYOUT_TO_PACKED_DIM[layout]
@dataclass(frozen=True)
class PackedDimInfo:
"""
Describes how tensor data is organized in physical memory, mirroring the
C++ PackedDimInfo struct in runtime/api/containers/Tensor.h.
"""
packed_dim: int
packed_dim_block_size: int
@classmethod
def from_repr(
cls,
memory_layout: VkMemoryLayout,
storage_type: VkStorageType = VkStorageType.BUFFER,
) -> "PackedDimInfo":
"""
Construct a PackedDimInfo based on a memory layout and storage type,
mirroring calculate_packed_dim_info in runtime/api/containers/Tensor.cpp.
"""
is_buffer = storage_type == VkStorageType.BUFFER
if memory_layout == VkMemoryLayout.TENSOR_WIDTH_PACKED:
return cls(
packed_dim=0,
packed_dim_block_size=1 if is_buffer else 4,
)
elif memory_layout == VkMemoryLayout.TENSOR_HEIGHT_PACKED:
return cls(
packed_dim=1,
packed_dim_block_size=1 if is_buffer else 4,
)
elif memory_layout == VkMemoryLayout.TENSOR_CHANNELS_PACKED:
return cls(
packed_dim=2,
packed_dim_block_size=1 if is_buffer else 4,
)
elif memory_layout == VkMemoryLayout.PACKED_INT8_4W:
return cls(
packed_dim=0,
packed_dim_block_size=4,
)
elif memory_layout == VkMemoryLayout.PACKED_INT8_4W4C:
return cls(
packed_dim=2,
packed_dim_block_size=4,
)
elif memory_layout == VkMemoryLayout.PACKED_INT8_4H4W:
return cls(
packed_dim=0,
packed_dim_block_size=4,
)
elif memory_layout == VkMemoryLayout.PACKED_INT8_4C1W:
return cls(
packed_dim=2,
packed_dim_block_size=4 if is_buffer else 16,
)
elif memory_layout == VkMemoryLayout.PACKED_INT8_CONV2D:
return cls(
packed_dim=2,
packed_dim_block_size=4,
)
else:
raise ValueError(f"Unknown memory layout: {memory_layout}")
def within_buffer_limit(node: torch.fx.Node, buffer_limit: int) -> int:
"""
Checks whether the tensors produced by the given node can fit within the device's
GPU buffer limit, which represents the maximum number of elements that can be stored
in a GPU buffer.
"""
assert is_tensor_node(node)
if isinstance(node.meta["val"], FakeTensor):
return node.meta["val"].numel() < buffer_limit
elif isinstance(node.meta["val"], list) or isinstance(node.meta["val"], tuple):
return all(x.numel() < buffer_limit for x in node.meta["val"])
else:
raise RuntimeError(f"Cannot get numel for val of type {type(node.meta['val'])}")
def tensor_node_is_high_dim(node: torch.fx.Node) -> bool:
"""
Returns true if a given node contains a tensor with more than 4 dimensions
"""
if isinstance(node.meta["val"], FakeTensor):
return len(node.meta["val"].shape) > 4
if isinstance(node.meta["val"], list) or isinstance(node.meta["val"], tuple):
for fake_tensor in node.meta["val"]:
if isinstance(fake_tensor, FakeTensor):
if len(fake_tensor.shape) > 4:
return True
return False
def required_image_extents(sizes: torch.Size, layout: VkMemoryLayout) -> ImageExtents:
"""
Calculate the image extents that will be used to represent a tensor with the given sizes
and memory layout in the Vulkan Delegate.
"""
width = sizes[-1] if len(sizes) >= 1 else 1
height = sizes[-2] if len(sizes) >= 2 else 1
channels = sizes[-3] if len(sizes) >= 3 else 1
batch = sizes[0] if len(sizes) >= 4 else 1
if layout == VkMemoryLayout.TENSOR_WIDTH_PACKED:
width = (width + 3) // 4
elif layout == VkMemoryLayout.TENSOR_HEIGHT_PACKED:
height = (height + 3) // 4
elif layout == VkMemoryLayout.TENSOR_CHANNELS_PACKED:
channels = (channels + 3) // 4
elif layout == VkMemoryLayout.PACKED_INT8_4W4C:
width = (width + 3) // 4
channels = (channels + 3) // 4
elif layout == VkMemoryLayout.PACKED_INT8_4H4W:
height = (height + 3) // 4
width = (width + 3) // 4
elif layout == VkMemoryLayout.PACKED_INT8_CONV2D:
# Use conservative extents (same as 4W4C) since this is buffer-only
width = (width + 3) // 4
channels = (channels + 3) // 4
else:
raise RuntimeError(f"Unsupported memory layout {layout}")
return width, height, channels * batch
def extents_are_valid(extents: ImageExtents, limits: ImageExtents) -> bool:
return all(extents[i] <= limits[i] for i in range(len(extents)))
def valid_texture_memory_layouts(
tensor_sizes: torch.Size, texture_limits: ImageExtents
) -> Set[VkMemoryLayout]:
"""
Given tensor sizes, determine the set of memory layouts which will prodice a texture
that can fit within the specified device limits.
"""
valid_layouts = set()
for layout in list(all_memory_layouts):
extents = required_image_extents(tensor_sizes, layout)
if extents_are_valid(extents, texture_limits):
valid_layouts.add(layout)
return valid_layouts
class TensorRepr:
"""
This class is a wrapper around a pair of VkStorageType and VkMemoryLayout which
describes how a tensor should be represented in the Vulkan Delegate.
"""
def __init__(self, storage_type: VkStorageType, memory_layout: VkMemoryLayout):
self.storage_type = storage_type
self.memory_layout = memory_layout
def __str__(self) -> str:
return f"TensorRepr({self.storage_type}, {self.memory_layout})"
def __eq__(self, other: object) -> bool:
if not isinstance(other, TensorRepr):
return NotImplemented
return (
self.storage_type == other.storage_type
and self.memory_layout == other.memory_layout
)
def __ne__(self, other: object) -> bool:
return not self.__eq__(other)
class TensorReprList:
"""
This class is a wrapper around a list of TensorRepr instances that automatically
applies a "broadcasting" mechanism. The broadcasting mechanism allows for a single
underlying TensorRepr to be used to represent multiple tensors.
"""
def __init__(self, tensor_reprs: Union[TensorRepr, List[TensorRepr]]):
self.vals: List[TensorRepr] = (
tensor_reprs if isinstance(tensor_reprs, list) else [tensor_reprs]
)
def __len__(self):
return len(self.vals)
def __getitem__(self, idx: int) -> TensorRepr:
if idx > 0 and len(self) == 1:
return self.vals[0]
else:
return self.vals[idx]
def __setitem__(self, idx: int, val: TensorRepr) -> None:
if idx > 0 and len(self) == 1:
self.vals[0] = val
else:
self.vals[idx] = val
def __str__(self) -> str:
return f"[{', '.join(str(ts) for ts in self.vals)}]"
def __eq__(self, other: object) -> bool:
if not isinstance(other, TensorReprList):
return NotImplemented
if len(self) == len(other):
for self_val, other_val in zip(self.vals, other.vals):
if self_val != other_val:
return False
return True
return False
def __ne__(self, other: object) -> bool:
return not self.__eq__(other)
def append(self, val: TensorRepr) -> None:
self.vals.append(val)
def storage_type(self, idx: int = 0) -> VkStorageType:
return self.vals[idx].storage_type
def memory_layout(self, idx: int = 0) -> VkMemoryLayout:
return self.vals[idx].memory_layout
class TensorRepSet:
"""
This class describes the possible set of representations (i.e. TensorRepr) that may
be used to represent a tensor. This set is determined by the implementation of the
operator that the tensor participates in as well as the texture extents of the GPU.
"""
def __init__(
self,
buffer_memory_layouts: Set[VkMemoryLayout],
texture_memory_layouts: Set[VkMemoryLayout],
):
self.valid_buffer_layouts = buffer_memory_layouts
self.valid_texture_layouts = texture_memory_layouts
def __str__(self) -> str:
buffer_layouts = ", ".join(layout.name for layout in self.valid_buffer_layouts)
texture_layouts = ", ".join(
layout.name for layout in self.valid_texture_layouts
)
return f"TensorRepSet(Buffer Layouts: [{buffer_layouts}], Texture Layouts: [{texture_layouts}])"
def __eq__(self, other: object) -> bool:
if not isinstance(other, TensorRepSet):
return NotImplemented
return (
self.valid_buffer_layouts == other.valid_buffer_layouts
and self.valid_texture_layouts == other.valid_texture_layouts
)
def __ne__(self, other: object) -> bool:
return not self.__eq__(other)
def copy(self) -> "TensorRepSet":
return TensorRepSet(
set(self.valid_buffer_layouts), set(self.valid_texture_layouts)
)
def is_empty(self) -> bool:
"""
A TensorRepSet is "empty" if there are no valid representations of the tensor.
"""
return (
len(self.valid_buffer_layouts) == 0 and len(self.valid_texture_layouts) == 0
)
def make_intersect(self, other: "TensorRepSet") -> "TensorRepSet":
"""
Merge this TensorRepr with another TensorRepr, returning a new TensorRepr
with the intersection of the two.
"""
return TensorRepSet(
self.valid_buffer_layouts & other.valid_buffer_layouts,
self.valid_texture_layouts & other.valid_texture_layouts,
)
def make_union(self, other: "TensorRepSet") -> "TensorRepSet":
"""
Merge this TensorRepSet with another TensorRepSet, returning a new TensorRepSet
with the union of the two.
"""
return TensorRepSet(
self.valid_buffer_layouts | other.valid_buffer_layouts,
self.valid_texture_layouts | other.valid_texture_layouts,
)
def is_compatible(self, storage: TensorRepr) -> bool:
"""
Check if this TensorRepr is compatible with the given TensorRepSet.
"""
if storage.storage_type == VkStorageType.BUFFER:
return storage.memory_layout in self.valid_buffer_layouts
elif storage.storage_type == VkStorageType.TEXTURE_3D:
return storage.memory_layout in self.valid_texture_layouts
else:
raise RuntimeError(f"Unsupported storage type {storage.storage_type}")
def any_in_common(self, other: "TensorRepSet") -> bool:
"""
Check if this TensorRepr has any representations in common with another
TensorRepr.
"""
return (
len(self.valid_buffer_layouts & other.valid_buffer_layouts) > 0
or len(self.valid_texture_layouts & other.valid_texture_layouts) > 0
)
def texture_is_valid(self):
return len(self.valid_texture_layouts) > 0
def buffer_is_valid(self):
return len(self.valid_buffer_layouts) > 0
def first_valid_buffer_layout(self):
return list(self.valid_buffer_layouts)[0]
def first_valid_texture_layout(self):
return list(self.valid_texture_layouts)[0]
def make_tensor_repr(self) -> TensorRepr:
"""
Pick a representation (i.e. TensorRepr) from the set of possible representations.
If there are multiple valid representations, then:
1. Prefer texture storage over buffer storage
2. Pick the first available memory layout.
"""
if self.is_empty():
# An empty repset typically means that it is associated with a weight tensor
# or non tensor argument. In this case, just return default storage and
# layout as placeholder.
return TensorRepr(
VkStorageType.DEFAULT_STORAGE, VkMemoryLayout.DEFAULT_LAYOUT
)
if self.texture_is_valid():
return TensorRepr(
VkStorageType.TEXTURE_3D, self.first_valid_texture_layout()
)
else:
return TensorRepr(VkStorageType.BUFFER, self.first_valid_buffer_layout())
def is_constrained(self) -> bool:
"""
A "constrained" RepSet is one that has either:
1. A single valid texture memory layout, and no valid buffer memory layouts
2. No valid texture memory layouts, and a single valid buffer memory layout
3. Is empty
In this case, it is unambiguous which representation should be used for the
tensor.
"""
if self.is_empty():
return True
elif (
len(self.valid_texture_layouts) == 1 and len(self.valid_buffer_layouts) == 0
):
return True
elif (
len(self.valid_texture_layouts) == 0 and len(self.valid_buffer_layouts) == 1
):
return True
else:
return False
def is_ambiguous(self) -> bool:
"""
An "ambiguous" RepSet is one that is not constrained.
"""
return not self.is_constrained()
def _possible_pdis(self) -> Set[PackedDimInfo]:
buffer_set = set()
texture_set = set()
for layout in self.valid_buffer_layouts:
buffer_set.add(PackedDimInfo.from_repr(layout, VkStorageType.BUFFER))
for layout in self.valid_texture_layouts:
texture_set.add(PackedDimInfo.from_repr(layout, VkStorageType.TEXTURE_3D))
return buffer_set, texture_set
def has_same_packed_dim_info_set(self, other: "TensorRepSet") -> bool:
"""
Check if self and other produce the exact same sets of PackedDimInfo
for both buffer and texture storage types. Completely empty repsets
(no layouts for any storage type) are treated as matching any other
repset.
"""
other_buf_set, other_tex_set = other._possible_pdis()
buf_set, tex_set = self._possible_pdis()
# A completely empty repset is compatible with anything
if not buf_set and not tex_set:
return True
if not other_buf_set and not other_tex_set:
return True
return other_buf_set == buf_set and other_tex_set == tex_set
def has_compatible_packed_dim_info_set(self, other: "TensorRepSet") -> bool:
"""
Check if all PackedDimInfos from other are contained within self's
PackedDimInfo sets, i.e. self is a superset of other for both buffer
and texture PDI sets.
"""
other_buf_set, other_tex_set = other._possible_pdis()
buf_set, tex_set = self._possible_pdis()
for pdi in other_buf_set:
if pdi not in buf_set:
return False
for pdi in other_tex_set:
if pdi not in tex_set:
return False
return True
def constrain_to_compatible_packed_dim(
self, other: "TensorRepSet"
) -> "TensorRepSet":
"""
Return a new TensorRepSet containing only layouts from self whose
PackedDimInfo is present in other's PackedDimInfo sets. If other is
completely empty, return a copy of self unchanged. If other has layouts
for only one storage type, layouts for the missing storage type are
also removed.
"""
other_buf_set, other_tex_set = other._possible_pdis()
# Completely empty other means no constraint
if not other_buf_set and not other_tex_set:
return self.copy()
new_buf = {
layout
for layout in self.valid_buffer_layouts
if other_buf_set
and PackedDimInfo.from_repr(layout, VkStorageType.BUFFER) in other_buf_set
}
new_tex = {
layout
for layout in self.valid_texture_layouts
if other_tex_set
and PackedDimInfo.from_repr(layout, VkStorageType.TEXTURE_3D)
in other_tex_set
}
return TensorRepSet(new_buf, new_tex)
def make_tensor_repset(tensor_repr: TensorRepr) -> TensorRepSet:
"""
Given a TensorRepr, return a TensorRepSet that contains only that TensorRepr
"""
if tensor_repr.storage_type == VkStorageType.BUFFER:
return TensorRepSet({tensor_repr.memory_layout}, set())
elif tensor_repr.storage_type == VkStorageType.TEXTURE_3D:
return TensorRepSet(set(), {tensor_repr.memory_layout})
else:
raise RuntimeError(f"Unsupported storage type {tensor_repr.storage_type}")
def filter_invalid_reprs(
tensor_val: FakeTensor,
tensor_repset: TensorRepSet,
texture_limits: ImageExtents,
) -> TensorRepSet:
"""
`tensor_val` represents an actual tensor participating in some operator computation.
`tensor_repset` represents the set of valid tensor representations that may be used
for that tensor that is supported by the op implementation.
`texture_limits` represents the maximum texture sizes that is supported by the GPU.
Given the above, return a new TensorRepSet that contains only texture layouts that
can be used to produce a valid image texture for the given tensor (i.e. fits within
texture limits).
"""
valid_texture_layouts = set()
for memory_layout in tensor_repset.valid_texture_layouts:
extents = required_image_extents(tensor_val.shape, memory_layout)
if extents_are_valid(extents, texture_limits):
valid_texture_layouts.add(memory_layout)
# High dimensional tensors require buffer storage
if len(tensor_val.shape) > 4:
return TensorRepSet(tensor_repset.valid_buffer_layouts, set())
return TensorRepSet(tensor_repset.valid_buffer_layouts, valid_texture_layouts)
def filter_invalid_reprs_for_node_list(
arg_repsets: TensorRepSet,
arg_node: List[torch.fx.Node],
texture_limits: ImageExtents,
) -> TensorRepSet:
"""
Wrapper around filter_invalid_reprs for a list of nodes. This will happen
for the cat operator, where the first argument is a list of nodes.
"""
# For variable length args, assume that they all need to use the same representation
# only one repset should be defined
common_tensor_repsets = arg_repsets
for n in arg_node:
assert isinstance(n, torch.fx.Node)
common_tensor_repsets = common_tensor_repsets.make_intersect(
filter_invalid_reprs(n.meta["val"], common_tensor_repsets, texture_limits)
)
return common_tensor_repsets
## Convenience TensorRepSet definitions
# Only includes memory layouts that can be used by non-quantized tensors
CONTIGUOUS_ANY = TensorRepSet(
{VkMemoryLayout.TENSOR_WIDTH_PACKED}, {VkMemoryLayout.TENSOR_WIDTH_PACKED}
)
CONTIGUOUS_BUFFER = TensorRepSet({VkMemoryLayout.TENSOR_WIDTH_PACKED}, set())
WIDTH_PACKED_TEXTURE = TensorRepSet(set(), {VkMemoryLayout.TENSOR_WIDTH_PACKED})
HEIGHT_PACKED_TEXTURE = TensorRepSet(set(), {VkMemoryLayout.TENSOR_HEIGHT_PACKED})
CHANNELS_PACKED_TEXTURE = TensorRepSet(set(), {VkMemoryLayout.TENSOR_CHANNELS_PACKED})
CHANNELS_PACKED_ANY = TensorRepSet(
{VkMemoryLayout.TENSOR_CHANNELS_PACKED}, {VkMemoryLayout.TENSOR_CHANNELS_PACKED}
)
CHANNELS_PACKED_TEXTURE_OR_CONTIGUOUS_BUFFER = TensorRepSet(
{VkMemoryLayout.TENSOR_WIDTH_PACKED}, {VkMemoryLayout.TENSOR_CHANNELS_PACKED}
)
ANY_TEXTURE = TensorRepSet(set(), all_memory_layouts)
ANY_BUFFER = TensorRepSet(all_memory_layouts, set())
ANY_STORAGE = TensorRepSet(all_memory_layouts, all_memory_layouts)
# Only includes memory layouts that can be used by quantized tensors
PACKED_INT8_BUFFER = TensorRepSet(all_quantized_memory_layouts, set())
PACKED_INT8_4W4C_BUFFER = TensorRepSet({VkMemoryLayout.PACKED_INT8_4W4C}, set())
PACKED_INT8_4H4W_BUFFER = TensorRepSet({VkMemoryLayout.PACKED_INT8_4H4W}, set())
PACKED_INT8_4W_BUFFER = TensorRepSet({VkMemoryLayout.PACKED_INT8_4W}, set())
PACKED_INT8_4C1W_BUFFER = TensorRepSet({VkMemoryLayout.PACKED_INT8_4C1W}, set())
PACKED_INT8_CONV2D_BUFFER = TensorRepSet({VkMemoryLayout.PACKED_INT8_CONV2D}, set())
PACKED_INT8_CHANNELS_PACKED_BUFFER = TensorRepSet(
{
VkMemoryLayout.PACKED_INT8_4W4C,
VkMemoryLayout.PACKED_INT8_4C1W,
VkMemoryLayout.PACKED_INT8_CONV2D,
},
set(),
)
# Special use RepSets
NO_STORAGE = TensorRepSet(set(), set())
ALL_STORAGES_REPSET = TensorRepSet(
universal_memory_layout_set, universal_memory_layout_set
)
class TensorRepSetList:
"""
This class is a wrapper around a list of TensorRepSet instances that automatically
applies a "broadcasting" mechanism. The broadcasting mechanism allows for a single
underlying TensorRepSet to be used for multiple tensors.
"""
def __init__(
self,
tensor_repsets: Union[TensorRepSet, List[TensorRepSet]],
):
self.vals: List[TensorRepSet] = (
tensor_repsets if isinstance(tensor_repsets, list) else [tensor_repsets]
)
def __len__(self):
return len(self.vals)
def __getitem__(self, idx: int) -> TensorRepSet:
if idx > 0 and len(self) == 1:
return self.vals[0]
if idx >= len(self.vals):
return set()
return self.vals[idx]
def __setitem__(self, idx: int, val: TensorRepSet) -> None:
if idx > 0 and len(self.vals) == 1:
self.vals[0] = val
else:
self.vals[idx] = val
def __str__(self) -> str:
return f"[{', '.join(str(ts) for ts in self.vals)}]"
def append(self, val: TensorRepSet) -> None:
return self.vals.append(val)
def any_is_empty(self) -> bool:
if len(self.vals) == 0:
return True
return any(tensor_repr.is_empty() for tensor_repr in self.vals)
class OpRepSets:
"""
This class is responsible for representing and managing the set of valid tensor
representations that may be used for all input and output tensors of an operator.
It is also responsible for maintaining synchronization rules between tensors
participating in the computation.
Currently, three synchronization rules exist:
1. All input tensors must use the same representation (e.g. binary ops)
2. The "primary" input and output tensors must use the same representation
(e.g. group norm; the output is a tuple of out, mean, rstd; out must be the same
representation as the first input x, but mean and rstd may use different
representations as out)
3. All output tensors must use the same representation (e.g. choose qparams)
Note that "primary" input and output tensor refers to the first non-weight input
tensor and the first output tensor. Note that Some operators (such as arange) do not
have any tensor inputs.
Currently, the above three synchronization rules are sufficient to describe the
representation requirements of all ET-VK operators.
This class also provides utilities to constrain the repsets; when applying the
constraints, the synchronization rules will be maintained.
"""
def __init__( # noqa: C901
self,
inputs_repsets: TensorRepSetList,
outputs_repsets: TensorRepSetList,
op_node: torch.fx.Node,
texture_limits: ImageExtents,
):
self.op_node = op_node
# inputs_repset_list is received from the operator registration. If a different
# repset is defined for each input tensor, then assume that the input tensor
# representations do not need to be synchronized.
if len(inputs_repsets) > 1:
self.sync_args_repr = False
# Otherwise, default to True
else:
self.sync_args_repr = True
# outputs_repset_list is received from the operator registration. If a different
# repset is defined for each output tensor, then assume that the output tensor
# representations do not need to be synchronized.
if len(outputs_repsets) > 1:
self.sync_outs_repr = False
else:
self.sync_outs_repr = True
# Try to determine the index of the "primary" argument, i.e. the first non
# constant tensor argument. For the vast majority of operators with tensor
# arguments, this will be the first argument.
self.primary_arg_idx: Optional[int] = None
for i, arg_node in enumerate(self.op_node.args):
arg_node_repset = inputs_repsets[i]
if not is_tensor_arg_node(arg_node):
continue
if arg_node_repset is None:
continue
if arg_node_repset.is_empty():
continue
self.primary_arg_idx = i
break
# If the repset of the primary input and the primary output are the same, then
# assume they need to be the same.
self.sync_primary_io_repr = self.primary_arg_idx is not None
if self.primary_arg_idx is not None:
if inputs_repsets[self.primary_arg_idx] != outputs_repsets[0]:
self.sync_primary_io_repr = False
# Now, go through the arguments of the operator and create a filtered repset
# for each based on the actual tensor value.
args_repset_list = TensorRepSetList([])
common_arg_repset = ALL_STORAGES_REPSET
for i, arg_node in enumerate(op_node.args):
arg_repset = inputs_repsets[i]
# Use ALL_STORAGES_REPSET for non-tensor nodes so they don't cause the op
# repsets to appear empty
if not is_tensor_arg_node(arg_node):
args_repset_list.append(ALL_STORAGES_REPSET)
# NO_STORAGE is used to denote that an input is either a non tensor arg or
# a weight tensor that is not prepacked. Similar to the above, use
# ALL_STORAGES_REPSET in this case.
elif arg_repset.is_empty():
args_repset_list.append(ALL_STORAGES_REPSET)
else:
assert not arg_repset.is_empty()
arg_repset = self.filter_invalid_reprs_for_arg(
arg_repset, arg_node, texture_limits
)
args_repset_list.append(arg_repset)
common_arg_repset = common_arg_repset.make_intersect(arg_repset)
# Repeat for output tensors.
outs_repset_list = TensorRepSetList([])
common_out_repset = ALL_STORAGES_REPSET
if num_tensors_in_node(op_node) == 1:
common_out_repset = filter_invalid_reprs(
op_node.meta["val"], outputs_repsets[0], texture_limits
)
outs_repset_list.append(common_out_repset)
# Multiple output tensors
else:
for i, val in enumerate(op_node.meta["val"]):
assert isinstance(val, FakeTensor)
out_repset = filter_invalid_reprs(
val, outputs_repsets[i], texture_limits
)
outs_repset_list.append(out_repset)
common_out_repset = common_out_repset.make_intersect(out_repset)
# Apply synchronization rules between the primary input and output
primary_repset = NO_STORAGE
if self.sync_primary_io_repr:
primary_in_repset = (
common_arg_repset
if self.sync_args_repr
else args_repset_list[self.primary_arg_idx]
)
primary_out_repset = (
common_out_repset if self.sync_outs_repr else outs_repset_list[0]
)
primary_repset = primary_in_repset.make_intersect(primary_out_repset)
args_repset_list[self.primary_arg_idx] = primary_repset.copy()
outs_repset_list[0] = primary_repset.copy()
# Apply synchronization rules; if either all inputs/outputs must use the same
# representation, then only use a single underlying repset.
if self.sync_args_repr:
common_repset = (
primary_repset if self.sync_primary_io_repr else common_arg_repset
)
for i in range(len(args_repset_list)):
args_repset_list[i] = common_repset.copy()
if self.sync_outs_repr:
common_repset = (
primary_repset if self.sync_primary_io_repr else common_out_repset
)
for i in range(len(outs_repset_list)):
outs_repset_list[i] = common_repset.copy()
# Save the resulting repsets
self.args_repset_list = args_repset_list
self.outs_repset_list = outs_repset_list
# Check that synchronization rules are respected.
self.assert_sync_contraints()
def __str__(self) -> str:
return f"OpRepSets(ins={self.args_repset_list}, outs={self.outs_repset_list})"
def filter_invalid_reprs_for_arg(
self, arg_repsets: TensorRepSet, arg_node: Any, texture_limits: ImageExtents
) -> TensorRepSet:
"""
Helper function to call filter_invalid_reprs
"""
if isinstance(arg_node, torch.fx.Node) and is_single_tensor_node(arg_node):
return filter_invalid_reprs(
arg_node.meta["val"], arg_repsets, texture_limits
)
elif isinstance(arg_node, list) and all(
is_single_tensor_node(n) for n in arg_node
):
return filter_invalid_reprs_for_node_list(
arg_repsets, arg_node, texture_limits
)
# Special case for getitem; return the repset of the particular val in the
# list of tensors that is being extracted.
elif (
self.op_node.target == operator.getitem and arg_node == self.op_node.args[0]
):
idx = self.op_node.args[1]
assert isinstance(idx, int)
return filter_invalid_reprs(
arg_node.meta["val"][idx], arg_repsets, texture_limits
)
raise NotImplementedError(f"Unhandled node type {arg_node}")
def assert_sync_contraints(self) -> None:
if self.sync_args_repr:
for i in range(len(self.args_repset_list)):
for j in range(i + 1, len(self.args_repset_list)):
ri = self.args_repset_list[i]
rj = self.args_repset_list[j]
if not ri.is_empty() and not rj.is_empty():
assert ri.has_compatible_packed_dim_info_set(
rj
), f"Synced arg repsets {i} and {j} have incompatible packed dim info: {ri} vs {rj}"
if self.sync_outs_repr:
for i in range(len(self.outs_repset_list)):
for j in range(i + 1, len(self.outs_repset_list)):
ri = self.outs_repset_list[i]
rj = self.outs_repset_list[j]
if not ri.is_empty() and not rj.is_empty():
assert ri.has_compatible_packed_dim_info_set(
rj
), f"Synced out repsets {i} and {j} have incompatible packed dim info: {ri} vs {rj}"
if self.sync_primary_io_repr:
primary_arg = self.args_repset_list[self.primary_arg_idx]
primary_out = self.outs_repset_list[0]
if not primary_arg.is_empty() and not primary_out.is_empty():
assert primary_arg.has_compatible_packed_dim_info_set(
primary_out
), f"Primary arg and out repsets have incompatible packed dim info: {primary_arg} vs {primary_out}"
def any_is_empty(self) -> bool:
return (
self.args_repset_list.any_is_empty() or self.outs_repset_list.any_is_empty()
)
def get_arg_repset(self, i: int):
return self.args_repset_list[i]
def get_out_repset(self, i: int):
return self.outs_repset_list[i]
def try_constrain_with_arg_repset(
self, arg_i: int, source_repset: TensorRepSet
) -> bool:
"""
Attempt to constrain the repsets of the tensors participating in this operator
based on an "existing" repset of an argument. The existing repset can have two
sources:
* A representation may have been determined for the argument already from a
prior operator
* The output repset of the operator which produces the argument
If the existing repset of the argument is compatible with the current operator,
then constrain the repsets of this operator and apply synchronization rules.
This process tries to minimize the number of transition nodes that will need to
be inserted by tag_memory_meta_pass.py by maintaining existing representations
for as long as possible.
"""
arg_current_repset = self.args_repset_list[arg_i]
if arg_current_repset == source_repset:
return False
if not arg_current_repset.any_in_common(source_repset):
return False
if self.sync_primary_io_repr:
if not self.get_out_repset(0).has_compatible_packed_dim_info_set(
source_repset
):
return False
# If this point is reached, then it is possible to constrain
narrowed = arg_current_repset.make_intersect(source_repset)
self.args_repset_list[arg_i] = narrowed
# Propagate to other synced args via packed-dim compatibility
if self.sync_args_repr:
for i in range(len(self.args_repset_list)):
if i != arg_i:
self.args_repset_list[i] = self.args_repset_list[
i
].constrain_to_compatible_packed_dim(narrowed)
# Propagate to output via packed-dim compatibility
if self.sync_primary_io_repr and (
arg_i == self.primary_arg_idx or self.sync_args_repr
):
self.outs_repset_list[0] = self.outs_repset_list[
0
].constrain_to_compatible_packed_dim(narrowed)
# Propagate to other synced outputs via packed-dim compatibility
if self.sync_outs_repr:
for i in range(len(self.outs_repset_list)):
if i != 0:
self.outs_repset_list[i] = self.outs_repset_list[
i
].constrain_to_compatible_packed_dim(self.outs_repset_list[0])
self.assert_sync_contraints()
return True
def try_constrain_with_out_repset(self, required_repset: TensorRepSet) -> bool:
"""
Attempt to constrain the output repsets of the tensors participating in this
operator based the repset required by a downstream operator.
"""
out_current_repset = self.outs_repset_list[0]
if out_current_repset == required_repset:
return False
if not out_current_repset.any_in_common(required_repset):
return False
narrowed = out_current_repset.make_intersect(required_repset)
self.outs_repset_list[0] = narrowed
# Propagate to other synced outputs via packed-dim compatibility
if self.sync_outs_repr:
for i in range(len(self.outs_repset_list)):
if i != 0:
self.outs_repset_list[i] = self.outs_repset_list[
i
].constrain_to_compatible_packed_dim(narrowed)
# Propagate to primary arg via packed-dim compatibility
if self.sync_primary_io_repr:
self.args_repset_list[self.primary_arg_idx] = self.args_repset_list[
self.primary_arg_idx
].constrain_to_compatible_packed_dim(narrowed)
# Propagate to other synced args via packed-dim compatibility
if self.sync_args_repr:
for i in range(len(self.args_repset_list)):
if i != self.primary_arg_idx:
self.args_repset_list[i] = self.args_repset_list[
i
].constrain_to_compatible_packed_dim(
self.args_repset_list[self.primary_arg_idx]
)
self.assert_sync_contraints()
return True
def pick_representations(self) -> Tuple[TensorReprList, TensorReprList]:
"""
For each tensor participating in the op, pick a representation for it among the
possible represetntation sets.
"""
args_repr_list = TensorReprList([])
outs_repr_list = TensorReprList([])
for i in range(len(self.op_node.args)):
arg_repset = self.args_repset_list[i]
args_repr_list.append(arg_repset.make_tensor_repr())
for i in range(num_tensors_in_node(self.op_node)):
out_repset = self.outs_repset_list[i]
outs_repr_list.append(out_repset.make_tensor_repr())
return args_repr_list, outs_repr_list
##
## TensorSpec Utils
##
def has_node_spec_attr(node: torch.fx.Node, attr: str) -> bool:
return "spec" in node.meta and hasattr(node.meta["spec"], attr)
def set_node_spec_attr(node: torch.fx.Node, attr: str, value):
assert "spec" in node.meta
spec = node.meta["spec"]
if isinstance(spec, TensorSpec):
setattr(spec, attr, value)
elif isinstance(spec, (list, tuple)):
# Special case if value is a list/tuple of the same length as the
# collection of tensors in the node. In this case, treat the value list
# as a list of values to set indivudually for each tensor in the node
if isinstance(value, (list, tuple)) and len(spec) == len(value):
assert len(spec) == len(value)
for s, v in zip(spec, value):
assert isinstance(s, TensorSpec)
setattr(s, attr, v)
# Otherwise, set the attribute to value for all tensors in the list
else:
for s in spec:
assert isinstance(s, TensorSpec)
setattr(s, attr, value)
else:
raise RuntimeError(f"Cannot set attr for spec of type {type(spec)}")
def get_node_spec_attr(node: torch.fx.Node, attr: str, return_first: bool = True):
assert "spec" in node.meta
spec = node.meta["spec"]
if isinstance(spec, TensorSpec):
return getattr(spec, attr) if hasattr(spec, attr) else None
elif isinstance(spec, (list, tuple)):
if return_first:
return getattr(spec[0], attr) if hasattr(spec[0], attr) else None
else:
return [getattr(s, attr) if hasattr(s, attr) else None for s in spec]
else:
raise RuntimeError(f"Cannot get attr for spec of type {type(spec)}")
def get_node_storage_type(node: torch.fx.Node) -> Optional[VkStorageType]:
return get_node_spec_attr(node, "vk_storage_type")
def get_node_memory_layout(node: torch.fx.Node) -> Optional[VkMemoryLayout]:
return get_node_spec_attr(node, "vk_memory_layout")
def has_node_repr(node) -> bool:
if isinstance(node, (list, tuple)):
return all(has_node_spec_attr(n, "etvk_node_repr") for n in node)
else:
return has_node_spec_attr(node, "etvk_node_repr")
def set_node_repr(node: torch.fx.Node, node_repr: Union[TensorRepr, TensorReprList]):
if isinstance(node_repr, TensorReprList):
# Convert to a regular list so taht `set_node_spec_attr` can attach each entry
# to a separate TensorSpec
node_repr_list = [node_repr[i] for i in range(num_tensors_in_node(node))]
set_node_spec_attr(node, "etvk_node_repr", node_repr_list)
else:
set_node_spec_attr(node, "etvk_node_repr", node_repr)
def get_node_repr(node) -> Union[TensorRepr, TensorReprList]:
if isinstance(node, (list, tuple)):
raise NotImplementedError("get_node_repr not implemented for list of nodes")
else:
return get_node_spec_attr(node, "etvk_node_repr", False)
##
## Graph Pattern Matching
##
def maybe_skip_q_dq_arg_chain(
arg: torch.fx.node.Argument,
) -> Tuple[Optional[torch.fx.Node], Optional[torch.fx.Node], Optional[torch.fx.Node]]:
"""
Check if the given node argument is part of a Quantize/Dequantize chain produced by
the quant workflow. If so, return the source tensor that is the input to the Q/DQ
chain and the quantize/dequantize nodes in the chain. Otherwise, return the argument
as is and None, None
"""
if not isinstance(arg, torch.fx.Node):
return None, None, None
# If the arg is a view copy node, check if the original node is a dequant node
if is_dequant_node(arg) or (
is_view_copy_node(arg) and is_dequant_node(arg.args[0]) # pyre-ignore[6]
):
dequant_node = arg
if is_view_copy_node(arg):
dequant_node = arg.args[0]
quant_node = dequant_node.args[0] # pyre-ignore[16]
assert isinstance(quant_node, torch.fx.Node)
source_arg = quant_node.args[0]
assert isinstance(source_arg, torch.fx.Node)
assert isinstance(dequant_node, torch.fx.Node)
return source_arg, quant_node, dequant_node
else:
return arg, None, None
def trace_args_until_placeholder(
node: torch.fx.node.Argument, max_search_depth: int = 4
) -> Tuple[Optional[torch.fx.Node], List[torch.fx.Node]]:
"""
Trace through node.args[0] of a given initial node until a placeholder node is found
then return it and the list of nodes traversed. If no placeholder node is found,
returns None and an empty list.
"""
cur_node = node
search_depth = 0
if not isinstance(cur_node, torch.fx.Node):
return None, []
traversed = [cur_node]
while cur_node.op != "placeholder" and search_depth < max_search_depth:
# Break if cur_node has no args
if len(cur_node.args) == 0:
break
cur_node = cur_node.args[0]
if not isinstance(cur_node, torch.fx.Node):
break
traversed.append(cur_node)
search_depth += 1
if not isinstance(cur_node, torch.fx.Node):
return None, []
if cur_node.op != "placeholder":
return None, []
assert isinstance(cur_node, torch.fx.Node)
return cur_node, traversed
def is_in_4bit_range(tensor: torch.Tensor) -> bool:
"""
Check if the given tensor is in the range of 4-bit quantization and is of integer type.
"""
if tensor.dtype not in (torch.int8, torch.uint8):
return False
return tensor.min().item() >= -8 and tensor.max().item() <= 7
def is_in_8bit_range(tensor: torch.Tensor) -> bool:
"""
Check if the given tensor is in the range of 4-bit quantization and is of integer type.
"""
if tensor.dtype not in (torch.int8, torch.uint8):
return False
return tensor.min().item() >= -128 and tensor.max().item() <= 127
##
## Misc
##
def normalize_dims(dims: Union[int, List[int]], ndim: int) -> Union[int, List[int]]:
"""
Normalize dimension indices to be non-negative and within [0, ndim).
Accepts a single int or a list of ints.
"""
if isinstance(dims, int):
if dims < 0:
dims += ndim
return dims
normalized = []
for d in dims:
if d < 0:
d += ndim
normalized.append(d)
return normalized
def nchw_dim_to_whcn_dim(nchw_dim: int, ndim: int) -> int:
# Handle negative indices for nchw_dim
if nchw_dim < 0:
nchw_dim += ndim
assert nchw_dim >= 0 and nchw_dim < ndim
whcn_dim = (ndim - 1) - nchw_dim
return whcn_dim
def get_tensor_val_str(tensor_val: FakeTensor) -> str:
return f"{tensor_val.dtype}: {tensor_val.shape}"
def get_node_val_str(node: torch.fx.Node) -> str:
if is_single_tensor_node(node):
assert isinstance(node.meta["val"], FakeTensor)
return get_tensor_val_str(node.meta["val"])
elif is_tensor_collection_node(node):
assert isinstance(node.meta["val"], (list, tuple))
return f"[{', '.join(get_tensor_val_str(t) for t in node.meta['val'])}]"
else:
if "val" not in node.meta:
return str(node)
return str(node.meta["val"])
def get_arg_node_val_str(arg_node: Any) -> str:
if isinstance(arg_node, torch.fx.Node):
return get_node_val_str(arg_node)
elif isinstance(arg_node, (list, tuple)):
return f"[{', '.join(get_arg_node_val_str(n) for n in arg_node)}]"
else:
return str(arg_node)
def node_io_str(node: torch.fx.Node) -> str:
target = node.target
if isinstance(target, EdgeOpOverload):
assert isinstance(target, EdgeOpOverload)
target_name = target.__name__
elif isinstance(target, torch._ops.OpOverload):
assert isinstance(target, torch._ops.OpOverload)
target_name = target.name()
else:
target_name = str(target)
out_str = f"{get_node_val_str(node)} = {target_name}("
for arg in node.args:
out_str += get_arg_node_val_str(arg) + ", "
out_str += " ...)"
return out_str
def update_program_state_dict(
program: ExportedProgram,
buffer_name: str,
updated_tensor: torch.Tensor,
) -> None:
target_name = None
kind = None
# Iterate over all the tensors in the graph signature, and find
# the one corresponding to the parameter/buffer name
for input_ in program.graph_signature.input_specs:
if (
input_.kind in (InputKind.BUFFER, InputKind.PARAMETER)
and isinstance(input_.arg, TensorArgument)
and input_.arg.name == buffer_name
):
kind = input_.kind
target_name = input_.target
break
# Assert that we found the parameter/buffer
assert (
target_name is not None
), f"could not find {buffer_name} in source program signature"
assert target_name in program.state_dict, f"could not find {target_name}"
if kind == InputKind.PARAMETER:
updated_tensor = torch.nn.Parameter(updated_tensor, requires_grad=False)
# Finally, overwrite the current tensor with updated tensor
program.state_dict[target_name] = updated_tensor
def align_width_and_update_state_dict(
ep: ExportedProgram,
node: torch.fx.Node,
cur_tensor: torch.Tensor,
align_to: int = 4,
force_update: bool = False,
) -> torch.Tensor:
"""
Align the width of the given tensor to the given alignment value and update the
state dict of the program with the aligned tensor.
"""
added_padding = False
cur_width = cur_tensor.shape[-1]
# Only align the width of the tensor if it is not already aligned
if cur_width % align_to != 0:
num_padding = align_to - (cur_width % align_to)
# Align the width of the tensor to the given alignment value
aligned_tensor = torch.nn.functional.pad(
cur_tensor, (0, num_padding)
).contiguous()
added_padding = True
else:
aligned_tensor = cur_tensor
if added_padding or force_update:
update_program_state_dict(ep, node.name, aligned_tensor)
# FakeTensor needs to match updated tensor
cur_fake_tensor = node.meta["val"]
node.meta["val"] = FakeTensorConverter().from_real_tensor(
cur_fake_tensor.fake_mode,
aligned_tensor,
)
return aligned_tensor