canonical example of torch.add benchmark (#23402)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/23402
This diff tries to make torch.add as a canonical example for op benchmark. Once it lands, we will also modify all other op benchmarks to be uniform with this example. With that, when people are adding new ops, they can copy paste any existing code.
Test Plan:
buck run mode/dev-nosan caffe2/benchmarks/operator_benchmark/pt:add_test -- --iterations 3
```
# ----------------------------------------
# PyTorch/Caffe2 Operator Micro-benchmarks
# ----------------------------------------
# Tag : short
# Benchmarking PyTorch: add
# Mode: Eager
# Name: add_M8_N16_K32_devicecpu
# Input: M: 8, N: 16, K: 32, device: cpu
Forward Execution Time (us) : 146.586
# Benchmarking PyTorch: add
# Mode: Eager
# Name: add_M8_N16_K32_devicecuda
# Input: M: 8, N: 16, K: 32, device: cuda
Forward Execution Time (us) : 92.151
# Benchmarking PyTorch: add
# Mode: Eager
# Name: add_M16_N16_K64_devicecpu
# Input: M: 16, N: 16, K: 64, device: cpu
Forward Execution Time (us) : 428.421
# Benchmarking PyTorch: add
# Mode: Eager
# Name: add_M16_N16_K64_devicecuda
# Input: M: 16, N: 16, K: 64, device: cuda
Forward Execution Time (us) : 89.811
# Benchmarking PyTorch: add
# Mode: Eager
# Name: add_M64_N64_K128_devicecpu
# Input: M: 64, N: 64, K: 128, device: cpu
Forward Execution Time (us) : 11857.012
# Benchmarking PyTorch: add
# Mode: Eager
# Name: add_M64_N64_K128_devicecuda
# Input: M: 64, N: 64, K: 128, device: cuda
Forward Execution Time (us) : 93.918
# Benchmarking PyTorch: add
# Mode: Eager
# Name: add_M8_N16_K32_devicecpu_bwdall
# Input: M: 8, N: 16, K: 32, device: cpu
Backward Execution Time (us) : 990.125
# Benchmarking PyTorch: add
# Mode: Eager
# Name: add_M8_N16_K32_devicecpu_bwd1
# Input: M: 8, N: 16, K: 32, device: cpu
Backward Execution Time (us) : 781.217
# Benchmarking PyTorch: add
# Mode: Eager
# Name: add_M8_N16_K32_devicecpu_bwd2
# Input: M: 8, N: 16, K: 32, device: cpu
Backward Execution Time (us) : 777.307
```
Reviewed By: zheng-xq
Differential Revision: D16501974
fbshipit-source-id: f1eec010eabf11ce4fcf6cfe6f85cd5241a7022d
diff --git a/benchmarks/operator_benchmark/benchmark_pytorch.py b/benchmarks/operator_benchmark/benchmark_pytorch.py
index 0371c74..fe558dc 100644
--- a/benchmarks/operator_benchmark/benchmark_pytorch.py
+++ b/benchmarks/operator_benchmark/benchmark_pytorch.py
@@ -94,11 +94,17 @@
""" this is a globally unique name which can be used to
label a specific test
"""
+
+ # This is a list of attributes which will not be included
+ # in the test name.
+ skip_key_list = ['device']
+
test_name_str = []
for key in kargs:
value = kargs[key]
test_name_str.append(
- key + str(value if type(value) != bool else int(value)))
+ ('' if key in skip_key_list else key)
+ + str(value if type(value) != bool else int(value)))
name = (self.module_name() + '_' +
'_'.join(test_name_str)).replace(" ", "")
return name
diff --git a/benchmarks/operator_benchmark/pt/add_test.py b/benchmarks/operator_benchmark/pt/add_test.py
index 3be7fb2..1e4c522 100644
--- a/benchmarks/operator_benchmark/pt/add_test.py
+++ b/benchmarks/operator_benchmark/pt/add_test.py
@@ -11,33 +11,49 @@
# Configs for PT add operator
add_long_configs = op_bench.cross_product_configs(
M=[8, 64, 128],
- N=range(2, 10, 3),
- K=[2 ** x for x in range(0, 3)],
+ N=range(2, 128, 64),
+ K=[8 ** x for x in range(0, 3)],
+ device=['cpu', 'cuda'],
tags=["long"]
)
add_short_configs = op_bench.config_list(
+ attr_names=["M", "N", "K"],
attrs=[
[64, 64, 64],
[64, 64, 128],
],
- attr_names=["M", "N", "K"],
- tags=["short"],
+ cross_product_configs={
+ 'device': ['cpu', 'cuda'],
+ },
+ tags=["short"],
)
class AddBenchmark(op_bench.TorchBenchmarkBase):
- def init(self, M, N, K):
- self.input_one = torch.rand(M, N, K)
- self.input_two = torch.rand(M, N, K)
+ def init(self, M, N, K, device):
+ self.input_one = torch.rand(M, N, K, device=device, requires_grad=self.auto_set())
+ self.input_two = torch.rand(M, N, K, device=device, requires_grad=self.auto_set())
self.set_module_name("add")
def forward(self):
return torch.add(self.input_one, self.input_two)
+# The generated test names based on add_short_configs will be in the following pattern:
+# add_M8_N16_K32_devicecpu
+# add_M8_N16_K32_devicecuda
+# add_M8_N16_K32_devicecpu_bwdall
+# add_M8_N16_K32_devicecpu_bwd1
+# add_M8_N16_K32_devicecpu_bwd2
+# add_M8_N16_K32_devicecuda_bwdall
+# add_M8_N16_K32_devicecuda_bwd1
+# add_M8_N16_K32_devicecuda_bwd2
+# ...
+# Those names can be used to filter tests.
op_bench.generate_pt_test(add_long_configs + add_short_configs, AddBenchmark)
+op_bench.generate_pt_gradient_test(add_long_configs + add_short_configs, AddBenchmark)
if __name__ == "__main__":