[Profiler] Add speedup estimate for FP32 pattern and Extra CUDA Copy Pattern (#81501)
Summary: The main idea is that we can run some baseline benchmarks after we are done matching the events. This gives us ability to accurate measure speed gain because system performance varies from machine to machine.
Test Plan: I did some manually testing on all the models in torchbench, as well as added a simple test in test_profiler.py
Differential Revision: [D37894566](https://our.internmc.facebook.com/intern/diff/D37894566)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/81501
Approved by: https://github.com/robieta
diff --git a/test/test_profiler.py b/test/test_profiler.py
index 3b8d4bd..a97a92c 100644
--- a/test/test_profiler.py
+++ b/test/test_profiler.py
@@ -1562,7 +1562,7 @@
)
num_matched = []
for _, fn in cases:
- with profile(with_stack=True) as prof:
+ with profile(with_stack=True, record_shapes=True) as prof:
fn()
pattern = ExtraCUDACopyPattern(prof)
num_matched.append(len(pattern.matched_events()))
@@ -1617,5 +1617,15 @@
self.assertEqual(num_matched, has_tf32)
+ @unittest.skipIf(not torch.cuda.is_available(), "CUDA is required")
+ def test_profiler_extra_cuda_copy_pattern_benchmark(self):
+ with profile(with_stack=True, record_shapes=True) as prof:
+ x = torch.ones((100, 100)).to("cuda")
+ x = torch.ones((50, 50)).to("cuda")
+ pattern = ExtraCUDACopyPattern(prof)
+ shapes_factor_map = pattern.benchmark(pattern.matched_events())
+ self.assertEqual(len(shapes_factor_map), 2)
+
+
if __name__ == '__main__':
run_tests()
diff --git a/torch/_C/_autograd.pyi b/torch/_C/_autograd.pyi
index 7b122e3..795cf51 100644
--- a/torch/_C/_autograd.pyi
+++ b/torch/_C/_autograd.pyi
@@ -113,9 +113,13 @@
PyCCall = ...
TorchOp = ...
+class _Inputs:
+ shapes: List[List[int]]
+ dtypes: List[str]
+
class _ExtraFields_TorchOp:
- inputs: List[List[int]]
allow_tf32_cublas: bool
+ inputs: _Inputs
...
class _ExtraFields_Backend:
diff --git a/torch/profiler/_pattern_matcher.py b/torch/profiler/_pattern_matcher.py
index 078f1f5..abad734 100644
--- a/torch/profiler/_pattern_matcher.py
+++ b/torch/profiler/_pattern_matcher.py
@@ -4,6 +4,7 @@
import torch
from torch.profiler import profile
+import torch.utils.benchmark as benchmark
from torch.profiler._utils import index_of_first_match
from torch._C._autograd import (_ProfilerEvent, _ExtraFields_TorchOp,
_ExtraFields_Backend, _ExtraFields_Allocation,
@@ -19,8 +20,10 @@
In subclass, define description and skip property.
'''
- def __init__(self, prof: profile):
+ def __init__(self, prof: profile, should_benchmark: bool = False):
self.prof = prof
+ self.should_benchmark = should_benchmark
+ self.name = "Please specify a name for pattern"
self.description = "Please specify a description for pattern"
assert prof.profiler is not None and prof.profiler.kineto_results is not None
self.event_tree = prof.profiler.kineto_results.experimental_event_tree(
@@ -49,6 +52,17 @@
for child_event in curr_event.children:
stack.append(child_event)
+ def summary(self, events: List[_ProfilerEvent]):
+ default_summary = f"{self.name}: {len(events)} events matched."
+ if self.should_benchmark:
+ summary = self.benchmark_summary(events)
+ # If benchmark summary is not empty, use it.
+ return summary if summary else default_summary
+ return default_summary
+
+ def benchmark_summary(self, events: List[_ProfilerEvent]):
+ return ""
+
def match(self, event: _ProfilerEvent):
'''
Return True if the event matches the pattern.
@@ -92,8 +106,8 @@
class NamePattern(Pattern):
- def __init__(self, prof: profile, name: str):
- super().__init__(prof)
+ def __init__(self, prof: profile, name: str, should_benchmark: bool = False):
+ super().__init__(prof, should_benchmark)
self.description = f"Matched Name Event: {name}"
self.name = name
@@ -118,8 +132,9 @@
If at any step we failed, it is not a match.
'''
- def __init__(self, prof: profile):
- super().__init__(prof)
+ def __init__(self, prof: profile, should_benchmark: bool = False):
+ super().__init__(prof, should_benchmark)
+ self.name = "Extra CUDA Copy Pattern"
self.description = "Filled a CPU tensor and immediately moved it to GPU. Please initalize it on GPU."
self.init_ops = {
"aten::fill_", "aten::zero_", "aten::normal_", "aten::uniform_"
@@ -127,7 +142,7 @@
@property
def skip(self):
- return self.prof.with_stack is False
+ return not self.prof.with_stack or not self.prof.record_shapes
def match(self, event):
# TODO: We should also check tensor identities
@@ -149,6 +164,33 @@
return event.name() in self.init_ops
# TODO: Check if tensor is reused
+ def benchmark(self, events: List[_ProfilerEvent]):
+ shapes_factor_map = {input_shapes(event)[0]: 0.0 for event in events}
+ for shape in shapes_factor_map:
+ to_timer = benchmark.Timer(stmt='torch.ones(shape).to("cuda")',
+ globals={'shape': shape})
+ de_timer = benchmark.Timer(stmt='torch.ones(shape, device="cuda")',
+ globals={'shape': shape})
+ to_time = to_timer.timeit(10).mean
+ de_time = de_timer.timeit(10).mean
+ shapes_factor_map[shape] = de_time / to_time
+ return shapes_factor_map
+
+ def report(self, event: _ProfilerEvent):
+ msg = f"{self.description}\n{source_code_location(event)}"
+ return msg
+
+ def benchmark_summary(self, events: List[_ProfilerEvent]):
+ shapes_factor_map = self.benchmark(events)
+ original_time = sum(event.duration_time_ns for event in events) / 1e3
+ new_time = sum(
+ shapes_factor_map[input_shapes(event)[0]] * event.duration_time_ns
+ for event in events) / 1e3
+ return (
+ f"{self.name}: {len(events)} events matched. "
+ f"Total Estimated Speedup: {original_time - new_time}us ({original_time/new_time}X)"
+ )
+
class ForLoopIndexingPattern(Pattern):
'''
@@ -167,8 +209,9 @@
We also keep a dictionary to avoid duplicate match in the for loop.
'''
- def __init__(self, prof: profile):
- super().__init__(prof)
+ def __init__(self, prof: profile, should_benchmark: bool = False):
+ super().__init__(prof, should_benchmark)
+ self.name = "For Loop Indexing Pattern"
self.description = "For loop indexing detected. Vectorization recommended."
self.visited: Set[int] = set()
@@ -220,8 +263,9 @@
class FP32MatMulPattern(Pattern):
- def __init__(self, prof: profile):
- super().__init__(prof)
+ def __init__(self, prof: profile, should_benchmark: bool = False):
+ super().__init__(prof, should_benchmark)
+ self.name = "FP32 MatMul Pattern"
self.description = (
"You are currently using GPU that supports TF32. "
"Please enable TF32 by setting 'torch.backends.cuda.matmul.allow_tf32 = True'"
@@ -229,9 +273,10 @@
@property
def skip(self):
+ # Anything less than sm_80 is not Ampere which doesn't support TF32
has_tf32 = all(
int(arch[3:]) >= 80 for arch in torch.cuda.get_arch_list())
- return has_tf32 is False
+ return has_tf32 is False or super().skip or not self.prof.record_shapes
def match(self, event: _ProfilerEvent):
# If we saw this pattern once, we don't need to match it again
@@ -246,6 +291,40 @@
def report(self, event: _ProfilerEvent):
return self.description
+ def benchmark(self, events: List[_ProfilerEvent]):
+ shapes_factor_map = {input_shapes(event): 0.0 for event in events}
+ for shape in shapes_factor_map:
+ matrixA = torch.randn(shape[0], device="cuda", dtype=torch.float32)
+ matrixB = torch.randn(shape[1], device="cuda", dtype=torch.float32)
+ fp32_timer = benchmark.Timer(stmt='torch.mm(matrixA, matrixB)',
+ globals={
+ "matrixA": matrixA,
+ "matrixB": matrixB
+ })
+ tf32_timer = benchmark.Timer(
+ stmt='torch.mm(matrixA, matrixB)',
+ setup='torch.backends.cuda.matmul.allow_tf32 = True',
+ globals={
+ "matrixA": matrixA,
+ "matrixB": matrixB
+ })
+ torch.backends.cuda.matmul.allow_tf32 = False
+ fp32_time = fp32_timer.timeit(10).mean
+ tf32_time = tf32_timer.timeit(10).mean
+ shapes_factor_map[shape] = tf32_time / fp32_time
+ return shapes_factor_map
+
+ def benchmark_summary(self, events: List[_ProfilerEvent]):
+ shapes_factor_map = self.benchmark(events)
+ original_time = sum(event.duration_time_ns for event in events) / 1e3
+ new_time = sum(
+ shapes_factor_map[input_shapes(event)] * event.duration_time_ns
+ for event in events) / 1e3
+ return (
+ f"{self.name}: {len(events)} events matched. "
+ f"Total Estimated Speedup: {original_time - new_time}us ({original_time/new_time}X)"
+ )
+
def source_code_location(event: _ProfilerEvent):
while event:
@@ -259,20 +338,35 @@
return "No source code location found"
-def report_all_anti_patterns(prof):
+def input_shapes(event: _ProfilerEvent):
+ assert isinstance(event.extra_fields, _ExtraFields_TorchOp)
+ return tuple([tuple(shape) for shape in event.extra_fields.inputs.shapes])
+
+
+def report_all_anti_patterns(prof, should_benchmark: bool = False):
anti_patterns = [
- ExtraCUDACopyPattern(prof),
- ForLoopIndexingPattern(prof),
- FP32MatMulPattern(prof)
+ ExtraCUDACopyPattern(prof, should_benchmark),
+ ForLoopIndexingPattern(prof, should_benchmark),
+ FP32MatMulPattern(prof, should_benchmark)
]
reported = set()
- print(f"{'-'*40}TorchTidy Report{'-'*40}")
+ summaries = []
+ message_list = [f"{'-'*40}TorchTidy Report{'-'*40}"]
+ message_list.append("Matched Events:")
for anti_pattern in anti_patterns:
- for event in anti_pattern.matched_events():
+ matched_events = anti_pattern.matched_events()
+ if not matched_events:
+ continue
+ summaries.append(anti_pattern.summary(matched_events))
+ for event in matched_events:
report_msg = anti_pattern.report(event)
if report_msg not in reported:
- print(report_msg)
+ message_list.append(report_msg)
reported.add(report_msg)
+ message_list.append("Summary:")
+ message_list += summaries
+ message_list.append(f"{'-'*40}TorchTidy Report{'-'*40}")
+ print("\n".join(message_list))
def event_type(event: _ProfilerEvent):