Migrate test in cpp/api/ to use gtest (#11556)
Summary:
The second part of T32009899
Pull Request resolved: https://github.com/pytorch/pytorch/pull/11556
Differential Revision: D9888224
Pulled By: zrphercule
fbshipit-source-id: cb0d0ba5d9c7ad601ee3bce0d932ce9cbbc40908
diff --git a/test/cpp/api/static.cpp b/test/cpp/api/static.cpp
index 121478c..827ff25 100644
--- a/test/cpp/api/static.cpp
+++ b/test/cpp/api/static.cpp
@@ -1,4 +1,5 @@
-#include <catch.hpp>
+
+#include "gtest/gtest.h"
#include <torch/detail/static.h>
#include <torch/nn/module.h>
@@ -22,43 +23,35 @@
return true;
}
-TEST_CASE("static") {
- SECTION("all_of") {
- REQUIRE(torch::all_of<>::value == true);
- REQUIRE(torch::all_of<true>::value == true);
- REQUIRE(torch::all_of<true, true, true>::value == true);
- REQUIRE(torch::all_of<false>::value == false);
- REQUIRE(torch::all_of<false, false, false>::value == false);
- REQUIRE(torch::all_of<true, true, false>::value == false);
- }
- SECTION("any_of") {
- REQUIRE(torch::any_of<>::value == false);
- REQUIRE(torch::any_of<true>::value == true);
- REQUIRE(torch::any_of<true, true, true>::value == true);
- REQUIRE(torch::any_of<false>::value == false);
- REQUIRE(torch::any_of<true, true, false>::value == true);
- }
- SECTION("enable_if_module_t") {
- REQUIRE(f(torch::nn::LinearImpl(1, 2)) == true);
- REQUIRE(f(5) == false);
- }
- SECTION("check_not_lvalue_references") {
- REQUIRE(torch::detail::check_not_lvalue_references<int>() == true);
- REQUIRE(
- torch::detail::check_not_lvalue_references<float, int, char>() == true);
- REQUIRE(
- torch::detail::check_not_lvalue_references<float, int&, char>() ==
- false);
- REQUIRE(torch::detail::check_not_lvalue_references<std::string>() == true);
- REQUIRE(
- torch::detail::check_not_lvalue_references<std::string&>() == false);
- }
- SECTION("apply") {
- std::vector<int> v;
- torch::apply([&v](int x) { v.push_back(x); }, 1, 2, 3, 4, 5);
- REQUIRE(v.size() == 5);
- for (size_t i = 0; i < v.size(); ++i) {
- REQUIRE(v.at(i) == 1 + i);
- }
+TEST(TestStatic, All_Of){
+ EXPECT_TRUE(torch::all_of<>::value);
+ EXPECT_TRUE(torch::all_of<true>::value);
+ EXPECT_TRUE((torch::all_of<true, true, true>::value));
+ EXPECT_FALSE(torch::all_of<false>::value);
+ EXPECT_FALSE((torch::all_of<false, false, false>::value));
+ EXPECT_FALSE((torch::all_of<true, true, false>::value));
+}
+TEST(TestStatic, Any_Of){
+ EXPECT_FALSE(torch::any_of<>::value);
+ EXPECT_TRUE(bool((torch::any_of<true>::value)));
+ EXPECT_TRUE(bool((torch::any_of<true, true, true>::value)));
+ EXPECT_FALSE(bool((torch::any_of<false>::value)));
+}
+TEST(TestStatic, Enable_If_Module){
+ EXPECT_TRUE(f(torch::nn::LinearImpl(1, 2)));
+ EXPECT_FALSE(f(5));
+ EXPECT_TRUE(torch::detail::check_not_lvalue_references<int>());
+ EXPECT_TRUE((torch::detail::check_not_lvalue_references<float, int, char>()));
+ EXPECT_FALSE(
+ (torch::detail::check_not_lvalue_references<float, int&, char>()));
+ EXPECT_TRUE(torch::detail::check_not_lvalue_references<std::string>());
+ EXPECT_FALSE(torch::detail::check_not_lvalue_references<std::string&>());
+}
+TEST(TestStatic, Apply){
+ std::vector<int> v;
+ torch::apply([&v](int x) { v.push_back(x); }, 1, 2, 3, 4, 5);
+ EXPECT_EQ(v.size(), 5);
+ for (size_t i = 0; i < v.size(); ++i) {
+ EXPECT_EQ(v.at(i), i + 1);
}
}
diff --git a/torch/CMakeLists.txt b/torch/CMakeLists.txt
index 68eee29..69b4963 100644
--- a/torch/CMakeLists.txt
+++ b/torch/CMakeLists.txt
@@ -436,8 +436,10 @@
endif()
if (BUILD_TEST AND NOT NO_API AND NOT USE_ROCM)
- set(TORCH_API_TEST_DIR "${TORCH_ROOT}/test/cpp/api")
+ #Catch test of api.
+ #TODO: Change all these tests to Google test.
+ set(TORCH_API_TEST_DIR "${TORCH_ROOT}/test/cpp/api")
set(TORCH_API_TEST_SOURCES
${TORCH_API_TEST_DIR}/any.cpp
${TORCH_API_TEST_DIR}/cursor.cpp
@@ -450,7 +452,6 @@
${TORCH_API_TEST_DIR}/parallel.cpp
${TORCH_API_TEST_DIR}/rnn.cpp
${TORCH_API_TEST_DIR}/sequential.cpp
- ${TORCH_API_TEST_DIR}/static.cpp
${TORCH_API_TEST_DIR}/tensor_cuda.cpp
${TORCH_API_TEST_DIR}/tensor.cpp
${TORCH_API_TEST_DIR}/jit.cpp
@@ -471,19 +472,36 @@
target_link_libraries(test_api torch ${TORCH_CUDA_LIBRARIES} ${CUDA_NVRTC_LIB} ${CUDA_CUDA_LIB})
+ #Google test of api.
+ set(TORCH_API_GTEST_DIR "${TORCH_ROOT}/test/cpp/api/")
+ add_executable(gtest_api
+ ${TORCH_API_GTEST_DIR}/static.cpp
+ )
+ target_include_directories(gtest_api PRIVATE ${ATen_CPU_INCLUDE})
+ target_link_libraries(gtest_api torch gtest_main)
+ if (USE_CUDA)
+ target_link_libraries(gtest_api ${CUDA_LIBRARIES} ${CUDA_NVRTC_LIB} ${CUDA_CUDA_LIB} ${TORCH_CUDA_LIBRARIES})
+ endif()
+
+ #Adding compile options for both tests.
if (NOT MSVC)
if (APPLE)
target_compile_options(test_api PRIVATE
# Clang has an unfixed bug leading to spurious missing braces
# warnings, see https://bugs.llvm.org/show_bug.cgi?id=21629
-Wno-missing-braces)
- else()
+ target_compile_options(gtest_api PRIVATE
+ -Wno-missing-braces)
+ else()
target_compile_options(test_api PRIVATE
# Considered to be flaky. See the discussion at
# https://github.com/pytorch/pytorch/pull/9608
-Wno-maybe-uninitialized
# gcc gives nonsensical warnings about variadic.h
-Wno-unused-but-set-parameter)
+ target_compile_options(gtest_api PRIVATE
+ -Wno-maybe-uninitialized
+ -Wno-unused-but-set-parameter)
endif()
endif()
endif()