[vulkan][test] benchmark sub op (#82221)
Differential Revision: [D38153928](https://our.internmc.facebook.com/intern/diff/D38153928/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/82221
Approved by: https://github.com/SS-JIA
diff --git a/aten/src/ATen/test/vulkan_perf_test.cpp b/aten/src/ATen/test/vulkan_perf_test.cpp
index ca8604a..db9030c 100644
--- a/aten/src/ATen/test/vulkan_perf_test.cpp
+++ b/aten/src/ATen/test/vulkan_perf_test.cpp
@@ -709,6 +709,100 @@
#endif
}
+static void sub_op_benchmark(benchmark::State& state) {
+ // Guard
+ if (!at::is_vulkan_available()) {
+ return;
+ }
+
+ // Arrange
+ const auto batches = state.range(0);
+ const auto channels = state.range(1);
+ const auto height = state.range(2);
+ const auto width = state.range(3);
+ const auto in_cpu1 = at::rand(
+ {batches, channels, height, width},
+ at::device(at::kCPU).dtype(at::kFloat));
+ const auto in_cpu2 = at::rand(
+ {batches, channels, height, width},
+ at::device(at::kCPU).dtype(at::kFloat));
+ const auto in_vulkan1 = in_cpu1.vulkan();
+ const auto in_vulkan2 = in_cpu2.vulkan();
+
+#if defined(USE_VULKAN_GPU_DIAGNOSTICS) && defined(__ANDROID__)
+ at::native::vulkan::api::context()->reset_querypool();
+#endif
+
+ // Act
+ for (auto _ : state) {
+ auto start = std::chrono::high_resolution_clock::now();
+ const auto vulkan_out = at::sub(in_vulkan1, in_vulkan2).cpu();
+ auto end = std::chrono::high_resolution_clock::now();
+ auto elapsed =
+ std::chrono::duration_cast<std::chrono::duration<double>>(end - start);
+ state.SetIterationTime(elapsed.count());
+ }
+
+#if defined(USE_VULKAN_GPU_DIAGNOSTICS) && defined(__ANDROID__)
+ at::native::vulkan::api::context()->querypool().extract_results();
+ at::native::vulkan::api::context()->querypool().print_results();
+ state.SetIterationTime(at::native::vulkan::api::context()->querypool().get_total_op_ns("sub") / 1000000.0);
+#endif
+}
+
+static void sub_op_q_benchmark(benchmark::State& state) {
+ // Guard
+ if (!at::is_vulkan_available()) {
+ return;
+ }
+
+ // Arrange
+ const auto batches = state.range(0);
+ const auto channels = state.range(1);
+ const auto height = state.range(2);
+ const auto width = state.range(3);
+ const auto in_cpu1 = at::rand(
+ {batches, channels, height, width},
+ at::device(at::kCPU).dtype(at::kFloat));
+ const auto in_cpu2 = at::rand(
+ {batches, channels, height, width},
+ at::device(at::kCPU).dtype(at::kFloat));
+ const auto in_vulkan1 = in_cpu1.vulkan();
+ const auto in_vulkan2 = in_cpu2.vulkan();
+ const double scale = 0.1;
+ const int zero_point = 10;
+ const auto out_cpu1 = at::quantize_per_tensor(
+ in_cpu1, scale, zero_point, c10::ScalarType::QUInt8);
+ const auto out_vulkan1 = at::native::vulkan::ops::quantize_per_tensor(
+ in_vulkan1, scale, zero_point, c10::ScalarType::QUInt8);
+ const auto out_vulkan2 = at::native::vulkan::ops::quantize_per_tensor(
+ in_vulkan2, scale, zero_point, c10::ScalarType::QUInt8);
+
+#if defined(USE_VULKAN_GPU_DIAGNOSTICS) && defined(__ANDROID__)
+ at::native::vulkan::api::context()->reset_querypool();
+#endif
+
+ // Act
+ const double scale2 = 0.15;
+ const int zero_point2 = 15;
+ for (auto _ : state) {
+ auto start = std::chrono::high_resolution_clock::now();
+ const auto vulkan_sub = at::native::vulkan::ops::quantized_sub(
+ out_vulkan1, out_vulkan2, scale2, zero_point2);
+ const auto vulkan_out = vulkan_to_cpu(vulkan_sub, out_cpu1);
+ auto end = std::chrono::high_resolution_clock::now();
+ auto elapsed =
+ std::chrono::duration_cast<std::chrono::duration<double>>(end - start);
+ state.SetIterationTime(elapsed.count());
+ }
+
+#if defined(USE_VULKAN_GPU_DIAGNOSTICS) && defined(__ANDROID__)
+ at::native::vulkan::api::context()->querypool().extract_results();
+ at::native::vulkan::api::context()->querypool().print_results();
+ state.SetIterationTime(at::native::vulkan::api::context()->querypool().get_total_op_ns("quantized_sub") / 1000000.0);
+#endif
+}
+
static void CommonBenchmarkSettings(benchmark::internal::Benchmark* b) {
b->Unit(benchmark::kMillisecond);
b->ArgNames({"N", "C", "H", "W"});
@@ -764,6 +858,18 @@
->Threads(1)
->Iterations(10)
->Args({1, 7, 137, 199});
+BENCHMARK(sub_op_benchmark)
+ ->Apply(CommonBenchmarkSettings)
+ ->UseManualTime()
+ ->Threads(1)
+ ->Iterations(100)
+ ->Args({3, 40, 221, 193});
+BENCHMARK(sub_op_q_benchmark)
+ ->Apply(CommonBenchmarkSettings)
+ ->UseManualTime()
+ ->Threads(1)
+ ->Iterations(100)
+ ->Args({3, 40, 221, 193});
BENCHMARK_MAIN();