blob: b233c157f67d54fe31de0b1b03b1bb5c1070caa1 [file] [log] [blame]
import torch
from torch.fx.graph import Graph
from torch.fx.graph_module import GraphModule
from torch.fx.node import Node
from torch.fx._symbolic_trace import symbolic_trace
import itertools
import operator
from typing import Dict, List
def get_first_dim(t: torch.Tensor) -> int:
"""
A free function primarily for use in the merge_matmul graph transformation below
that returns the first dimension of a Tensor. This is necessary because torch.Tensor.shape
is an attribute (and cannot be the target of a call_function node) and also helps save
a getitem op in the graph.
Arguments:
t: The tensor to get the first dimension of.
Returns:
The first dimension of t.
"""
return t.shape[0]
def legalize_graph(gm: GraphModule):
"""
Replace the graph of the given GraphModule with one that contains the same nodes as the
original, but in topologically sorted order.
This is used by the merge_matmul transformation below, which disturbs the topologically sorted
order of its input GraphModule, so that this order is restored before further transformation.
Arguments:
gm: The graph module to topologically sort. It is modified in-place.
"""
# Build an adjacency list representation of node dependencies in the graph. This also
# serves as a list of nodes that still need to be inserted into the new, topologically
# sorted graph.
dependencies = {node: node.all_input_nodes.copy() for node in gm.graph.nodes}
# Construct a new graph that will contain all nodes in topologically sorted order.
new_graph = Graph()
value_remap: Dict[Node, Node] = {}
# Copy over all nodes with no dependencies.
for node, deps in dependencies.items():
if not deps:
value_remap[node] = new_graph.node_copy(node, lambda n: value_remap[n])
# Remove the copied over nodes from the adjacency list.
for copied_node in value_remap.keys():
del dependencies[copied_node]
# While there are still nodes to insert into the new graph:
while dependencies:
copied_this_round = []
# Copy over all nodes whose dependencies already exist in the new graph.
for node, deps in dependencies.items():
all_deps_copied = True
for dep in deps:
if dep not in value_remap:
all_deps_copied = False
if all_deps_copied:
value_remap[node] = new_graph.node_copy(node, lambda n: value_remap[n])
copied_this_round.append(node)
# Delete all nodes copied over in this iteration from dependencies.
for copied_node in copied_this_round:
del dependencies[copied_node]
# Replace the old graph with the new, topologically sorted one.
gm.graph = new_graph
def may_depend_on(a: Node, b: Node, search_depth: int = 6):
"""
Determine if one node depends on another in a torch.fx.Graph.
Arguments:
a: The node that may have a dependency on b.
b: The node that a may have a dependency on.
search_depth: In the case of an indirect dependency, this function
searches upto this many nodes away in search of a
data dependency. If none is found, the function
makes the conservative assumption that there is a
dependency.
Returns:
True if a may depend on b, False if it definitely does not.
"""
# Equivalence is defined as dependence.
if a == b:
return True
# If a has no inputs, it cannot depend on b.
if len(a.all_input_nodes) == 0:
return False
# If the search depth has been exhausted and no conclusion has been
# reached, assume that there is a data dependency.
if search_depth == 0:
return True
# Recursively check all inputs of a.
for inp in a.all_input_nodes:
if may_depend_on(inp, b, search_depth - 1):
return True
return False
def are_nodes_independent(nodes: List[Node]):
"""
Check if all of the given nodes are pairwise-data independent.
Arguments:
nodes: The nodes to check for data dependencies.
Returns:
True if any pair in nodes has a data dependency.
"""
# For each pair in nodes:
for i, j in itertools.combinations(nodes, 2):
if may_depend_on(i, j) or may_depend_on(j, i):
return False
return True
def merge_matmul(in_mod: torch.nn.Module):
"""
A graph transformation that merges matrix multiplication operations that share the same right-hand
side operand into one large matrix multiplication.
____ _________ _________
---- | | | | M| A * C |
M| A | T| B | * K| C | = |---------|
---- , | | | | T| B * C |
K ---- --------- ---------
K R R
"""
gm = symbolic_trace(in_mod)
rhs_users: Dict[Node, List[Node]] = {}
lhs_users: Dict[Node, List[Node]] = {}
# Populate rhs_users and lhs_users - maps from LHS/RHS matrix multiply operands to
# the matmul of which they are the LHS/RHS.
for node in gm.graph.nodes:
if node.op != "call_function" or node.target is not torch.matmul:
continue
lhs, rhs = node.args
# TODO: Properly handle aliasing caused by get_attr. For now,
# use the attribute name as the operand if the node is a
# get_attr.
lhs = lhs.target if lhs.op == "get_attr" else lhs
rhs = rhs.target if rhs.op == "get_attr" else rhs
lhs_users.setdefault(lhs, []).append(node)
rhs_users.setdefault(rhs, []).append(node)
for rhs, mms in rhs_users.items():
# There must be at least matmuls for a merge to make sense.
if len(mms) < 2:
continue
# All matmuls must not depend on each other directly or indirectly
# in order for the merge to be possible.
if not are_nodes_independent(mms):
continue
lhs_vals = [mm.args[0] for mm in mms]
# Merge the matmul.
# Collect a list of LHS operands and the single RHS operand.
lhs = [gm.graph.get_attr(l) if isinstance(l, str) else l for l in lhs_vals]
rhs = gm.graph.get_attr(rhs) if isinstance(rhs, str) else rhs
# Concatenate all the LHS operands.
merge_mm_cat = gm.graph.call_function(torch.cat, (lhs,), {})
# Multiply the concatenated LHS operands with the one RHS. This will produce
# the same results as all the individual matmuls involving rhs in the original graph,
# but they will all be concatenated together.
merge_mm = gm.graph.call_function(torch.matmul, (merge_mm_cat, rhs,), {})
# Split the result of the merged matmul using the shapes of the LHS operands
# to ascertain how large each chunk should be.
merge_mm_sizes = [
gm.graph.call_function(get_first_dim, (l,), {}) for l in lhs
]
merge_mm_split = gm.graph.call_function(
torch.split, (merge_mm, merge_mm_sizes), {}
)
merge_mm_res = [
gm.graph.call_function(operator.getitem, (merge_mm_split, out), {})
for out in range(len(lhs))
]
# Replace all uses of the original, unmerged matmuls with the equivalent split chunk from the merged matmul.
for old, new in zip(mms, merge_mm_res):
old.replace_all_uses_with(new)
gm.graph.erase_node(old)
# All of the new nodes created above were inserted at the end, so we need to sort
# the nodes topologically to make sure all definitions precede uses.
legalize_graph(gm)
gm.recompile()
gm.graph.lint()
return gm