Added utility to count memory reads/written in Inductor (#89203)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/89203
Approved by: https://github.com/jansel, https://github.com/ngimel
diff --git a/test/inductor/test_perf.py b/test/inductor/test_perf.py
new file mode 100644
index 0000000..d473ff4
--- /dev/null
+++ b/test/inductor/test_perf.py
@@ -0,0 +1,434 @@
+# Owner(s): ["module: inductor"]
+import contextlib
+from unittest.mock import patch
+
+import torch._dynamo
+import torch._inductor.config as config
+from torch._dynamo.optimizations.backends import register_backend
+from torch._inductor import metrics
+from torch._inductor.compile_fx import compile_fx, count_bytes_inner
+from torch.testing._internal.common_utils import (
+    TEST_WITH_ROCM,
+    TestCase as TorchTestCase,
+)
+from torch.testing._internal.inductor_utils import HAS_CUDA
+
+aten = torch.ops.aten
+
+
+@register_backend
+def count_bytes_inductor(gm, example_inputs):
+    return compile_fx(gm, example_inputs, inner_compile=count_bytes_inner)
+
+
+@torch._dynamo.optimize("count_bytes_inductor")
+def f(x):
+    return torch.cat([x, x.cos()])
+
+
+def count_numel(f, *args):
+    """
+    Assumes all inputs are fp32
+    """
+    metrics.reset()
+    torch._dynamo.optimize("count_bytes_inductor")(f)(*args)
+    print(metrics.nodes_num_elem)
+    return str(metrics.num_bytes_accessed // 4)
+
+
+DEVICE = "cuda"
+
+
+def T(*size, dtype=torch.float32, device=DEVICE):
+    return torch.randn(size, dtype=dtype, device=device)
+
+
+def TI(*size, mx=10, dtype=torch.int32, device=DEVICE):
+    return torch.randint(0, mx, size, dtype=dtype, device=device)
+
+
+class TestCase(TorchTestCase):
+    device = DEVICE
+    pass
+
+
+class NumBytesMetricTests(TestCase):
+    """
+    Primarily used for sanity testing that the num_bytes_accessed metrics is correct.
+    """
+
+    def test_pointwise(self):
+        def f(x):
+            return x.cos()
+
+        inp = (T(10),)
+        self.assertExpectedInline(count_numel(f, *inp), """20""")
+
+        def f(x, y):
+            return x + y
+
+        inp = (T(10), T(10))
+        self.assertExpectedInline(count_numel(f, *inp), """30""")
+
+        def f(x, y):
+            return x + y
+
+        inp = (T(10, 10), T(10))
+        self.assertExpectedInline(count_numel(f, *inp), """210""")
+
+        def f(x):
+            return x + x
+
+        inp = (T(10),)
+        self.assertExpectedInline(count_numel(f, *inp), """20""")
+
+        def f(x):
+            return x + x.t()
+
+        inp = (T(10, 10),)
+        self.assertExpectedInline(count_numel(f, *inp), """200""")
+
+        def f(a, b, c):
+            return a.cos(), b.sin() + c.sin()
+
+        inp = (T(10), T(10), T(10))
+        self.assertExpectedInline(count_numel(f, *inp), """50""")
+
+    def test_reduction(self):
+        def f(x):
+            return x.sum(dim=1)
+
+        inp = (T(10, 10),)
+        self.assertExpectedInline(count_numel(f, *inp), """110""")
+
+        def f(x):
+            return x.sum(dim=0)
+
+        inp = (T(10, 10),)
+        self.assertExpectedInline(count_numel(f, *inp), """110""")
+
+    def test_extern(self):
+        def f(x):
+            return torch.mm(x, x)
+
+        inp = (T(10, 10),)
+        self.assertExpectedInline(count_numel(f, *inp), """200""")
+
+        def f(a, b):
+            return torch.mm(a, b)
+
+        inp = (T(10, 10), T(10, 10))
+        self.assertExpectedInline(count_numel(f, *inp), """300""")
+
+        def f(x):
+            x = x.cos()
+            x = torch.mm(x, x)
+            x = x.cos()
+            return x
+
+        inp = (T(10, 10),)
+        self.assertExpectedInline(count_numel(f, *inp), """600""")
+
+        def f(x):
+            a = x.cos()
+            b = x.sin()
+            x = torch.mm(a, b)
+            return x
+
+        inp = (T(10, 10),)
+        self.assertExpectedInline(count_numel(f, *inp), """600""")
+
+    def test_cat(self):
+        def f(a, b):
+            return torch.cat([a.sin(), b.sin()])
+
+        inp = (T(10), T(10))
+        self.assertExpectedInline(count_numel(f, *inp), """40""")
+
+        def f(a, b):
+            return torch.cat([a, b])
+
+        inp = (T(10), T(10))
+        self.assertExpectedInline(count_numel(f, *inp), """40""")
+
+        def f(a, b):
+            return torch.cat([a.cos(), b])
+
+        inp = (T(10), T(10))
+        self.assertExpectedInline(count_numel(f, *inp), """40""")
+
+        def f(a):
+            return torch.cat([a.cos(), a.sin()])
+
+        inp = (T(10),)
+        self.assertExpectedInline(count_numel(f, *inp), """30""")
+
+    def test_index(self):
+        def f(a, b):
+            return a[b]
+
+        inp = (T(10), TI(10, mx=10))
+        self.assertExpectedInline(count_numel(f, *inp), """30""")
+
+
+class FusionTests(TestCase):
+    """
+    Tests that things can be fused into a single kernel
+    """
+
+    def test_horizontal_reduction_pointwise(self):
+        def f(a):
+            b = a.sum(dim=1)
+            c = a.cos()
+            return b, c
+
+        inp = (T(10, 10),)
+        self.assertExpectedInline(count_numel(f, *inp), """210""")
+
+    def test_horizontal_reduction_reduction(self):
+        def f(a):
+            b = a.sum(dim=1)
+            c = a.amax(dim=1)
+            return b, c
+
+        inp = (T(10, 10),)
+        self.assertExpectedInline(count_numel(f, *inp), """120""")
+
+    def test_horizontal_reduction_pointwise2(self):
+        def f(a, b):
+            c = a.sum(dim=1)
+            b = b.cos()
+            return b + c
+
+        inp = (T(10, 10), T(10))
+        self.assertExpectedInline(count_numel(f, *inp), """120""")
+
+    def test_horizontal_reduction_outer_pointwise(self):
+        def f(a, b):
+            c = a.sum(dim=0)
+            b = b.cos()
+            return b + c
+
+        inp = (T(10, 10), T(10))
+        self.assertExpectedInline(count_numel(f, *inp), """120""")
+
+    def test_horizontal_sum_pw_broadcast(self):
+        def f(a, b):
+            a = a.sum(dim=1, keepdim=True)
+            b = b.cos()
+            return a * b
+
+        inp = (T(10, 10), T(10))
+        self.assertExpectedInline(count_numel(f, *inp), """210""")
+
+    def test_vertical_sum_pw(self):
+        def f(a):
+            a = a.cos()
+            a = a.sum(dim=1)
+            return a.cos()
+
+        inp = (T(10, 10),)
+        self.assertExpectedInline(count_numel(f, *inp), """110""")
+
+    def test_norm_chain(self):
+        def f(a):
+            b = a.sum(dim=1, keepdim=True)
+            a = a * b
+            b = a.sum(dim=1, keepdim=True)
+            a = a * b
+            b = a.sum(dim=1, keepdim=True)
+            a = a * b
+            return a
+
+        inp = (T(10, 10),)
+        self.assertExpectedInline(count_numel(f, *inp), """200""")
+
+    def test_softmax_inner(self):
+        def f(a):
+            return torch.softmax(a, dim=1)
+
+        inp = (T(10, 10),)
+        self.assertExpectedInline(count_numel(f, *inp), """200""")
+
+    def test_layer_norm(self):
+        # TODO: Suboptimal! We shouldn't need to save normalization stats.
+        mod = torch.nn.LayerNorm(10, device=self.device)
+
+        def f(x):
+            return mod(x)
+
+        inp = (T(10, 10),)
+        with torch.no_grad():
+            self.assertExpectedInline(count_numel(f, *inp), """220""")
+
+    def test_double_softmax(self):
+        def f(x):
+            x = torch.softmax(x, dim=1)
+            x = torch.softmax(x, dim=1)
+            return x
+
+        inp = (T(10, 10),)
+        self.assertExpectedInline(count_numel(f, *inp), """200""")
+
+    def test_softmax_backward(self):
+        def f(grad_out, out):
+            return aten._softmax_backward_data(grad_out, out, 1, torch.float32)
+
+        inp = (T(10, 10), T(10, 10))
+        self.assertExpectedInline(count_numel(f, *inp), """300""")
+
+    def test_neighbor(self):
+        def f(a, b):
+            return ((a - b) ** 2).sum(dim=-1).amax(dim=1)
+
+        inp = (T(10, 1, 4), T(1, 10, 4))
+        self.assertExpectedInline(count_numel(f, *inp), """90""")
+
+    def test_factory_reduction(self):
+        def f():
+            a = torch.ones(10, device=self.device)
+            b = torch.ones(10, 10, device=self.device)
+            return (a + b).sum(dim=-1)
+
+        inp = ()
+        self.assertExpectedInline(count_numel(f, *inp), """10""")
+
+    def test_index_pointwise(self):
+        def f(a, b):
+            return a[b].cos()
+
+        inp = (T(10, 10), TI(20, mx=10))
+        self.assertExpectedInline(count_numel(f, *inp), """320""")
+
+    def test_index_reduction(self):
+        def f(a, b):
+            return a[b].cos().sum(dim=1)
+
+        inp = (T(10, 10), TI(20, mx=10))
+        self.assertExpectedInline(count_numel(f, *inp), """140""")
+
+
+class SchedulerFusionTests(TestCase):
+    """
+    Testing the fusion group creation heuristic (i.e. cases where we can't fuse
+    everything into a single kernel)
+    Disables inductor rematerialization for easier reasoning of tests.
+    """
+
+    @classmethod
+    def setUpClass(cls):
+        super().setUpClass()
+        cls._stack = contextlib.ExitStack()
+        cls._stack.enter_context(patch.object(config, "realize_bytes_threshold", 0))
+
+    @classmethod
+    def tearDownClass(cls):
+        cls._stack.close()
+        super().tearDownClass()
+
+    def test_fusion_choice1(self):
+        # Doesn't matter where we break fusion group here
+        def f(a):
+            c = a.cos()
+            d = torch.mm(c, c)
+            e = c.cos()
+            return d + e
+
+        inp = (T(10, 10),)
+        self.assertExpectedInline(count_numel(f, *inp), """700""")
+
+    def test_fusion_choice2(self):
+        # We should materialize e (it's smaller!)
+        # [c, e]: 210, [f]: 210, [d]: 200
+        def f(a):
+            c = a.cos()
+            d = torch.mm(c, c)
+            e = c.sum(dim=1)
+            f = d + e
+            return f
+
+        inp = (T(10, 10),)
+        self.assertExpectedInline(count_numel(f, *inp), """620""")
+
+    def test_fusion_choice3(self):
+        # We should materialize e.
+        # [c, e]: 300, [f]: 300, [d]: 200
+        def f(a):
+            c = a.cos()
+            d = torch.mm(c, c)
+            e = c + a
+            f = d + e
+            return f, e
+
+        inp = (T(10, 10),)
+        self.assertExpectedInline(count_numel(f, *inp), """800""")
+
+
+class TilingTests(TestCase):
+    def test_tiling_simple(self):
+        def f(a, b):
+            return a + b.t()
+
+        inp = (T(10, 10), T(10, 10))
+        self.assertExpectedInline(count_numel(f, *inp), """300""")
+
+        def f(a, b):
+            return a.t() + b
+
+        inp = (T(10, 10), T(10, 10))
+        self.assertExpectedInline(count_numel(f, *inp), """300""")
+
+    def test_tiling_three(self):
+        def f(a, b, c):
+            return a + b.permute(1, 2, 0) + c.permute(2, 0, 1)
+
+        inp = (T(10, 10, 10), T(10, 10, 10), T(10, 10, 10))
+        self.assertExpectedInline(count_numel(f, *inp), """4000""")
+
+
+# Test cases where we don't do the right thing yet.
+class WouldBeNiceIfItWorked:
+    def test_horizontal(self):
+        def f(a):
+            b = a.sum(dim=0)
+            c = a.cos()
+            return b, c
+
+        inp = (T(10, 10),)
+        self.assertExpectedInline(count_numel(f, *inp), """210""")
+
+    # TODO: We aren't fusing outer dim softmaxes
+    def test_softmax_outer(self):
+        def f(a):
+            return torch.softmax(a, dim=0)
+
+        inp = (T(10, 10),)
+        self.assertExpectedInline(count_numel(f, *inp), """200""")
+
+    # TODO: The greedy fusion strategy results in suboptimal grouping
+    @patch.object(config, "realize_bytes_threshold", 0)
+    def test_fusion_choice4(self):
+        def f(a, b, b2):
+            c = a + b
+            d = torch.mm(c, c)
+            e = c + b + b2
+            f = d + e + b2
+            return f, e
+
+        inp = (T(10, 10), T(10, 10, dtype=torch.float16), T(10, 10))
+        self.assertExpectedInline(count_numel(f, *inp), """1000""")
+
+    # TODO: We materialize the intermediate if we don't unroll the reduction
+    def test_neighbor(self):
+        def f(a, b):
+            return ((a - b) ** 2).sum(dim=-1).amax(dim=1)
+
+        inp = (T(10, 1, 8), T(1, 10, 8))
+        self.assertExpectedInline(count_numel(f, *inp), """170""")
+
+
+if __name__ == "__main__":
+    from torch._dynamo.test_case import run_tests
+
+    if HAS_CUDA and not TEST_WITH_ROCM:
+        run_tests(needs="filelock")
diff --git a/torch/_inductor/compile_fx.py b/torch/_inductor/compile_fx.py
index 813daee..c482e55 100644
--- a/torch/_inductor/compile_fx.py
+++ b/torch/_inductor/compile_fx.py
@@ -12,7 +12,7 @@
 import torch.fx
 from torch._subclasses.fake_tensor import FakeTensor
 
-from . import config, overrides
+from . import config, metrics, overrides
 from .debug import DebugContext
 from .decomposition import select_decomp_table
 from .graph import GraphLowering
@@ -84,6 +84,22 @@
 
 
 @DebugContext.wrap
+def count_bytes_inner(gm, example_inputs, num_fixed=0, **kwargs):
+    shape_env = None
+    for inp in example_inputs:
+        if isinstance(inp, FakeTensor) and inp.fake_mode.shape_env is not None:
+            shape_env = inp.fake_mode.shape_env
+
+    graph = GraphLowering(gm, shape_env=shape_env, num_static_inputs=num_fixed)
+    with V.set_graph_handler(graph):
+        graph.run(*example_inputs)
+        num_bytes, nodes_num_elem = graph.count_bytes()
+        metrics.num_bytes_accessed += num_bytes
+        metrics.nodes_num_elem += nodes_num_elem
+    return make_boxed_func(gm.forward)
+
+
+@DebugContext.wrap
 @torch.utils._python_dispatch._disable_current_modes()
 def compile_fx_inner(
     gm: torch.fx.GraphModule,
@@ -326,7 +342,11 @@
 _graph_counter = itertools.count(0)
 
 
-def compile_fx(model_: torch.fx.GraphModule, example_inputs_: List[torch.Tensor]):
+def compile_fx(
+    model_: torch.fx.GraphModule,
+    example_inputs_: List[torch.Tensor],
+    inner_compile=compile_fx_inner,
+):
     """Main entrypoint to a compile given FX graph"""
 
     if not is_aot_autograd_safe_to_run(model_, example_inputs_):
@@ -348,7 +368,7 @@
     @dynamo_utils.dynamo_timed
     def fw_compiler(model: torch.fx.GraphModule, example_inputs):
         fixed = len(example_inputs) - num_example_inputs
-        return compile_fx_inner(
+        return inner_compile(
             model,
             example_inputs,
             num_fixed=fixed,
@@ -359,7 +379,7 @@
     @dynamo_utils.dynamo_timed
     def bw_compiler(model: torch.fx.GraphModule, example_inputs):
         fixed = count_tangents(model)
-        return compile_fx_inner(
+        return inner_compile(
             model,
             example_inputs,
             num_fixed=fixed,
diff --git a/torch/_inductor/dependencies.py b/torch/_inductor/dependencies.py
index 27c92f8..5434d7a 100644
--- a/torch/_inductor/dependencies.py
+++ b/torch/_inductor/dependencies.py
@@ -9,7 +9,14 @@
 
 from . import config
 from .codegen.common import index_prevent_reordering
-from .utils import sympy_product, sympy_str, sympy_subs, sympy_symbol, VarRanges
+from .utils import (
+    get_dtype_size,
+    sympy_product,
+    sympy_str,
+    sympy_subs,
+    sympy_symbol,
+    VarRanges,
+)
 from .virtualized import V
 
 log = logging.getLogger(__name__)
@@ -69,11 +76,18 @@
             return MemoryDep(renames[self.name], self.index, self.size)
         return self
 
-    def numel_hint(self):
+    def numbytes_hint(self):
         vars = set(self.index.free_symbols)
+        size_vars_used = []
+        for var in vars:
+            if var.name.startswith(canonicalization_prefix()):
+                # Sometimes with indirect indexing we have very weird symbol names
+                assert " " not in var.name
+                size_vars_used.append(int(var.name[len(canonicalization_prefix()) :]))
+
         return V.graph.sizevars.size_hint(
-            sympy_product([s for s in self.size if s in vars])
-        )
+            sympy_product([self.size[i] for i in size_vars_used])
+        ) * get_dtype_size(V.graph.get_dtype(self.name))
 
     def is_contiguous(self) -> bool:
         return isinstance(self.index, (sympy.Symbol, sympy.Integer))
@@ -88,8 +102,21 @@
             return StarDep(renames[self.name])
         return self
 
-    def numel_hint(self):
-        return 1
+    def numbytes_hint(self):
+        from .ir import MultiOutputLayout
+
+        if self.name in V.graph.name_to_buffer:
+            buf = V.graph.name_to_buffer[self.name]
+        elif self.name in V.graph.graph_inputs:
+            buf = V.graph.graph_inputs[self.name]
+        else:
+            return 1
+        if hasattr(buf, "layout") and isinstance(buf.layout, MultiOutputLayout):
+            # NB: Too annoying to acquire, should only be used for instrumentation
+            return 1
+        return V.graph.sizevars.size_hint(
+            sympy_product(buf.get_size())
+        ) * get_dtype_size(buf.get_dtype())
 
     def is_contiguous(self) -> bool:
         return False
diff --git a/torch/_inductor/graph.py b/torch/_inductor/graph.py
index 5114ffa..a47d9c1 100644
--- a/torch/_inductor/graph.py
+++ b/torch/_inductor/graph.py
@@ -27,7 +27,7 @@
     needs_realized_inputs,
 )
 from .sizevars import SizeVarAllocator
-from .utils import dynamo_utils, gather_origins
+from .utils import dynamo_utils, gather_origins, get_dtype_size, sympy_product
 from .virtualized import V
 
 log = logging.getLogger(__name__)
@@ -356,6 +356,47 @@
         self.scheduler.codegen()
         return self.wrapper_code.generate()
 
+    def count_bytes(self):
+        from .scheduler import FusedSchedulerNode, NopKernelSchedulerNode, Scheduler
+
+        scheduler = Scheduler(self.buffers)
+
+        def get_read_write_buffers_sizes(node):
+            if isinstance(node, NopKernelSchedulerNode):
+                return 0
+            reads = set(dep.name for dep in node.read_writes.reads)
+            writes = set(dep.name for dep in node.read_writes.writes)
+
+            def is_materialized(buf):
+                buf_uses = set(
+                    [user.node for user in scheduler.name_to_node[buf].users]
+                )
+                return len(buf_uses - set(node.snodes)) > 0
+
+            if isinstance(node, FusedSchedulerNode):
+                writes = set([dep for dep in writes if is_materialized(dep)])
+            node_bytes = 0
+            for buf in reads | writes:
+                if buf in self.name_to_buffer:
+                    buf = self.name_to_buffer[buf]
+                elif buf in self.graph_inputs:
+                    buf = self.graph_inputs[buf]
+                else:
+                    continue
+
+                node_bytes += V.graph.sizevars.size_hint(
+                    sympy_product(buf.get_size())
+                ) * get_dtype_size(buf.get_dtype())
+            return node_bytes
+
+        total_bytes = 0
+        node_counts = []
+        for node in scheduler.nodes:
+            num_bytes = get_read_write_buffers_sizes(node)
+            node_counts.append((node, num_bytes // 4))
+            total_bytes += num_bytes
+        return total_bytes, node_counts
+
     @dynamo_utils.dynamo_timed
     def compile_to_module(self):
         from .codecache import PyCodeCache
diff --git a/torch/_inductor/metrics.py b/torch/_inductor/metrics.py
index 582c5ac..f7e0528 100644
--- a/torch/_inductor/metrics.py
+++ b/torch/_inductor/metrics.py
@@ -1,12 +1,17 @@
 # counter for tracking how many kernels have been generated
 generated_kernel_count = 0
 generated_cpp_vec_kernel_count = 0
+num_bytes_accessed = 0
+nodes_num_elem = []
 
 
 # reset all counters
 def reset():
     global generated_kernel_count
     global generated_cpp_vec_kernel_count
+    global num_bytes_accessed, nodes_num_elem
 
     generated_kernel_count = 0
     generated_cpp_vec_kernel_count = 0
+    num_bytes_accessed = 0
+    nodes_num_elem.clear()
diff --git a/torch/_inductor/scheduler.py b/torch/_inductor/scheduler.py
index cb71a44..8609617 100644
--- a/torch/_inductor/scheduler.py
+++ b/torch/_inductor/scheduler.py
@@ -977,7 +977,7 @@
         common_memory_deps = (node1.read_writes.reads | node1.read_writes.writes) & (
             node2.read_writes.reads | node2.read_writes.writes
         )
-        return sum(dep.numel_hint() for dep in common_memory_deps)
+        return sum(dep.numbytes_hint() for dep in common_memory_deps)
 
     def score_fusion_key(self, nodes):
         """
diff --git a/torch/_inductor/utils.py b/torch/_inductor/utils.py
index 08e95b9..62357be 100644
--- a/torch/_inductor/utils.py
+++ b/torch/_inductor/utils.py
@@ -375,3 +375,8 @@
 
 def argsort(seq):
     return sorted(range(len(seq)), key=seq.__getitem__)
+
+
+@functools.lru_cache(8)
+def get_dtype_size(dtype):
+    return torch.empty((), dtype=dtype).element_size()
diff --git a/torch/_inductor/virtualized.py b/torch/_inductor/virtualized.py
index 5d40d05..27e60b1 100644
--- a/torch/_inductor/virtualized.py
+++ b/torch/_inductor/virtualized.py
@@ -74,7 +74,7 @@
 
     @staticmethod
     def indirect_indexing(index_var):
-        return sympy_symbol(str(index_var))
+        return sympy_symbol(f"({str(index_var)})")
 
     @classmethod
     def _init_cls(cls):