Added utility to count memory reads/written in Inductor (#89203)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/89203
Approved by: https://github.com/jansel, https://github.com/ngimel
diff --git a/test/inductor/test_perf.py b/test/inductor/test_perf.py
new file mode 100644
index 0000000..d473ff4
--- /dev/null
+++ b/test/inductor/test_perf.py
@@ -0,0 +1,434 @@
+# Owner(s): ["module: inductor"]
+import contextlib
+from unittest.mock import patch
+
+import torch._dynamo
+import torch._inductor.config as config
+from torch._dynamo.optimizations.backends import register_backend
+from torch._inductor import metrics
+from torch._inductor.compile_fx import compile_fx, count_bytes_inner
+from torch.testing._internal.common_utils import (
+ TEST_WITH_ROCM,
+ TestCase as TorchTestCase,
+)
+from torch.testing._internal.inductor_utils import HAS_CUDA
+
+aten = torch.ops.aten
+
+
+@register_backend
+def count_bytes_inductor(gm, example_inputs):
+ return compile_fx(gm, example_inputs, inner_compile=count_bytes_inner)
+
+
+@torch._dynamo.optimize("count_bytes_inductor")
+def f(x):
+ return torch.cat([x, x.cos()])
+
+
+def count_numel(f, *args):
+ """
+ Assumes all inputs are fp32
+ """
+ metrics.reset()
+ torch._dynamo.optimize("count_bytes_inductor")(f)(*args)
+ print(metrics.nodes_num_elem)
+ return str(metrics.num_bytes_accessed // 4)
+
+
+DEVICE = "cuda"
+
+
+def T(*size, dtype=torch.float32, device=DEVICE):
+ return torch.randn(size, dtype=dtype, device=device)
+
+
+def TI(*size, mx=10, dtype=torch.int32, device=DEVICE):
+ return torch.randint(0, mx, size, dtype=dtype, device=device)
+
+
+class TestCase(TorchTestCase):
+ device = DEVICE
+ pass
+
+
+class NumBytesMetricTests(TestCase):
+ """
+ Primarily used for sanity testing that the num_bytes_accessed metrics is correct.
+ """
+
+ def test_pointwise(self):
+ def f(x):
+ return x.cos()
+
+ inp = (T(10),)
+ self.assertExpectedInline(count_numel(f, *inp), """20""")
+
+ def f(x, y):
+ return x + y
+
+ inp = (T(10), T(10))
+ self.assertExpectedInline(count_numel(f, *inp), """30""")
+
+ def f(x, y):
+ return x + y
+
+ inp = (T(10, 10), T(10))
+ self.assertExpectedInline(count_numel(f, *inp), """210""")
+
+ def f(x):
+ return x + x
+
+ inp = (T(10),)
+ self.assertExpectedInline(count_numel(f, *inp), """20""")
+
+ def f(x):
+ return x + x.t()
+
+ inp = (T(10, 10),)
+ self.assertExpectedInline(count_numel(f, *inp), """200""")
+
+ def f(a, b, c):
+ return a.cos(), b.sin() + c.sin()
+
+ inp = (T(10), T(10), T(10))
+ self.assertExpectedInline(count_numel(f, *inp), """50""")
+
+ def test_reduction(self):
+ def f(x):
+ return x.sum(dim=1)
+
+ inp = (T(10, 10),)
+ self.assertExpectedInline(count_numel(f, *inp), """110""")
+
+ def f(x):
+ return x.sum(dim=0)
+
+ inp = (T(10, 10),)
+ self.assertExpectedInline(count_numel(f, *inp), """110""")
+
+ def test_extern(self):
+ def f(x):
+ return torch.mm(x, x)
+
+ inp = (T(10, 10),)
+ self.assertExpectedInline(count_numel(f, *inp), """200""")
+
+ def f(a, b):
+ return torch.mm(a, b)
+
+ inp = (T(10, 10), T(10, 10))
+ self.assertExpectedInline(count_numel(f, *inp), """300""")
+
+ def f(x):
+ x = x.cos()
+ x = torch.mm(x, x)
+ x = x.cos()
+ return x
+
+ inp = (T(10, 10),)
+ self.assertExpectedInline(count_numel(f, *inp), """600""")
+
+ def f(x):
+ a = x.cos()
+ b = x.sin()
+ x = torch.mm(a, b)
+ return x
+
+ inp = (T(10, 10),)
+ self.assertExpectedInline(count_numel(f, *inp), """600""")
+
+ def test_cat(self):
+ def f(a, b):
+ return torch.cat([a.sin(), b.sin()])
+
+ inp = (T(10), T(10))
+ self.assertExpectedInline(count_numel(f, *inp), """40""")
+
+ def f(a, b):
+ return torch.cat([a, b])
+
+ inp = (T(10), T(10))
+ self.assertExpectedInline(count_numel(f, *inp), """40""")
+
+ def f(a, b):
+ return torch.cat([a.cos(), b])
+
+ inp = (T(10), T(10))
+ self.assertExpectedInline(count_numel(f, *inp), """40""")
+
+ def f(a):
+ return torch.cat([a.cos(), a.sin()])
+
+ inp = (T(10),)
+ self.assertExpectedInline(count_numel(f, *inp), """30""")
+
+ def test_index(self):
+ def f(a, b):
+ return a[b]
+
+ inp = (T(10), TI(10, mx=10))
+ self.assertExpectedInline(count_numel(f, *inp), """30""")
+
+
+class FusionTests(TestCase):
+ """
+ Tests that things can be fused into a single kernel
+ """
+
+ def test_horizontal_reduction_pointwise(self):
+ def f(a):
+ b = a.sum(dim=1)
+ c = a.cos()
+ return b, c
+
+ inp = (T(10, 10),)
+ self.assertExpectedInline(count_numel(f, *inp), """210""")
+
+ def test_horizontal_reduction_reduction(self):
+ def f(a):
+ b = a.sum(dim=1)
+ c = a.amax(dim=1)
+ return b, c
+
+ inp = (T(10, 10),)
+ self.assertExpectedInline(count_numel(f, *inp), """120""")
+
+ def test_horizontal_reduction_pointwise2(self):
+ def f(a, b):
+ c = a.sum(dim=1)
+ b = b.cos()
+ return b + c
+
+ inp = (T(10, 10), T(10))
+ self.assertExpectedInline(count_numel(f, *inp), """120""")
+
+ def test_horizontal_reduction_outer_pointwise(self):
+ def f(a, b):
+ c = a.sum(dim=0)
+ b = b.cos()
+ return b + c
+
+ inp = (T(10, 10), T(10))
+ self.assertExpectedInline(count_numel(f, *inp), """120""")
+
+ def test_horizontal_sum_pw_broadcast(self):
+ def f(a, b):
+ a = a.sum(dim=1, keepdim=True)
+ b = b.cos()
+ return a * b
+
+ inp = (T(10, 10), T(10))
+ self.assertExpectedInline(count_numel(f, *inp), """210""")
+
+ def test_vertical_sum_pw(self):
+ def f(a):
+ a = a.cos()
+ a = a.sum(dim=1)
+ return a.cos()
+
+ inp = (T(10, 10),)
+ self.assertExpectedInline(count_numel(f, *inp), """110""")
+
+ def test_norm_chain(self):
+ def f(a):
+ b = a.sum(dim=1, keepdim=True)
+ a = a * b
+ b = a.sum(dim=1, keepdim=True)
+ a = a * b
+ b = a.sum(dim=1, keepdim=True)
+ a = a * b
+ return a
+
+ inp = (T(10, 10),)
+ self.assertExpectedInline(count_numel(f, *inp), """200""")
+
+ def test_softmax_inner(self):
+ def f(a):
+ return torch.softmax(a, dim=1)
+
+ inp = (T(10, 10),)
+ self.assertExpectedInline(count_numel(f, *inp), """200""")
+
+ def test_layer_norm(self):
+ # TODO: Suboptimal! We shouldn't need to save normalization stats.
+ mod = torch.nn.LayerNorm(10, device=self.device)
+
+ def f(x):
+ return mod(x)
+
+ inp = (T(10, 10),)
+ with torch.no_grad():
+ self.assertExpectedInline(count_numel(f, *inp), """220""")
+
+ def test_double_softmax(self):
+ def f(x):
+ x = torch.softmax(x, dim=1)
+ x = torch.softmax(x, dim=1)
+ return x
+
+ inp = (T(10, 10),)
+ self.assertExpectedInline(count_numel(f, *inp), """200""")
+
+ def test_softmax_backward(self):
+ def f(grad_out, out):
+ return aten._softmax_backward_data(grad_out, out, 1, torch.float32)
+
+ inp = (T(10, 10), T(10, 10))
+ self.assertExpectedInline(count_numel(f, *inp), """300""")
+
+ def test_neighbor(self):
+ def f(a, b):
+ return ((a - b) ** 2).sum(dim=-1).amax(dim=1)
+
+ inp = (T(10, 1, 4), T(1, 10, 4))
+ self.assertExpectedInline(count_numel(f, *inp), """90""")
+
+ def test_factory_reduction(self):
+ def f():
+ a = torch.ones(10, device=self.device)
+ b = torch.ones(10, 10, device=self.device)
+ return (a + b).sum(dim=-1)
+
+ inp = ()
+ self.assertExpectedInline(count_numel(f, *inp), """10""")
+
+ def test_index_pointwise(self):
+ def f(a, b):
+ return a[b].cos()
+
+ inp = (T(10, 10), TI(20, mx=10))
+ self.assertExpectedInline(count_numel(f, *inp), """320""")
+
+ def test_index_reduction(self):
+ def f(a, b):
+ return a[b].cos().sum(dim=1)
+
+ inp = (T(10, 10), TI(20, mx=10))
+ self.assertExpectedInline(count_numel(f, *inp), """140""")
+
+
+class SchedulerFusionTests(TestCase):
+ """
+ Testing the fusion group creation heuristic (i.e. cases where we can't fuse
+ everything into a single kernel)
+ Disables inductor rematerialization for easier reasoning of tests.
+ """
+
+ @classmethod
+ def setUpClass(cls):
+ super().setUpClass()
+ cls._stack = contextlib.ExitStack()
+ cls._stack.enter_context(patch.object(config, "realize_bytes_threshold", 0))
+
+ @classmethod
+ def tearDownClass(cls):
+ cls._stack.close()
+ super().tearDownClass()
+
+ def test_fusion_choice1(self):
+ # Doesn't matter where we break fusion group here
+ def f(a):
+ c = a.cos()
+ d = torch.mm(c, c)
+ e = c.cos()
+ return d + e
+
+ inp = (T(10, 10),)
+ self.assertExpectedInline(count_numel(f, *inp), """700""")
+
+ def test_fusion_choice2(self):
+ # We should materialize e (it's smaller!)
+ # [c, e]: 210, [f]: 210, [d]: 200
+ def f(a):
+ c = a.cos()
+ d = torch.mm(c, c)
+ e = c.sum(dim=1)
+ f = d + e
+ return f
+
+ inp = (T(10, 10),)
+ self.assertExpectedInline(count_numel(f, *inp), """620""")
+
+ def test_fusion_choice3(self):
+ # We should materialize e.
+ # [c, e]: 300, [f]: 300, [d]: 200
+ def f(a):
+ c = a.cos()
+ d = torch.mm(c, c)
+ e = c + a
+ f = d + e
+ return f, e
+
+ inp = (T(10, 10),)
+ self.assertExpectedInline(count_numel(f, *inp), """800""")
+
+
+class TilingTests(TestCase):
+ def test_tiling_simple(self):
+ def f(a, b):
+ return a + b.t()
+
+ inp = (T(10, 10), T(10, 10))
+ self.assertExpectedInline(count_numel(f, *inp), """300""")
+
+ def f(a, b):
+ return a.t() + b
+
+ inp = (T(10, 10), T(10, 10))
+ self.assertExpectedInline(count_numel(f, *inp), """300""")
+
+ def test_tiling_three(self):
+ def f(a, b, c):
+ return a + b.permute(1, 2, 0) + c.permute(2, 0, 1)
+
+ inp = (T(10, 10, 10), T(10, 10, 10), T(10, 10, 10))
+ self.assertExpectedInline(count_numel(f, *inp), """4000""")
+
+
+# Test cases where we don't do the right thing yet.
+class WouldBeNiceIfItWorked:
+ def test_horizontal(self):
+ def f(a):
+ b = a.sum(dim=0)
+ c = a.cos()
+ return b, c
+
+ inp = (T(10, 10),)
+ self.assertExpectedInline(count_numel(f, *inp), """210""")
+
+ # TODO: We aren't fusing outer dim softmaxes
+ def test_softmax_outer(self):
+ def f(a):
+ return torch.softmax(a, dim=0)
+
+ inp = (T(10, 10),)
+ self.assertExpectedInline(count_numel(f, *inp), """200""")
+
+ # TODO: The greedy fusion strategy results in suboptimal grouping
+ @patch.object(config, "realize_bytes_threshold", 0)
+ def test_fusion_choice4(self):
+ def f(a, b, b2):
+ c = a + b
+ d = torch.mm(c, c)
+ e = c + b + b2
+ f = d + e + b2
+ return f, e
+
+ inp = (T(10, 10), T(10, 10, dtype=torch.float16), T(10, 10))
+ self.assertExpectedInline(count_numel(f, *inp), """1000""")
+
+ # TODO: We materialize the intermediate if we don't unroll the reduction
+ def test_neighbor(self):
+ def f(a, b):
+ return ((a - b) ** 2).sum(dim=-1).amax(dim=1)
+
+ inp = (T(10, 1, 8), T(1, 10, 8))
+ self.assertExpectedInline(count_numel(f, *inp), """170""")
+
+
+if __name__ == "__main__":
+ from torch._dynamo.test_case import run_tests
+
+ if HAS_CUDA and not TEST_WITH_ROCM:
+ run_tests(needs="filelock")
diff --git a/torch/_inductor/compile_fx.py b/torch/_inductor/compile_fx.py
index 813daee..c482e55 100644
--- a/torch/_inductor/compile_fx.py
+++ b/torch/_inductor/compile_fx.py
@@ -12,7 +12,7 @@
import torch.fx
from torch._subclasses.fake_tensor import FakeTensor
-from . import config, overrides
+from . import config, metrics, overrides
from .debug import DebugContext
from .decomposition import select_decomp_table
from .graph import GraphLowering
@@ -84,6 +84,22 @@
@DebugContext.wrap
+def count_bytes_inner(gm, example_inputs, num_fixed=0, **kwargs):
+ shape_env = None
+ for inp in example_inputs:
+ if isinstance(inp, FakeTensor) and inp.fake_mode.shape_env is not None:
+ shape_env = inp.fake_mode.shape_env
+
+ graph = GraphLowering(gm, shape_env=shape_env, num_static_inputs=num_fixed)
+ with V.set_graph_handler(graph):
+ graph.run(*example_inputs)
+ num_bytes, nodes_num_elem = graph.count_bytes()
+ metrics.num_bytes_accessed += num_bytes
+ metrics.nodes_num_elem += nodes_num_elem
+ return make_boxed_func(gm.forward)
+
+
+@DebugContext.wrap
@torch.utils._python_dispatch._disable_current_modes()
def compile_fx_inner(
gm: torch.fx.GraphModule,
@@ -326,7 +342,11 @@
_graph_counter = itertools.count(0)
-def compile_fx(model_: torch.fx.GraphModule, example_inputs_: List[torch.Tensor]):
+def compile_fx(
+ model_: torch.fx.GraphModule,
+ example_inputs_: List[torch.Tensor],
+ inner_compile=compile_fx_inner,
+):
"""Main entrypoint to a compile given FX graph"""
if not is_aot_autograd_safe_to_run(model_, example_inputs_):
@@ -348,7 +368,7 @@
@dynamo_utils.dynamo_timed
def fw_compiler(model: torch.fx.GraphModule, example_inputs):
fixed = len(example_inputs) - num_example_inputs
- return compile_fx_inner(
+ return inner_compile(
model,
example_inputs,
num_fixed=fixed,
@@ -359,7 +379,7 @@
@dynamo_utils.dynamo_timed
def bw_compiler(model: torch.fx.GraphModule, example_inputs):
fixed = count_tangents(model)
- return compile_fx_inner(
+ return inner_compile(
model,
example_inputs,
num_fixed=fixed,
diff --git a/torch/_inductor/dependencies.py b/torch/_inductor/dependencies.py
index 27c92f8..5434d7a 100644
--- a/torch/_inductor/dependencies.py
+++ b/torch/_inductor/dependencies.py
@@ -9,7 +9,14 @@
from . import config
from .codegen.common import index_prevent_reordering
-from .utils import sympy_product, sympy_str, sympy_subs, sympy_symbol, VarRanges
+from .utils import (
+ get_dtype_size,
+ sympy_product,
+ sympy_str,
+ sympy_subs,
+ sympy_symbol,
+ VarRanges,
+)
from .virtualized import V
log = logging.getLogger(__name__)
@@ -69,11 +76,18 @@
return MemoryDep(renames[self.name], self.index, self.size)
return self
- def numel_hint(self):
+ def numbytes_hint(self):
vars = set(self.index.free_symbols)
+ size_vars_used = []
+ for var in vars:
+ if var.name.startswith(canonicalization_prefix()):
+ # Sometimes with indirect indexing we have very weird symbol names
+ assert " " not in var.name
+ size_vars_used.append(int(var.name[len(canonicalization_prefix()) :]))
+
return V.graph.sizevars.size_hint(
- sympy_product([s for s in self.size if s in vars])
- )
+ sympy_product([self.size[i] for i in size_vars_used])
+ ) * get_dtype_size(V.graph.get_dtype(self.name))
def is_contiguous(self) -> bool:
return isinstance(self.index, (sympy.Symbol, sympy.Integer))
@@ -88,8 +102,21 @@
return StarDep(renames[self.name])
return self
- def numel_hint(self):
- return 1
+ def numbytes_hint(self):
+ from .ir import MultiOutputLayout
+
+ if self.name in V.graph.name_to_buffer:
+ buf = V.graph.name_to_buffer[self.name]
+ elif self.name in V.graph.graph_inputs:
+ buf = V.graph.graph_inputs[self.name]
+ else:
+ return 1
+ if hasattr(buf, "layout") and isinstance(buf.layout, MultiOutputLayout):
+ # NB: Too annoying to acquire, should only be used for instrumentation
+ return 1
+ return V.graph.sizevars.size_hint(
+ sympy_product(buf.get_size())
+ ) * get_dtype_size(buf.get_dtype())
def is_contiguous(self) -> bool:
return False
diff --git a/torch/_inductor/graph.py b/torch/_inductor/graph.py
index 5114ffa..a47d9c1 100644
--- a/torch/_inductor/graph.py
+++ b/torch/_inductor/graph.py
@@ -27,7 +27,7 @@
needs_realized_inputs,
)
from .sizevars import SizeVarAllocator
-from .utils import dynamo_utils, gather_origins
+from .utils import dynamo_utils, gather_origins, get_dtype_size, sympy_product
from .virtualized import V
log = logging.getLogger(__name__)
@@ -356,6 +356,47 @@
self.scheduler.codegen()
return self.wrapper_code.generate()
+ def count_bytes(self):
+ from .scheduler import FusedSchedulerNode, NopKernelSchedulerNode, Scheduler
+
+ scheduler = Scheduler(self.buffers)
+
+ def get_read_write_buffers_sizes(node):
+ if isinstance(node, NopKernelSchedulerNode):
+ return 0
+ reads = set(dep.name for dep in node.read_writes.reads)
+ writes = set(dep.name for dep in node.read_writes.writes)
+
+ def is_materialized(buf):
+ buf_uses = set(
+ [user.node for user in scheduler.name_to_node[buf].users]
+ )
+ return len(buf_uses - set(node.snodes)) > 0
+
+ if isinstance(node, FusedSchedulerNode):
+ writes = set([dep for dep in writes if is_materialized(dep)])
+ node_bytes = 0
+ for buf in reads | writes:
+ if buf in self.name_to_buffer:
+ buf = self.name_to_buffer[buf]
+ elif buf in self.graph_inputs:
+ buf = self.graph_inputs[buf]
+ else:
+ continue
+
+ node_bytes += V.graph.sizevars.size_hint(
+ sympy_product(buf.get_size())
+ ) * get_dtype_size(buf.get_dtype())
+ return node_bytes
+
+ total_bytes = 0
+ node_counts = []
+ for node in scheduler.nodes:
+ num_bytes = get_read_write_buffers_sizes(node)
+ node_counts.append((node, num_bytes // 4))
+ total_bytes += num_bytes
+ return total_bytes, node_counts
+
@dynamo_utils.dynamo_timed
def compile_to_module(self):
from .codecache import PyCodeCache
diff --git a/torch/_inductor/metrics.py b/torch/_inductor/metrics.py
index 582c5ac..f7e0528 100644
--- a/torch/_inductor/metrics.py
+++ b/torch/_inductor/metrics.py
@@ -1,12 +1,17 @@
# counter for tracking how many kernels have been generated
generated_kernel_count = 0
generated_cpp_vec_kernel_count = 0
+num_bytes_accessed = 0
+nodes_num_elem = []
# reset all counters
def reset():
global generated_kernel_count
global generated_cpp_vec_kernel_count
+ global num_bytes_accessed, nodes_num_elem
generated_kernel_count = 0
generated_cpp_vec_kernel_count = 0
+ num_bytes_accessed = 0
+ nodes_num_elem.clear()
diff --git a/torch/_inductor/scheduler.py b/torch/_inductor/scheduler.py
index cb71a44..8609617 100644
--- a/torch/_inductor/scheduler.py
+++ b/torch/_inductor/scheduler.py
@@ -977,7 +977,7 @@
common_memory_deps = (node1.read_writes.reads | node1.read_writes.writes) & (
node2.read_writes.reads | node2.read_writes.writes
)
- return sum(dep.numel_hint() for dep in common_memory_deps)
+ return sum(dep.numbytes_hint() for dep in common_memory_deps)
def score_fusion_key(self, nodes):
"""
diff --git a/torch/_inductor/utils.py b/torch/_inductor/utils.py
index 08e95b9..62357be 100644
--- a/torch/_inductor/utils.py
+++ b/torch/_inductor/utils.py
@@ -375,3 +375,8 @@
def argsort(seq):
return sorted(range(len(seq)), key=seq.__getitem__)
+
+
+@functools.lru_cache(8)
+def get_dtype_size(dtype):
+ return torch.empty((), dtype=dtype).element_size()
diff --git a/torch/_inductor/virtualized.py b/torch/_inductor/virtualized.py
index 5d40d05..27e60b1 100644
--- a/torch/_inductor/virtualized.py
+++ b/torch/_inductor/virtualized.py
@@ -74,7 +74,7 @@
@staticmethod
def indirect_indexing(index_var):
- return sympy_symbol(str(index_var))
+ return sympy_symbol(f"({str(index_var)})")
@classmethod
def _init_cls(cls):