[pytorch-vulkan] Support zero-dim (#111680)
Summary:
1. Add zero-dim (Tensor with 1 element) support.
2. New operator `_local_scalar_dense` that map a zero-dim tensor into a Scalar
3. `sum_dim`:
3.1. Add zero-dim support.
3.2. Fix bug in negative indices when handling multi-dim reduction call
3.3. Add unittests to test new coverages
4. Add `aten::sum` support.
5. Change bug in `add_tensor` (and other binary ops), when `other` is zero dim, we will use broadcast instead.
Test Plan:
## Devserver
Full Paste: P858982150
```
[yipjustin@31799.od ~/fbsource (8593e7559)]$ LD_LIBRARY_PATH=third-party/swiftshader/lib/linux-x64/ buck2 run fbcode/mode/dev-nosan -c pt.has_backtraces=1 //xplat/caffe2:pt_vulkan_api_test_bin --
File changed: fbsource//xplat/caffe2/aten/src/ATen/test/vulkan_api_test.cpp
Buck UI: https://www.internalfb.com/buck2/90cad0ff-ac98-4dbf-8d6f-0e419c06208d
Network: Up: 43KiB Down: 1.4MiB (reSessionID-dfc3a318-fd1a-4ad6-b077-c454ebb4c6a8)
Jobs completed: 6. Time elapsed: 26.4s.
Cache hits: 0%. Commands: 2 (cached: 0, remote: 1, local: 1)
BUILD SUCCEEDED
Running main() from third-party/googletest/1.11.0/googletest/googletest/src/gtest_main.cc
[==========] Running 385 tests from 1 test suite.
[----------] Global test environment set-up.
[----------] 385 tests from VulkanAPITest
[ RUN ] VulkanAPITest.zero_size_tensor
[ OK ] VulkanAPITest.zero_size_tensor (9 ms)
[ RUN ] VulkanAPITest.zero_dim_tensor_1
[ OK ] VulkanAPITest.zero_dim_tensor_1 (84 ms)
[ RUN ] VulkanAPITest.zero_dim_tensor_2
[ OK ] VulkanAPITest.zero_dim_tensor_2 (22 ms)
[ RUN ] VulkanAPITest.local_scalar_dense
[ OK ] VulkanAPITest.local_scalar_dense (10 ms)
...
[ OK ] VulkanAPITest.lstm_prepack_success (2 ms)
[ RUN ] VulkanAPITest.querypool_flushed_shader_log
xplat/caffe2/aten/src/ATen/test/vulkan_api_test.cpp:7484: Skipped
QueryPool is not available
[ SKIPPED ] VulkanAPITest.querypool_flushed_shader_log (0 ms)
[----------] 385 tests from VulkanAPITest (46915 ms total)
[----------] Global test environment tear-down
[==========] 385 tests from 1 test suite ran. (46915 ms total)
[ PASSED ] 382 tests.
[ SKIPPED ] 1 test, listed below:
[ SKIPPED ] VulkanAPITest.querypool_flushed_shader_log
[ FAILED ] 2 tests, listed below:
[ FAILED ] VulkanAPITest.conv2d_pw_prepack
[ FAILED ] VulkanAPITest.conv2d_pw_prepack_bc
2 FAILED TESTS
YOU HAVE 7 DISABLED TESTS
```
## M1 MAC
P859975219
```
buck run //xplat/caffe2:pt_vulkan_api_test_binAppleMac\#macosx-arm64 --target-platforms ovr_config//platform/macos:arm64-fbsource -- --gtest_filter="*"
Using additional configuration options from .buckconfig.local
Building: finished in 0.2 sec (100%) 269/2875 jobs, 0/2875 updated
Total time: 0.2 sec
BUILD SUCCEEDED
Running main() from third-party/googletest/1.11.0/googletest/googletest/src/gtest_main.cc
[==========] Running 384 tests from 1 test suite.
[----------] Global test environment set-up.
[----------] 384 tests from VulkanAPITest
[ RUN ] VulkanAPITest.zero_size_tensor
[ OK ] VulkanAPITest.zero_size_tensor (40 ms)
[ RUN ] VulkanAPITest.zero_dim_tensor_1
[ OK ] VulkanAPITest.zero_dim_tensor_1 (7 ms)
[ RUN ] VulkanAPITest.zero_dim_tensor_2
[ OK ] VulkanAPITest.zero_dim_tensor_2 (1 ms)
[ RUN ] VulkanAPITest.local_scalar_dense
[ OK ] VulkanAPITest.local_scalar_dense (0 ms)
[ RUN ] VulkanAPITest.copy_to_texture
[ OK ] VulkanAPITest.copy_to_texture (45 ms)
...
[ SKIPPED ] VulkanAPITest.querypool_flushed_shader_log (0 ms)
[----------] 384 tests from VulkanAPITest (5127 ms total)
[----------] Global test environment tear-down
[==========] 384 tests from 1 test suite ran. (5127 ms total)
[ PASSED ] 382 tests.
[ SKIPPED ] 1 test, listed below:
[ SKIPPED ] VulkanAPITest.querypool_flushed_shader_log
[ FAILED ] 1 test, listed below:
[ FAILED ] VulkanAPITest.normal_large
1 FAILED TEST
YOU HAVE 5 DISABLED TESTS
```
Differential Revision: D50347338
Pull Request resolved: https://github.com/pytorch/pytorch/pull/111680
Approved by: https://github.com/SS-JIA
diff --git a/aten/src/ATen/native/vulkan/api/Tensor.cpp b/aten/src/ATen/native/vulkan/api/Tensor.cpp
index b601b1c..35ea94c 100644
--- a/aten/src/ATen/native/vulkan/api/Tensor.cpp
+++ b/aten/src/ATen/native/vulkan/api/Tensor.cpp
@@ -168,14 +168,30 @@
// packed dimension.
else {
TORCH_CHECK(
- ndim >= 1 && ndim <= 4,
- "Texture storage only valid for 1 <= ndim <= 4, received: ",
+ ndim >= 0 && ndim <= 4,
+ "Texture storage only valid for 0 <= ndim <= 4, received: ",
ndim);
c10::SmallVector<int64_t, 6u> gpu_sizes(ndim == 4 ? 4 : 3);
// Channel dim will be be aligned to the next multiple of 4
switch (ndim) {
+ case 0:
+ switch (memory_layout) {
+ case api::GPUMemoryLayout::TENSOR_CHANNELS_PACKED:
+ // 0-dimension tensors only has 1 element. Hence it is always {4, 1,
+ // 1} when stored as image textures. Channels need to be multiple of
+ // 4 due to packing.
+ gpu_sizes[0] = 4;
+ gpu_sizes[1] = 1;
+ gpu_sizes[2] = 1;
+ break;
+ default:
+ TORCH_CHECK(
+ false,
+ "Invalid memory format used to create vTensor with zero-dim!");
+ }
+ break;
case 1:
switch (memory_layout) {
case api::GPUMemoryLayout::TENSOR_WIDTH_PACKED:
diff --git a/aten/src/ATen/native/vulkan/ops/BinaryOp.cpp b/aten/src/ATen/native/vulkan/ops/BinaryOp.cpp
index 03931e4..754fa49 100644
--- a/aten/src/ATen/native/vulkan/ops/BinaryOp.cpp
+++ b/aten/src/ATen/native/vulkan/ops/BinaryOp.cpp
@@ -402,13 +402,6 @@
const Tensor& self_arg,
const Tensor& other_arg,
const Scalar& alpha) {
- if (other_arg.dim() == 0) {
- return binary_op_scalar(
- self_arg,
- other_arg.item(),
- c10::optional<Scalar>(),
- VK_KERNEL(add_scalar));
- }
return binary_op_tensor(
self_arg, other_arg, c10::optional<Scalar>(alpha), VK_KERNEL(add));
}
@@ -444,13 +437,6 @@
const Tensor& self_arg,
const Tensor& other_arg,
const Scalar& alpha) {
- if (other_arg.dim() == 0) {
- return binary_op_scalar(
- self_arg,
- other_arg.item(),
- c10::optional<Scalar>(-1 * alpha.to<float>()),
- VK_KERNEL(add_scalar));
- }
return binary_op_tensor(
self_arg, other_arg, c10::optional<Scalar>(alpha), VK_KERNEL(sub));
}
@@ -474,13 +460,6 @@
}
Tensor mul_tensor(const Tensor& self_arg, const Tensor& other_arg) {
- if (other_arg.dim() == 0) {
- return binary_op_scalar(
- self_arg,
- other_arg.item(),
- c10::optional<Scalar>(),
- VK_KERNEL(mul_scalar));
- }
return binary_op_tensor(
self_arg, other_arg, c10::optional<Scalar>(), VK_KERNEL(mul));
}
@@ -507,13 +486,6 @@
}
Tensor div_tensor(const Tensor& self_arg, const Tensor& other_arg) {
- if (other_arg.dim() == 0) {
- return binary_op_scalar(
- self_arg,
- 1.0 / other_arg.item().to<float>(),
- c10::optional<Scalar>(),
- VK_KERNEL(mul_scalar));
- }
return binary_op_tensor(
self_arg, other_arg, c10::optional<Scalar>(), VK_KERNEL(div));
}
diff --git a/aten/src/ATen/native/vulkan/ops/Scalar.cpp b/aten/src/ATen/native/vulkan/ops/Scalar.cpp
new file mode 100644
index 0000000..52350f3
--- /dev/null
+++ b/aten/src/ATen/native/vulkan/ops/Scalar.cpp
@@ -0,0 +1,33 @@
+#include <ATen/native/vulkan/ops/Common.h>
+
+#include <torch/library.h>
+
+namespace at {
+namespace native {
+namespace vulkan {
+namespace ops {
+namespace {
+
+using namespace api::utils;
+
+Scalar _local_scalar_dense(const Tensor& self) {
+ TORCH_CHECK(
+ self.dtype() == ScalarType::Float, "Only float dtype is supported");
+ return Scalar(self.cpu().item<float>());
+}
+
+#ifdef USE_VULKAN_API
+
+TORCH_LIBRARY_IMPL(aten, Vulkan, m) {
+ m.impl(
+ TORCH_SELECTIVE_NAME("aten::_local_scalar_dense"),
+ TORCH_FN(_local_scalar_dense));
+}
+
+#endif /* USE_VULKAN_API */
+
+} // namespace
+} // namespace ops
+} // namespace vulkan
+} // namespace native
+} // namespace at
diff --git a/aten/src/ATen/native/vulkan/ops/Sum.cpp b/aten/src/ATen/native/vulkan/ops/Sum.cpp
index 0033731..14d9f22 100644
--- a/aten/src/ATen/native/vulkan/ops/Sum.cpp
+++ b/aten/src/ATen/native/vulkan/ops/Sum.cpp
@@ -16,8 +16,8 @@
bool keepdim,
const optional<ScalarType> dtype) {
TORCH_CHECK(
- self.dim() >= 2 && self.dim() <= 4,
- "Vulkan sum.dim_IntList supports 2d, 3d, 4d tensors as input!");
+ self.dim() >= 1 && self.dim() <= 4,
+ "Vulkan sum.dim_IntList supports 1d, 2d, 3d, 4d tensors as input!");
// Get the global Vulkan context
api::Context* const context = api::context();
@@ -26,9 +26,6 @@
const Tensor input = self.is_vulkan() ? self : self.vulkan();
const vTensor& v_input = convert(input);
- // Normalize dim into range [0, self.dim()]
- dim = utils::normalize(dim, self.dim());
-
// Create the output texture
std::vector<int64_t> output_size = self.sizes().vec();
uint32_t dim_size = output_size[dim];
@@ -102,16 +99,19 @@
std::set<int64_t> dims_set;
if (opt_dim.has_value()) {
auto dims = opt_dim.value();
- for (const auto& d : dims) {
+ for (const auto& dim : dims) {
+ // Do dim check before normalization to report to specified wrong dim
+ // value to user
TORCH_CHECK(
- d >= -self.dim() && d < self.dim(),
+ dim >= -self.dim() && dim <= self.dim() - 1,
"Vulkan sum.dim_IntList dimension out of range expected to be in range of [",
-self.dim(),
",",
self.dim() - 1,
"], but got ",
- d);
- int64_t dim_normalized = utils::normalize(d, self.dim());
+ dim);
+ // Normalize dim into range [0, self.dim() - 1]
+ int64_t dim_normalized = utils::normalize(dim, self.dim());
if (dims_set.find(dim_normalized) != dims_set.end()) {
TORCH_CHECK(
false,
@@ -122,6 +122,8 @@
dims_set.insert(dim_normalized);
}
Tensor result = self;
+ // Reduce the higher dimensionalities first, otherwise when keepdim is
+ // false, it will be reducing the wrong dimension.
for (auto it = dims_set.rbegin(); it != dims_set.rend(); ++it) {
result = sum_dim(result, *it, keepdim, dtype);
}
@@ -130,11 +132,21 @@
return self;
}
+Tensor sum(const Tensor& self, const c10::optional<ScalarType> dtype) {
+ std::vector<int64_t> dims;
+ for (int64_t d = 0; d < self.dim(); d++) {
+ dims.push_back(d);
+ }
+
+ return sum_dim_IntList(self, dims, false, dtype);
+}
+
#ifdef USE_VULKAN_API
TORCH_LIBRARY_IMPL(aten, Vulkan, m) {
m.impl(
TORCH_SELECTIVE_NAME("aten::sum.dim_IntList"), TORCH_FN(sum_dim_IntList));
+ m.impl(TORCH_SELECTIVE_NAME("aten::sum"), TORCH_FN(sum));
}
#endif /* USE_VULKAN_API */
diff --git a/aten/src/ATen/test/vulkan_api_test.cpp b/aten/src/ATen/test/vulkan_api_test.cpp
index 6d61107..c15fb63 100644
--- a/aten/src/ATen/test/vulkan_api_test.cpp
+++ b/aten/src/ATen/test/vulkan_api_test.cpp
@@ -282,6 +282,34 @@
ASSERT_TRUE(at::equal(out_vk, cpu));
}
+TEST_F(VulkanAPITest, zero_dim_tensor_1) {
+ auto cpu = at::rand({}, at::device(at::kCPU).dtype(at::kFloat));
+ auto vv = cpu.item<float>();
+
+ auto vk = cpu.vulkan();
+ auto out_vk = vk.cpu();
+ ASSERT_TRUE(almostEqual(cpu, out_vk));
+
+ auto vk_vv = out_vk.item<float>();
+ EXPECT_NEAR(vv, vk_vv, kTolerance);
+}
+
+TEST_F(VulkanAPITest, zero_dim_tensor_2) {
+ float v = 3.14f;
+ auto cpu = at::empty({}, at::device(at::kCPU).dtype(at::kFloat)) + v;
+ auto vk = at::empty({}, at::device(at::kVulkan).dtype(at::kFloat)) + v;
+
+ ASSERT_TRUE(almostEqual(cpu, vk.cpu()));
+}
+
+TEST_F(VulkanAPITest, local_scalar_dense) {
+ float v = 8.31f;
+ // Force the zero-dim tensor to a non-zero constant v.
+ auto vk = at::zeros({}, at::device(at::kVulkan).dtype(at::kFloat)) + v;
+ c10::Scalar scalar = at::_local_scalar_dense(vk);
+ EXPECT_NEAR(v, scalar.toFloat(), kTolerance);
+}
+
TEST_F(VulkanAPITest, copy_to_texture) {
using namespace at::native::vulkan;
at::Tensor test_tensors[] = {
@@ -436,6 +464,10 @@
test_add({1, 15, 5, 4}, {21, 1, 5, 4}, 1.8f);
}
+TEST_F(VulkanAPITest, add_zero_dim) {
+ test_add({2, 6, 5, 6}, {}, 1.5f);
+}
+
TEST_F(VulkanAPITest, add_) {
auto a_cpu = at::rand({61, 17, 29, 83}, at::device(at::kCPU).dtype(at::kFloat));
auto a_vulkan = a_cpu.vulkan();
@@ -1874,6 +1906,10 @@
test_div({1, 15, 5, 4}, {21, 1, 5, 4});
}
+TEST_F(VulkanAPITest, div_zero_dim) {
+ test_div({1, 15, 5, 4}, {});
+}
+
TEST_F(VulkanAPITest, div_) {
auto a_cpu = at::rand({61, 17, 29, 83}, at::device(at::kCPU).dtype(at::kFloat));
auto a_vulkan = a_cpu.vulkan();
@@ -3281,6 +3317,10 @@
test_mul({1, 15, 5, 4}, {21, 1, 5, 4});
}
+TEST_F(VulkanAPITest, mul_zero_dim) {
+ test_mul({1, 15, 5, 4}, {});
+}
+
TEST_F(VulkanAPITest, mul_) {
auto a_cpu = at::rand({61, 17, 29, 83}, at::device(at::kCPU).dtype(at::kFloat));
auto a_vulkan = a_cpu.vulkan();
@@ -3482,6 +3522,10 @@
test_pow({1, 1, 5, 5}, {8, 8, 1, 1}); // mul4ch
}
+TEST_F(VulkanAPITest, pow_zero_dim) {
+ test_mul({1, 15, 5, 4}, {});
+}
+
void test_pow_(const at::IntArrayRef input_shape, const at::IntArrayRef other_shape) {
const auto cpu = at::rand(input_shape, at::device(at::kCPU).dtype(at::kFloat));
const auto other_cpu = at::rand(other_shape, at::device(at::kCPU).dtype(at::kFloat));
@@ -4151,6 +4195,10 @@
test_sub({1, 15, 5, 4}, {21, 1, 5, 4}, 1.8f);
}
+TEST_F(VulkanAPITest, sub_zero_dim) {
+ test_sub({1, 15, 5, 4}, {}, 1.8f);
+}
+
TEST_F(VulkanAPITest, sub_) {
auto a_cpu = at::rand({61, 17, 29, 83}, at::device(at::kCPU).dtype(at::kFloat));
auto a_vulkan = a_cpu.vulkan();
@@ -4361,9 +4409,15 @@
ASSERT_TRUE(check);
}
+TEST_F(VulkanAPITest, sum_dim_1d) {
+ test_sum_dim({7}, {-1});
+ test_sum_dim({3}, {0});
+}
+
TEST_F(VulkanAPITest, sum_dim_2d) {
test_sum_dim({2, 3}, {-1});
test_sum_dim({2, 7}, {-2});
+ test_sum_dim({2, 7}, {-1, -2});
}
TEST_F(VulkanAPITest, sum_dim_3d) {
@@ -4374,8 +4428,13 @@
test_sum_dim({10, 7, 5}, {0, 1});
test_sum_dim({10, 7, 5}, {0, 2});
test_sum_dim({10, 7, 5}, {1, 2});
+
test_sum_dim({10, 7, 5}, {-1, -2});
- test_sum_dim({10, 7, 5}, {0, -2});
+ test_sum_dim({10, 7, 5}, {-1, -3});
+ test_sum_dim({10, 7, 5}, {-2, -3});
+
+ test_sum_dim({10, 7, 5}, {0, 1, 2});
+ test_sum_dim({10, 7, 5}, {-1, -2, -3});
}
TEST_F(VulkanAPITest, sum_dim_4d) {
@@ -4398,6 +4457,18 @@
test_sum_dim({10, 7, 5, 6}, {3, 2, 1});
test_sum_dim({10, 7, 5, 6}, {3, -2, 1});
test_sum_dim({10, 7, 5, 6}, {-3, -2, -1});
+
+ test_sum_dim({10, 7, 5, 6}, {-1, -2, -3});
+ test_sum_dim({10, 7, 5, 6}, {-1, -2, -4});
+ test_sum_dim({10, 7, 5, 6}, {-1, -3, -4});
+ test_sum_dim({10, 7, 5, 6}, {-2, -3, -4});
+
+ test_sum_dim({10, 7, 5, 6}, {-1, -2, -3, -4});
+}
+
+TEST_F(VulkanAPITest, sum_dim_keepdim_1d) {
+ test_sum_dim({5}, {-1}, true);
+ test_sum_dim({3}, {-1}, true);
}
TEST_F(VulkanAPITest, sum_dim_keepdim_2d) {
@@ -4413,6 +4484,8 @@
test_sum_dim({9, 5, 7}, {0, 1}, true);
test_sum_dim({5, 9, 7}, {0, 2}, true);
test_sum_dim({7, 9, 5}, {1, 2}, true);
+
+ test_sum_dim({7, 9, 5}, {0, 1, 2}, true);
}
TEST_F(VulkanAPITest, sum_dim_keepdim_4d) {
@@ -4431,8 +4504,38 @@
test_sum_dim({7, 11, 9, 5}, {-1, -2, -3}, true);
test_sum_dim({11, 7, 9, 5}, {-1, -2, -4}, true);
test_sum_dim({9, 5, 7, 11}, {-2, -3, -4}, true);
+
+ test_sum_dim({9, 5, 7, 11}, {-1, -2, -3, -4}, true);
}
+void test_sum(const at::IntArrayRef input_shape) {
+ const auto in_cpu = at::rand(input_shape, at::device(at::kCPU).dtype(at::kFloat));
+ const auto in_vulkan = in_cpu.vulkan();
+
+ const auto out_cpu = at::sum(in_cpu);
+ const auto out_vulkan = at::sum(in_vulkan);
+
+ ASSERT_TRUE(out_vulkan.dim() == 0);
+ const auto check = almostEqual(out_cpu, out_vulkan.cpu());
+ if (!check) {
+ std::cout << "sum test failed with input shape: "
+ << input_shape << std::endl;
+ showRtol(out_cpu, out_vulkan.cpu());
+ }
+
+ ASSERT_TRUE(check);
+}
+
+TEST_F(VulkanAPITest, sum_test) {
+ test_sum({6});
+ test_sum({5, 6});
+ test_sum({0, 3, 1});
+ test_sum({3, 3, 1});
+ test_sum({7, 6, 6});
+ test_sum({7, 8, 5, 6});
+}
+
+
void test_uniform(at::Tensor a_vulkan, const float a_min, const float a_max) {
auto a_cpu = a_vulkan.cpu();
ASSERT_TRUE(a_cpu.max().item<float>() <= a_max);