Use CATCH prefix to avoid name conflicts with Caffe2.
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/11780
Differential Revision: D9889925
Pulled By: gchanan
fbshipit-source-id: 5eca849c36ced00b8ae7482b7945b445a3e1687e
diff --git a/aten/src/ATen/test/apply_test.cpp b/aten/src/ATen/test/apply_test.cpp
index 986f599..fc39ecc 100644
--- a/aten/src/ATen/test/apply_test.cpp
+++ b/aten/src/ATen/test/apply_test.cpp
@@ -1,5 +1,5 @@
#define CATCH_CONFIG_MAIN
-#include "catch.hpp"
+#include "catch_utils.hpp"
#include "cuda.h"
#include "cuda_runtime.h"
@@ -11,111 +11,111 @@
*/
#ifndef _WIN32
-TEST_CASE("2D Contiguous", "Collapses a 2D contiguous tensor to 1D contiguous") {
+CATCH_TEST_CASE("2D Contiguous", "Collapses a 2D contiguous tensor to 1D contiguous") {
int sizes[] = {4, 4};
int strides[] = {4, 1};
::at::cuda::detail::TensorInfo<void, int> ti{nullptr, 2, sizes, strides};
ti.collapseDims();
- REQUIRE(ti.dims == 1);
- REQUIRE(ti.sizes[0] == (4 * 4));
+ CATCH_REQUIRE(ti.dims == 1);
+ CATCH_REQUIRE(ti.sizes[0] == (4 * 4));
}
-TEST_CASE("3D Contiguous", "Collapses a 3D contiguous tensor to a 1D contiguous") {
+CATCH_TEST_CASE("3D Contiguous", "Collapses a 3D contiguous tensor to a 1D contiguous") {
int sizes[] = {6, 3, 7};
int strides[] = {3 * 7, 7, 1};
::at::cuda::detail::TensorInfo<void, int> ti{nullptr, 3, sizes, strides};
ti.collapseDims();
- REQUIRE(ti.dims == 1);
- REQUIRE(ti.sizes[0] == (6 * 3 * 7));
+ CATCH_REQUIRE(ti.dims == 1);
+ CATCH_REQUIRE(ti.sizes[0] == (6 * 3 * 7));
}
-TEST_CASE("3D Partial Collapse", "Collapses a 3D noncontiguous tensor to a 2D tensor") {
+CATCH_TEST_CASE("3D Partial Collapse", "Collapses a 3D noncontiguous tensor to a 2D tensor") {
int sizes[] = {4, 3, 2};
int strides[] = {3 * 3, 3, 1};
::at::cuda::detail::TensorInfo<void, int> ti{nullptr, 3, sizes, strides};
ti.collapseDims();
- REQUIRE(ti.dims == 2);
- REQUIRE(ti.sizes[0] == (4 * 3));
- REQUIRE(ti.sizes[1] == 2);
+ CATCH_REQUIRE(ti.dims == 2);
+ CATCH_REQUIRE(ti.sizes[0] == (4 * 3));
+ CATCH_REQUIRE(ti.sizes[1] == 2);
}
-TEST_CASE("2D Strided Collapse", "Collapses a 2D skip contiguous tensor to a 1D skip contiguous tensor") {
+CATCH_TEST_CASE("2D Strided Collapse", "Collapses a 2D skip contiguous tensor to a 1D skip contiguous tensor") {
int sizes[] = {3, 2};
int strides[] = {2 * 2, 2};
::at::cuda::detail::TensorInfo<void, int> ti{nullptr, 2, sizes, strides};
ti.collapseDims();
- REQUIRE(ti.dims == 1);
- REQUIRE(ti.sizes[0] == (3 * 2));
- REQUIRE(ti.strides[0] == 2);
+ CATCH_REQUIRE(ti.dims == 1);
+ CATCH_REQUIRE(ti.sizes[0] == (3 * 2));
+ CATCH_REQUIRE(ti.strides[0] == 2);
}
-TEST_CASE("4D Partial Strided Collapse", "Collapses a 4D tensor to a 2D tensor"){
+CATCH_TEST_CASE("4D Partial Strided Collapse", "Collapses a 4D tensor to a 2D tensor"){
int sizes[] = {3, 6, 5, 2};
int strides[] = {6 * 22, 22, 2 * 2, 2};
::at::cuda::detail::TensorInfo<void, int> ti{nullptr, 4, sizes, strides};
ti.collapseDims();
- REQUIRE(ti.dims == 2);
- REQUIRE(ti.sizes[0] == (3 * 6));
- REQUIRE(ti.strides[0] == 22);
- REQUIRE(ti.sizes[1] == (5 * 2));
- REQUIRE(ti.strides[1] == 2);
+ CATCH_REQUIRE(ti.dims == 2);
+ CATCH_REQUIRE(ti.sizes[0] == (3 * 6));
+ CATCH_REQUIRE(ti.strides[0] == 22);
+ CATCH_REQUIRE(ti.sizes[1] == (5 * 2));
+ CATCH_REQUIRE(ti.strides[1] == 2);
}
-TEST_CASE("Collapsing Zeros and Ones", "Collapses a 5D tensor to a 1D tensor") {
+CATCH_TEST_CASE("Collapsing Zeros and Ones", "Collapses a 5D tensor to a 1D tensor") {
int sizes[] = {1, 10, 1, 5, 4};
int strides[] = {4, 0, 16, 0, 1};
::at::cuda::detail::TensorInfo<void, int> ti{nullptr, 5, sizes, strides};
ti.collapseDims();
- REQUIRE(ti.dims == 2);
- REQUIRE(ti.sizes[0] == (10 * 5));
- REQUIRE(ti.strides[0] == 0);
- REQUIRE(ti.sizes[1] == 4);
- REQUIRE(ti.strides[1] == 1);
+ CATCH_REQUIRE(ti.dims == 2);
+ CATCH_REQUIRE(ti.sizes[0] == (10 * 5));
+ CATCH_REQUIRE(ti.strides[0] == 0);
+ CATCH_REQUIRE(ti.sizes[1] == 4);
+ CATCH_REQUIRE(ti.strides[1] == 1);
}
-TEST_CASE("Collapsing to a Point Tensor", "Collapses a 3D tensor to a point tensor") {
+CATCH_TEST_CASE("Collapsing to a Point Tensor", "Collapses a 3D tensor to a point tensor") {
int sizes[] = {1, 1, 1};
int strides[] = {17, 12, 3};
::at::cuda::detail::TensorInfo<void, int> ti{nullptr, 3, sizes, strides};
- REQUIRE(ti.collapseDims() == 0);
- REQUIRE(ti.dims == 1);
- REQUIRE(ti.sizes[0] == 1);
- REQUIRE(ti.strides[0] == 1);
+ CATCH_REQUIRE(ti.collapseDims() == 0);
+ CATCH_REQUIRE(ti.dims == 1);
+ CATCH_REQUIRE(ti.sizes[0] == 1);
+ CATCH_REQUIRE(ti.strides[0] == 1);
}
-TEST_CASE("Excluding in a 4D Contiguous", "Collapses a 4D tensor to a 3D tensor") {
+CATCH_TEST_CASE("Excluding in a 4D Contiguous", "Collapses a 4D tensor to a 3D tensor") {
int sizes[] = {3, 6, 5, 2};
int strides[] = {6 * 22, 22, 2 * 2, 2};
::at::cuda::detail::TensorInfo<void, int> ti{nullptr, 4, sizes, strides};
- REQUIRE(ti.collapseDims(1) == 1);
- REQUIRE(ti.dims == 3);
- REQUIRE(ti.sizes[0] == 3);
- REQUIRE(ti.strides[0] == (6 * 22));
- REQUIRE(ti.sizes[1] == 6);
- REQUIRE(ti.strides[1] == 22);
- REQUIRE(ti.sizes[2] == (5 * 2));
- REQUIRE(ti.strides[2] == 2);
+ CATCH_REQUIRE(ti.collapseDims(1) == 1);
+ CATCH_REQUIRE(ti.dims == 3);
+ CATCH_REQUIRE(ti.sizes[0] == 3);
+ CATCH_REQUIRE(ti.strides[0] == (6 * 22));
+ CATCH_REQUIRE(ti.sizes[1] == 6);
+ CATCH_REQUIRE(ti.strides[1] == 22);
+ CATCH_REQUIRE(ti.sizes[2] == (5 * 2));
+ CATCH_REQUIRE(ti.strides[2] == 2);
}
-TEST_CASE("Roving Exclusion", "Collapses a 4D tensor to a 3D tensor") {
+CATCH_TEST_CASE("Roving Exclusion", "Collapses a 4D tensor to a 3D tensor") {
int sizes[] = {3, 6, 5, 2};
int strides[] = {6 * 22, 22, 2 * 2, 2};
::at::cuda::detail::TensorInfo<void, int> ti{nullptr, 4, sizes, strides};
- REQUIRE(ti.collapseDims(2) == 1);
- REQUIRE(ti.dims == 3);
- REQUIRE(ti.sizes[0] == (3 * 6));
- REQUIRE(ti.strides[0] == 22);
- REQUIRE(ti.sizes[1] == 5);
- REQUIRE(ti.strides[1] == 4);
- REQUIRE(ti.sizes[2] == 2);
- REQUIRE(ti.strides[2] == 2);
+ CATCH_REQUIRE(ti.collapseDims(2) == 1);
+ CATCH_REQUIRE(ti.dims == 3);
+ CATCH_REQUIRE(ti.sizes[0] == (3 * 6));
+ CATCH_REQUIRE(ti.strides[0] == 22);
+ CATCH_REQUIRE(ti.sizes[1] == 5);
+ CATCH_REQUIRE(ti.strides[1] == 4);
+ CATCH_REQUIRE(ti.sizes[2] == 2);
+ CATCH_REQUIRE(ti.strides[2] == 2);
}
-TEST_CASE("Invalid Exclusion", "Attempts to exclude a nonexisting dimension") {
+CATCH_TEST_CASE("Invalid Exclusion", "Attempts to exclude a nonexisting dimension") {
int sizes[] = {1, 1, 1};
int strides[] = {17, 12, 3};
::at::cuda::detail::TensorInfo<void, int> ti{nullptr, 3, sizes, strides};
- REQUIRE_THROWS(ti.collapseDims(5));
+ _CATCH_REQUIRE_THROWS(ti.collapseDims(5));
}
#endif
diff --git a/aten/src/ATen/test/apply_utils_test.cpp b/aten/src/ATen/test/apply_utils_test.cpp
index 38027ba..22be6de 100644
--- a/aten/src/ATen/test/apply_utils_test.cpp
+++ b/aten/src/ATen/test/apply_utils_test.cpp
@@ -1,5 +1,5 @@
#define CATCH_CONFIG_MAIN
-#include "catch.hpp"
+#include "catch_utils.hpp"
#include "ATen/ATen.h"
#include "ATen/CPUApplyUtils.h"
@@ -108,32 +108,32 @@
});
}
-TEST_CASE("apply utils test 2-dim small contiguous", "[cpu]") {
+CATCH_TEST_CASE("apply utils test 2-dim small contiguous", "[cpu]") {
manual_seed(123, at::kCPU);
test(CPU(kDouble), {2, 1}, -1, -1);
}
-TEST_CASE("apply utils test 2-dim small", "[cpu]") {
+CATCH_TEST_CASE("apply utils test 2-dim small", "[cpu]") {
manual_seed(123, at::kCPU);
test(CPU(kDouble), {2, 1});
}
-TEST_CASE("apply utils test 2-dim", "[cpu]") {
+CATCH_TEST_CASE("apply utils test 2-dim", "[cpu]") {
manual_seed(123, at::kCPU);
test(CPU(kDouble), {20, 10});
}
-TEST_CASE("apply utils test 3-dim", "[cpu]") {
+CATCH_TEST_CASE("apply utils test 3-dim", "[cpu]") {
manual_seed(123, at::kCPU);
test(CPU(kDouble), {3, 4, 2});
}
-TEST_CASE("apply utils test 3-dim medium", "[cpu]") {
+CATCH_TEST_CASE("apply utils test 3-dim medium", "[cpu]") {
manual_seed(123, at::kCPU);
test(CPU(kDouble), {3, 40, 2});
}
-TEST_CASE("apply utils test 10-dim", "[cpu]") {
+CATCH_TEST_CASE("apply utils test 10-dim", "[cpu]") {
manual_seed(123, at::kCPU);
test(CPU(kDouble), {3, 4, 2, 5, 2, 1, 3, 4, 2, 3});
}
diff --git a/aten/src/ATen/test/basic.cpp b/aten/src/ATen/test/basic.cpp
index 9498812..c64fdec 100644
--- a/aten/src/ATen/test/basic.cpp
+++ b/aten/src/ATen/test/basic.cpp
@@ -1,5 +1,5 @@
#define CATCH_CONFIG_MAIN
-#include "catch.hpp"
+#include "catch_utils.hpp"
#include "ATen/ATen.h"
#include "ATen/core/Reduction.h"
@@ -20,66 +20,66 @@
using Catch::Matchers::StartsWith;
static void test(Type & type) {
- SECTION( "resize" ) {
+ CATCH_SECTION( "resize" ) {
auto a = type.tensor();
a.resize_({3,4});
- REQUIRE(a.numel() == 12);
+ CATCH_REQUIRE(a.numel() == 12);
a.resize_({5, 7});
- REQUIRE(a.numel() == 35);
+ CATCH_REQUIRE(a.numel() == 35);
}
- SECTION( "ones and dot" ) {
+ CATCH_SECTION( "ones and dot" ) {
Tensor b0 = ones({1, 1}, type);
- REQUIRE(2 == (b0+b0).sum().toCDouble());
+ CATCH_REQUIRE(2 == (b0+b0).sum().toCDouble());
Tensor b1 = ones({1, 2}, type);
- REQUIRE(4 == (b1+b1).sum().toCDouble());
+ CATCH_REQUIRE(4 == (b1+b1).sum().toCDouble());
Tensor b = ones({3, 4}, type);
- REQUIRE(24 == (b+b).sum().toCDouble());
- REQUIRE(12 == b.numel());
- REQUIRE(b.view(-1).dot(b.view(-1)).toCDouble() == 12);
+ CATCH_REQUIRE(24 == (b+b).sum().toCDouble());
+ CATCH_REQUIRE(12 == b.numel());
+ CATCH_REQUIRE(b.view(-1).dot(b.view(-1)).toCDouble() == 12);
}
- SECTION( "rand" ) {
+ CATCH_SECTION( "rand" ) {
for(auto i = 0; i < 10; i++) {
Tensor a = rand({3,4}, type.toScalarType(i % 2 == 0 ? kFloat : kDouble));
}
}
- SECTION( "sort" ) {
+ CATCH_SECTION( "sort" ) {
Tensor b = rand({3, 4}, type);
auto z = b.sort(1);
auto z_sorted = std::get<0>(z);
- REQUIRE(z_sorted[0][0].toCFloat() < z_sorted[0][1].toCFloat());
+ CATCH_REQUIRE(z_sorted[0][0].toCFloat() < z_sorted[0][1].toCFloat());
}
if(type.backend() != Backend::CUDA)
- SECTION( "randperm" ) {
+ CATCH_SECTION( "randperm" ) {
Tensor b = randperm(15, type);
Tensor rv, ri;
std::tie(rv, ri) = sort(b, 0);
- REQUIRE(rv[0].toCFloat() <= rv[1].toCFloat());
+ CATCH_REQUIRE(rv[0].toCFloat() <= rv[1].toCFloat());
}
- SECTION( "context" ) {
+ CATCH_SECTION( "context" ) {
std::stringstream ss;
ss << "context: " << std::hex << (int64_t)&globalContext() << std::endl;
}
- SECTION( "add" ) {
+ CATCH_SECTION( "add" ) {
Tensor a = rand({3, 4}, type);
Tensor b = rand({3, 4}, type);
Tensor c = add(a, add(a, b));
//TODO:0-dim Tensor d(3.f);
Scalar d = 3.f;
- REQUIRE( add(c, d).allclose(a + a + b + d) );
+ CATCH_REQUIRE( add(c, d).allclose(a + a + b + d) );
}
- SECTION( "loads of adds" ) {
+ CATCH_SECTION( "loads of adds" ) {
auto begin = std::chrono::high_resolution_clock::now();
Tensor d = ones({3, 4}, type);
Tensor r = zeros({3, 4}, type);
@@ -89,10 +89,10 @@
auto end = std::chrono::high_resolution_clock::now();
//TODO TEST PERF?
std::cout << std::dec << " " << std::chrono::duration_cast<std::chrono::milliseconds>(end-begin).count() << " ms" << std::endl;
- REQUIRE(norm(100000*d).toCDouble() == norm(r).toCDouble());
+ CATCH_REQUIRE(norm(100000*d).toCDouble() == norm(r).toCDouble());
}
- SECTION( "loads of adds (with copy)" ) {
+ CATCH_SECTION( "loads of adds (with copy)" ) {
auto begin = std::chrono::high_resolution_clock::now();
Tensor d = ones({3, 4}, type);
Tensor r = zeros({3, 4}, type);
@@ -102,59 +102,59 @@
auto end = std::chrono::high_resolution_clock::now();
//TODO TEST PERF?
std::cout << std::dec << " " << std::chrono::duration_cast<std::chrono::milliseconds>(end-begin).count() << " ms" << std::endl;
- REQUIRE(norm(100000*d).toCDouble() == norm(r).toCDouble());
+ CATCH_REQUIRE(norm(100000*d).toCDouble() == norm(r).toCDouble());
}
- SECTION( "isContiguous" ) {
+ CATCH_SECTION( "isContiguous" ) {
Tensor a = rand({3, 4}, type);
- REQUIRE(a.is_contiguous());
+ CATCH_REQUIRE(a.is_contiguous());
a = a.transpose(0, 1);
- REQUIRE(!a.is_contiguous());
+ CATCH_REQUIRE(!a.is_contiguous());
}
- SECTION( "permute" ) {
+ CATCH_SECTION( "permute" ) {
Tensor a = rand({3, 4, 5}, type);
Tensor b = a.permute({1, 2, 0});
- REQUIRE(b.sizes().equals({4, 5, 3}));
- REQUIRE(b.strides().equals({5, 1, 20}));
+ CATCH_REQUIRE(b.sizes().equals({4, 5, 3}));
+ CATCH_REQUIRE(b.strides().equals({5, 1, 20}));
}
- SECTION( "mm" ) {
+ CATCH_SECTION( "mm" ) {
Tensor a = rand({3, 4}, type);
Tensor b = rand({4}, type);
Tensor c = mv(a, b);
- REQUIRE(c.equal(addmv(zeros({3}, type), a, b, 0, 1)));
+ CATCH_REQUIRE(c.equal(addmv(zeros({3}, type), a, b, 0, 1)));
}
- SECTION( "squeeze" ) {
+ CATCH_SECTION( "squeeze" ) {
Tensor a = rand({2, 1}, type);
Tensor b = squeeze(a);
- REQUIRE(b.dim() == 1);
+ CATCH_REQUIRE(b.dim() == 1);
a = rand({1}, type);
b = squeeze(a);
//TODO 0-dim squeeze
- REQUIRE(a[0].equal(b));
+ CATCH_REQUIRE(a[0].equal(b));
}
- SECTION( "copy" ) {
+ CATCH_SECTION( "copy" ) {
Tensor a = zeros({4, 3}, type);
Tensor e = rand({4, 3}, type);
a.copy_(e);
- REQUIRE(a.equal(e));
+ CATCH_REQUIRE(a.equal(e));
}
- SECTION( "copy (broadcasting)" ) {
+ CATCH_SECTION( "copy (broadcasting)" ) {
Tensor a = zeros({4, 3}, type);
Tensor e = rand({3}, type);
a.copy_(e);
for (int i = 0; i < 4; ++i) {
- REQUIRE(a[i].equal(e));
+ CATCH_REQUIRE(a[i].equal(e));
}
}
- SECTION( "abs(value)" ) {
+ CATCH_SECTION( "abs(value)" ) {
Tensor r = at::abs(type.scalarTensor(-3));
- REQUIRE(r.toCInt() == 3);
+ CATCH_REQUIRE(r.toCInt() == 3);
}
//TODO(zach): operator overloads
@@ -168,120 +168,120 @@
}
#endif
- SECTION( "adding a value with a scalar" ) {
+ CATCH_SECTION( "adding a value with a scalar" ) {
Tensor a = rand({4, 3}, type);
- REQUIRE((ones({4,3}, type) + a).equal(add(a,1)));
+ CATCH_REQUIRE((ones({4,3}, type) + a).equal(add(a,1)));
}
- SECTION( "select" ) {
+ CATCH_SECTION( "select" ) {
Tensor a = rand({3, 7}, type);
auto a_13 = select(a, 1, 3);
auto a_13_02 = select(select(a, 1, 3), 0, 2);
- REQUIRE( a[0][3].equal(a_13[0]) );
- REQUIRE( a[2][3].equal(a_13_02) );
+ CATCH_REQUIRE( a[0][3].equal(a_13[0]) );
+ CATCH_REQUIRE( a[2][3].equal(a_13_02) );
}
- SECTION( "zero-dim" ) {
+ CATCH_SECTION( "zero-dim" ) {
Tensor a = type.scalarTensor(4); //rand(type, {1});
Tensor b = rand({3,4}, type);
- REQUIRE((a + a).dim() == 0);
- REQUIRE((1 + a).dim() == 0);
- REQUIRE((b + a).dim() == 2);
- REQUIRE((a + b).dim() == 2);
+ CATCH_REQUIRE((a + a).dim() == 0);
+ CATCH_REQUIRE((1 + a).dim() == 0);
+ CATCH_REQUIRE((b + a).dim() == 2);
+ CATCH_REQUIRE((a + b).dim() == 2);
auto c = rand({3,4}, type);
- REQUIRE(c[1][2].dim() == 0);
+ CATCH_REQUIRE(c[1][2].dim() == 0);
auto f = rand({3,4}, type);
f[2] = zeros({4}, type);
f[1][0] = -1;
- REQUIRE(f[2][0].toCDouble() == 0);
+ CATCH_REQUIRE(f[2][0].toCDouble() == 0);
}
- SECTION( "tensor from TH" ) {
+ CATCH_SECTION( "tensor from TH" ) {
int a = 4;
THFloatTensor *t = THFloatTensor_newWithSize2d(a, a);
THFloatTensor_fill(t, a);
Tensor tt = CPU(kFloat).unsafeTensorFromTH(t,false);
- REQUIRE_NOTHROW(tt);
+ CATCH_REQUIRE_NOTHROW(tt);
}
- SECTION( "toCFloat" ) {
+ CATCH_SECTION( "toCFloat" ) {
Tensor a = zeros({3,4});
Tensor b = ones({3,7});
Tensor c = cat({a,b},1);
- REQUIRE(c.size(1) == 11);
+ CATCH_REQUIRE(c.size(1) == 11);
Tensor e = rand({});
- REQUIRE(*e.data<float>() == e.sum().toCFloat());
+ CATCH_REQUIRE(*e.data<float>() == e.sum().toCFloat());
}
- SECTION( "to string" ) {
+ CATCH_SECTION( "to string" ) {
Tensor b = ones({3,7})*.0000001f;
std::stringstream s;
s << b << "\n";
std::string expect = "1e-07 *";
- REQUIRE(s.str().substr(0,expect.size()) == expect);
+ CATCH_REQUIRE(s.str().substr(0,expect.size()) == expect);
}
- SECTION("indexing by Scalar") {
+ CATCH_SECTION("indexing by Scalar") {
Tensor tensor = arange(0, 10, kInt);
Tensor one = ones({}, kInt);
for (int64_t i = 0; i < tensor.numel(); ++i) {
- REQUIRE(tensor[i].equal(one * i));
+ CATCH_REQUIRE(tensor[i].equal(one * i));
}
for (size_t i = 0; i < static_cast<uint64_t>(tensor.numel()); ++i) {
- REQUIRE(tensor[i].equal(one * static_cast<int64_t>(i)));
+ CATCH_REQUIRE(tensor[i].equal(one * static_cast<int64_t>(i)));
}
for (int i = 0; i < tensor.numel(); ++i) {
- REQUIRE(tensor[i].equal(one * i));
+ CATCH_REQUIRE(tensor[i].equal(one * i));
}
for (int16_t i = 0; i < tensor.numel(); ++i) {
- REQUIRE(tensor[i].equal(one * i));
+ CATCH_REQUIRE(tensor[i].equal(one * i));
}
for (int8_t i = 0; i < tensor.numel(); ++i) {
- REQUIRE(tensor[i].equal(one * i));
+ CATCH_REQUIRE(tensor[i].equal(one * i));
}
- REQUIRE_THROWS_WITH(
+ CATCH_REQUIRE_THROWS_WITH(
tensor[Scalar(3.14)].equal(one),
StartsWith(
"Can only index tensors with integral scalars"));
}
- SECTION("indexing by zero-dim tensor") {
+ CATCH_SECTION("indexing by zero-dim tensor") {
Tensor tensor = arange(0, 10, kInt);
Tensor one = ones({}, kInt);
for (int i = 0; i < tensor.numel(); ++i) {
- REQUIRE(tensor[one * i].equal(one * i));
+ CATCH_REQUIRE(tensor[one * i].equal(one * i));
}
- REQUIRE_THROWS_WITH(
+ CATCH_REQUIRE_THROWS_WITH(
tensor[ones({}) * 3.14].equal(one),
StartsWith(
"Can only index tensors with integral scalars"));
- REQUIRE_THROWS_WITH(
+ CATCH_REQUIRE_THROWS_WITH(
tensor[Tensor()].equal(one),
StartsWith("Can only index with tensors that are defined"));
- REQUIRE_THROWS_WITH(
+ CATCH_REQUIRE_THROWS_WITH(
tensor[ones({2, 3, 4}, kInt)].equal(one),
StartsWith("Can only index with tensors that are scalars (zero-dim)"));
}
- SECTION("dispatch") {
+ CATCH_SECTION("dispatch") {
Tensor tensor = randn({20, 20});
Tensor other = randn({20, 20});
auto result = tensor.m(relu).m(mse_loss, other, Reduction::ElementwiseMean);
- REQUIRE(result.allclose(mse_loss(relu(tensor), other)));
+ CATCH_REQUIRE(result.allclose(mse_loss(relu(tensor), other)));
}
- SECTION("core") {
+ CATCH_SECTION("core") {
int i = CoreTest();
- REQUIRE(i + 1 == CoreTest());
+ CATCH_REQUIRE(i + 1 == CoreTest());
}
}
-TEST_CASE( "basic tests CPU", "[cpu]" ) {
+CATCH_TEST_CASE( "basic tests CPU", "[cpu]" ) {
manual_seed(123, at::kCPU);
test(CPU(kFloat));
}
-TEST_CASE( "basic tests GPU", "[cuda]" ) {
+CATCH_TEST_CASE( "basic tests GPU", "[cuda]" ) {
manual_seed(123, at::kCUDA);
if(at::hasCUDA()) {
diff --git a/aten/src/ATen/test/broadcast_test.cpp b/aten/src/ATen/test/broadcast_test.cpp
index cd5c43d..822a1d7 100644
--- a/aten/src/ATen/test/broadcast_test.cpp
+++ b/aten/src/ATen/test/broadcast_test.cpp
@@ -1,154 +1,154 @@
#define CATCH_CONFIG_MAIN
-#include "catch.hpp"
+#include "catch_utils.hpp"
#include "ATen/ATen.h"
#include "test_seed.h"
using namespace at;
-TEST_CASE( "broadcast", "[]" ) {
+CATCH_TEST_CASE( "broadcast", "[]" ) {
manual_seed(123, at::kCPU);
Type & T = CPU(kFloat);
// 0) pre-req tests:
- SECTION( "can't expand empty tensor" ) {
+ CATCH_SECTION( "can't expand empty tensor" ) {
auto empty = randn({0}, T);
- REQUIRE_THROWS(empty.expand({3}));
+ _CATCH_REQUIRE_THROWS(empty.expand({3}));
}
// 1) out-place function with 2 args
- SECTION( "out-place function with 2 args" ) {
+ CATCH_SECTION( "out-place function with 2 args" ) {
- SECTION( "basic" ) {
+ CATCH_SECTION( "basic" ) {
auto a = randn({3, 1}, T);
auto b = randn({5}, T);
std::vector<int64_t> expanded_sizes = {3, 5};
- REQUIRE((a + b).equal(a.expand(expanded_sizes) + b.expand(expanded_sizes)));
+ CATCH_REQUIRE((a + b).equal(a.expand(expanded_sizes) + b.expand(expanded_sizes)));
}
- SECTION( "with scalar" ) {
+ CATCH_SECTION( "with scalar" ) {
auto aScalar = ones({1}, T);
aScalar.unsafeGetTensorImpl()->maybe_zero_dim(true);
auto b = randn({3, 5}, T);
- REQUIRE((aScalar + b).equal(aScalar.expand(b.sizes()) + b.expand(b.sizes())));
+ CATCH_REQUIRE((aScalar + b).equal(aScalar.expand(b.sizes()) + b.expand(b.sizes())));
}
- SECTION( "old fallback behavior yields error" ) {
+ CATCH_SECTION( "old fallback behavior yields error" ) {
auto a = randn({3, 5}, T);
auto b = randn({5, 3}, T);
- REQUIRE_THROWS(a + b);
+ _CATCH_REQUIRE_THROWS(a + b);
}
- SECTION( "with mismatched sizes" ) {
+ CATCH_SECTION( "with mismatched sizes" ) {
auto a = randn({3, 5}, T);
auto b = randn({7, 5}, T);
- REQUIRE_THROWS(a + b);
+ _CATCH_REQUIRE_THROWS(a + b);
}
}
- SECTION( "out-place function with 3 args" ) {
+ CATCH_SECTION( "out-place function with 3 args" ) {
- SECTION( "basic" ) {
+ CATCH_SECTION( "basic" ) {
auto a = randn({3, 1, 1}, T);
auto b = randn({1, 2, 1}, T);
auto c = randn({1, 1, 5}, T);
std::vector<int64_t> expanded_sizes = {3, 2, 5};
- REQUIRE((a + b + c).equal(a.expand(expanded_sizes) + b.expand(expanded_sizes) + c.expand(expanded_sizes)));
+ CATCH_REQUIRE((a + b + c).equal(a.expand(expanded_sizes) + b.expand(expanded_sizes) + c.expand(expanded_sizes)));
}
- SECTION( "with scalar" ) {
+ CATCH_SECTION( "with scalar" ) {
auto aTensorScalar = ones({1}, T);
aTensorScalar.unsafeGetTensorImpl()->maybe_zero_dim(true);
auto b = randn({3, 2, 1}, T);
auto c = randn({1, 2, 5}, T);
std::vector<int64_t> expanded_sizes = {3, 2, 5};
- REQUIRE(aTensorScalar.addcmul(b, c).equal(
+ CATCH_REQUIRE(aTensorScalar.addcmul(b, c).equal(
aTensorScalar.expand(expanded_sizes).addcmul(b.expand(expanded_sizes), c.expand(expanded_sizes))));
}
- SECTION( "old fallback behavior yields error" ) {
+ CATCH_SECTION( "old fallback behavior yields error" ) {
auto a = randn({3, 2, 5}, T);
auto b = randn({2, 3, 5}, T);
auto c = randn({5, 3, 2}, T);
- REQUIRE_THROWS(a.addcmul(b, c));
+ _CATCH_REQUIRE_THROWS(a.addcmul(b, c));
}
- SECTION( "with mismatched sizes" ){
+ CATCH_SECTION( "with mismatched sizes" ){
auto a = randn({3, 2, 5}, T);
auto b = randn({2, 3, 5}, T);
auto c = randn({5, 5, 5}, T);
- REQUIRE_THROWS(a.addcmul(b, c));
+ _CATCH_REQUIRE_THROWS(a.addcmul(b, c));
}
}
- SECTION( "in-place function with 2 args" ) {
- SECTION( "basic" ) {
+ CATCH_SECTION( "in-place function with 2 args" ) {
+ CATCH_SECTION( "basic" ) {
auto a = randn({3, 5}, T);
auto b = randn({3, 1}, T);
- REQUIRE((a + b).equal(a + b.expand({3, 5})));
+ CATCH_REQUIRE((a + b).equal(a + b.expand({3, 5})));
}
- SECTION( "with scalar" ) {
+ CATCH_SECTION( "with scalar" ) {
auto a = randn({3, 5}, T);
auto bScalar = ones({1}, T);
bScalar.unsafeGetTensorImpl()->maybe_zero_dim(true);
- REQUIRE((a + bScalar).equal(a + bScalar.expand(a.sizes())));
+ CATCH_REQUIRE((a + bScalar).equal(a + bScalar.expand(a.sizes())));
}
- SECTION( "error: would have to expand inplace arg" ) {
+ CATCH_SECTION( "error: would have to expand inplace arg" ) {
auto a = randn({1, 5}, T);
auto b = randn({3, 1}, T);
- REQUIRE_THROWS(a.add_(b));
+ _CATCH_REQUIRE_THROWS(a.add_(b));
}
}
- SECTION( "in-place function with 3 args" ) {
+ CATCH_SECTION( "in-place function with 3 args" ) {
auto a = randn({3, 5, 2}, T);
auto b = randn({3, 1, 2}, T);
auto c = randn({1, 5, 1}, T);
- SECTION( "basic" ) {
+ CATCH_SECTION( "basic" ) {
auto aClone = a.clone();
- REQUIRE(a.addcmul_(b, c).equal(aClone.addcmul_(b.expand(a.sizes()), c.expand(a.sizes()))));
+ CATCH_REQUIRE(a.addcmul_(b, c).equal(aClone.addcmul_(b.expand(a.sizes()), c.expand(a.sizes()))));
}
- SECTION( "with scalar" ) {
+ CATCH_SECTION( "with scalar" ) {
auto aClone = a.clone();
auto bScalar = ones({1}, T);
bScalar.unsafeGetTensorImpl()->maybe_zero_dim(true);
- REQUIRE(a.addcmul_(bScalar, c).equal(aClone.addcmul_(bScalar.expand(a.sizes()), c.expand(a.sizes()))));
+ CATCH_REQUIRE(a.addcmul_(bScalar, c).equal(aClone.addcmul_(bScalar.expand(a.sizes()), c.expand(a.sizes()))));
}
- SECTION( "error: would have to expand inplace arg" ) {
+ CATCH_SECTION( "error: would have to expand inplace arg" ) {
auto a = randn({1, 3, 5}, T);
auto b = randn({4, 1, 1}, T);
auto c = randn({1, 3, 1}, T);
- REQUIRE_THROWS(a.addcmul_(b, c));
+ _CATCH_REQUIRE_THROWS(a.addcmul_(b, c));
}
}
- SECTION( "explicit dim specification" ) {
+ CATCH_SECTION( "explicit dim specification" ) {
auto a = randn({1}, T);
auto b = randn({5, 3}, T);
auto c = randn({3, 7}, T);
- SECTION( "basic" ) {
- REQUIRE(a.addmm(b, c).equal(a.expand({5,7}).addmm(b, c)));
+ CATCH_SECTION( "basic" ) {
+ CATCH_REQUIRE(a.addmm(b, c).equal(a.expand({5,7}).addmm(b, c)));
}
- SECTION( "with scalar" ) {
+ CATCH_SECTION( "with scalar" ) {
Tensor aScalar = ones({1}, T);
aScalar.unsafeGetTensorImpl()->maybe_zero_dim(true);
- REQUIRE(aScalar.addmm(b, c).equal(aScalar.expand({5, 7}).addmm(b, c)));
+ CATCH_REQUIRE(aScalar.addmm(b, c).equal(aScalar.expand({5, 7}).addmm(b, c)));
}
- SECTION( "with mismatched sizes" ) {
+ CATCH_SECTION( "with mismatched sizes" ) {
auto a = randn({3, 3}, T);
- REQUIRE_THROWS(a.addmm(b, c));
+ _CATCH_REQUIRE_THROWS(a.addmm(b, c));
}
}
}
diff --git a/aten/src/ATen/test/catch_utils.hpp b/aten/src/ATen/test/catch_utils.hpp
new file mode 100644
index 0000000..b9b0a87
--- /dev/null
+++ b/aten/src/ATen/test/catch_utils.hpp
@@ -0,0 +1,8 @@
+#pragma once
+
+#define CATCH_CONFIG_PREFIX_ALL
+#include <catch.hpp>
+
+// CATCH_REQUIRE_THROWS is not defined identically to REQUIRE_THROWS and causes warning;
+// define our own version that doesn't warn.
+#define _CATCH_REQUIRE_THROWS( ... ) INTERNAL_CATCH_THROWS( "CATCH_REQUIRE_THROWS", Catch::ResultDisposition::Normal, __VA_ARGS__ )
diff --git a/aten/src/ATen/test/cuda_half_test.cu b/aten/src/ATen/test/cuda_half_test.cu
index fa00e53..cce2671 100644
--- a/aten/src/ATen/test/cuda_half_test.cu
+++ b/aten/src/ATen/test/cuda_half_test.cu
@@ -1,5 +1,5 @@
#define CATCH_CONFIG_MAIN
-#include "catch.hpp"
+#include "catch_utils.hpp"
#include "ATen/ATen.h"
#include "ATen/cuda/NumericLimits.cuh"
@@ -82,9 +82,9 @@
kernel<<<1,1>>>();
}
-TEST_CASE( "half common math functions tests in device", "[cuda]" ) {
+CATCH_TEST_CASE( "half common math functions tests in device", "[cuda]" ) {
launch_function();
cudaError_t err = cudaDeviceSynchronize();
- REQUIRE(err == cudaSuccess);
+ CATCH_REQUIRE(err == cudaSuccess);
}
diff --git a/aten/src/ATen/test/cuda_optional_test.cu b/aten/src/ATen/test/cuda_optional_test.cu
index 9956dcf..b64c530 100644
--- a/aten/src/ATen/test/cuda_optional_test.cu
+++ b/aten/src/ATen/test/cuda_optional_test.cu
@@ -1,5 +1,5 @@
#define CATCH_CONFIG_MAIN
-#include "catch.hpp"
+#include "catch_utils.hpp"
#include "ATen/ATen.h"
#include "ATen/optional.h"
@@ -8,15 +8,15 @@
using namespace at;
-TEST_CASE( "optional in cuda files", "[cuda]" ) {
+CATCH_TEST_CASE( "optional in cuda files", "[cuda]" ) {
at::optional<int64_t> trivially_destructible;
at::optional<std::vector<int64_t>> non_trivially_destructible;
- REQUIRE(!trivially_destructible.has_value());
- REQUIRE(!non_trivially_destructible.has_value());
+ CATCH_REQUIRE(!trivially_destructible.has_value());
+ CATCH_REQUIRE(!non_trivially_destructible.has_value());
trivially_destructible = {5};
non_trivially_destructible = std::vector<int64_t>{5, 10};
- REQUIRE(trivially_destructible.has_value());
- REQUIRE(non_trivially_destructible.has_value());
+ CATCH_REQUIRE(trivially_destructible.has_value());
+ CATCH_REQUIRE(non_trivially_destructible.has_value());
}
diff --git a/aten/src/ATen/test/cuda_packedtensoraccessor_test.cu b/aten/src/ATen/test/cuda_packedtensoraccessor_test.cu
index f1eb5cb..a529f38 100644
--- a/aten/src/ATen/test/cuda_packedtensoraccessor_test.cu
+++ b/aten/src/ATen/test/cuda_packedtensoraccessor_test.cu
@@ -1,5 +1,5 @@
#define CATCH_CONFIG_MAIN
-#include "catch.hpp"
+#include "catch_utils.hpp"
#include "ATen/ATen.h"
#include "test_seed.h"
@@ -22,7 +22,7 @@
}
}
-TEST_CASE( "test PackedTensorAccessor and Tensor.packed_accessor", "[cuda]" ) {
+CATCH_TEST_CASE( "test PackedTensorAccessor and Tensor.packed_accessor", "[cuda]" ) {
manual_seed(123, at::kCPU);
manual_seed(123, at::kCUDA);
@@ -38,9 +38,9 @@
test_tensor_packed_accessor_kernel<<<1, 1, 0, stream>>>(resa, t1a, t2a);
cudaError_t err = cudaDeviceSynchronize();
- REQUIRE(err == cudaSuccess);
+ CATCH_REQUIRE(err == cudaSuccess);
auto expected = mv(t1, t2);
- REQUIRE(res.allclose(expected));
+ CATCH_REQUIRE(res.allclose(expected));
}
diff --git a/aten/src/ATen/test/cuda_rng_test.cpp b/aten/src/ATen/test/cuda_rng_test.cpp
index d32903d..7b14174 100644
--- a/aten/src/ATen/test/cuda_rng_test.cpp
+++ b/aten/src/ATen/test/cuda_rng_test.cpp
@@ -1,5 +1,5 @@
#define CATCH_CONFIG_MAIN
-#include "catch.hpp"
+#include "catch_utils.hpp"
#include "ATen/ATen.h"
#include "cuda.h"
@@ -21,7 +21,7 @@
}
};
-TEST_CASE( "CUDA RNG test", "[cuda]" ) {
- SECTION( "multithread" )
+CATCH_TEST_CASE( "CUDA RNG test", "[cuda]" ) {
+ CATCH_SECTION( "multithread" )
testCudaRNGMultithread();
}
diff --git a/aten/src/ATen/test/cudnn_test.cpp b/aten/src/ATen/test/cudnn_test.cpp
index 31786e8..4391867 100644
--- a/aten/src/ATen/test/cudnn_test.cpp
+++ b/aten/src/ATen/test/cudnn_test.cpp
@@ -1,5 +1,5 @@
#define CATCH_CONFIG_MAIN
-#include "catch.hpp"
+#include "catch_utils.hpp"
#include "ATen/ATen.h"
#include "ATen/cudnn/Descriptors.h"
@@ -9,7 +9,7 @@
using namespace at;
using namespace at::native;
-TEST_CASE( "cudnn", "[cuda]" ) {
+CATCH_TEST_CASE( "cudnn", "[cuda]" ) {
manual_seed(123, at::kCUDA);
#if CUDNN_VERSION < 7000
@@ -18,8 +18,8 @@
desc1.initialize_rng(at::CUDA(kByte), handle, 0.5, 42);
desc2.set(handle, 0.5, desc1.state);
- REQUIRE(desc1.desc()->dropout == desc2.desc()->dropout);
- REQUIRE(desc1.desc()->nstates == desc2.desc()->nstates);
- REQUIRE(desc1.desc()->states == desc2.desc()->states);
+ CATCH_REQUIRE(desc1.desc()->dropout == desc2.desc()->dropout);
+ CATCH_REQUIRE(desc1.desc()->nstates == desc2.desc()->nstates);
+ CATCH_REQUIRE(desc1.desc()->states == desc2.desc()->states);
#endif
}
diff --git a/aten/src/ATen/test/dlconvertor_test.cpp b/aten/src/ATen/test/dlconvertor_test.cpp
index 4882929..bf0cf93 100644
--- a/aten/src/ATen/test/dlconvertor_test.cpp
+++ b/aten/src/ATen/test/dlconvertor_test.cpp
@@ -1,5 +1,5 @@
#define CATCH_CONFIG_MAIN
-#include "catch.hpp"
+#include "catch_utils.hpp"
#include "ATen/ATen.h"
#include "ATen/DLConvertor.h"
@@ -11,17 +11,17 @@
using namespace at;
-TEST_CASE( "dlconvertor", "[cpu]" ) {
+CATCH_TEST_CASE( "dlconvertor", "[cpu]" ) {
manual_seed(123, at::kCPU);
- INFO( "convert ATen to DLTensor" );
+ CATCH_INFO( "convert ATen to DLTensor" );
Tensor a = rand({3,4});
DLManagedTensor* dlMTensor = toDLPack(a);
- INFO( "convert DLTensor to ATen" );
+ CATCH_INFO( "convert DLTensor to ATen" );
Tensor b = fromDLPack(dlMTensor);
- REQUIRE(a.equal(b));
+ CATCH_REQUIRE(a.equal(b));
}
diff --git a/aten/src/ATen/test/half_test.cpp b/aten/src/ATen/test/half_test.cpp
index 3b29448..3217770 100644
--- a/aten/src/ATen/test/half_test.cpp
+++ b/aten/src/ATen/test/half_test.cpp
@@ -1,5 +1,5 @@
#define CATCH_CONFIG_MAIN
-#include "catch.hpp"
+#include "catch_utils.hpp"
#include <ATen/ATen.h>
#include <iostream>
@@ -12,53 +12,53 @@
using namespace at;
-TEST_CASE( "half arithmetic", "[]" ) {
+CATCH_TEST_CASE( "half arithmetic", "[]" ) {
Half zero = 0;
Half one = 1;
- REQUIRE(zero + one == one);
- REQUIRE(zero + zero == zero);
- REQUIRE(zero * one == zero);
- REQUIRE(one * one == one);
- REQUIRE(one / one == one);
- REQUIRE(one - one == zero);
- REQUIRE(one - zero == one);
- REQUIRE(zero - one == -one);
- REQUIRE(one + one == Half(2));
- REQUIRE(one + one == 2);
+ CATCH_REQUIRE(zero + one == one);
+ CATCH_REQUIRE(zero + zero == zero);
+ CATCH_REQUIRE(zero * one == zero);
+ CATCH_REQUIRE(one * one == one);
+ CATCH_REQUIRE(one / one == one);
+ CATCH_REQUIRE(one - one == zero);
+ CATCH_REQUIRE(one - zero == one);
+ CATCH_REQUIRE(zero - one == -one);
+ CATCH_REQUIRE(one + one == Half(2));
+ CATCH_REQUIRE(one + one == 2);
}
-TEST_CASE( "half comparisons", "[]" ) {
+CATCH_TEST_CASE( "half comparisons", "[]" ) {
Half zero = 0;
Half one = 1;
- REQUIRE(zero < one);
- REQUIRE(zero < 1);
- REQUIRE(1 > zero);
- REQUIRE(0 >= zero);
- REQUIRE(0 != one);
- REQUIRE(zero == 0);
- REQUIRE(zero == zero);
- REQUIRE(zero == -zero);
+ CATCH_REQUIRE(zero < one);
+ CATCH_REQUIRE(zero < 1);
+ CATCH_REQUIRE(1 > zero);
+ CATCH_REQUIRE(0 >= zero);
+ CATCH_REQUIRE(0 != one);
+ CATCH_REQUIRE(zero == 0);
+ CATCH_REQUIRE(zero == zero);
+ CATCH_REQUIRE(zero == -zero);
}
-TEST_CASE( "half cast", "[]" ) {
+CATCH_TEST_CASE( "half cast", "[]" ) {
Half value = 1.5f;
- REQUIRE((int)value == 1);
- REQUIRE((short)value == 1);
- REQUIRE((long long)value == 1LL);
- REQUIRE((float)value == 1.5f);
- REQUIRE((double)value == 1.5);
- REQUIRE((bool)value == true);
- REQUIRE((bool)Half(0.0f) == false);
+ CATCH_REQUIRE((int)value == 1);
+ CATCH_REQUIRE((short)value == 1);
+ CATCH_REQUIRE((long long)value == 1LL);
+ CATCH_REQUIRE((float)value == 1.5f);
+ CATCH_REQUIRE((double)value == 1.5);
+ CATCH_REQUIRE((bool)value == true);
+ CATCH_REQUIRE((bool)Half(0.0f) == false);
}
-TEST_CASE( "half construction", "[]" ) {
- REQUIRE(Half((short)3) == Half(3.0f));
- REQUIRE(Half((unsigned short)3) == Half(3.0f));
- REQUIRE(Half(3) == Half(3.0f));
- REQUIRE(Half(3U) == Half(3.0f));
- REQUIRE(Half(3LL) == Half(3.0f));
- REQUIRE(Half(3ULL) == Half(3.0f));
- REQUIRE(Half(3.5) == Half(3.5f));
+CATCH_TEST_CASE( "half construction", "[]" ) {
+ CATCH_REQUIRE(Half((short)3) == Half(3.0f));
+ CATCH_REQUIRE(Half((unsigned short)3) == Half(3.0f));
+ CATCH_REQUIRE(Half(3) == Half(3.0f));
+ CATCH_REQUIRE(Half(3U) == Half(3.0f));
+ CATCH_REQUIRE(Half(3LL) == Half(3.0f));
+ CATCH_REQUIRE(Half(3ULL) == Half(3.0f));
+ CATCH_REQUIRE(Half(3.5) == Half(3.5f));
}
static std::string to_string(const Half& h) {
@@ -67,22 +67,22 @@
return ss.str();
}
-TEST_CASE( "half to string", "[]" ) {
- REQUIRE(to_string(Half(3.5f)) == "3.5");
- REQUIRE(to_string(Half(-100.0f)) == "-100");
+CATCH_TEST_CASE( "half to string", "[]" ) {
+ CATCH_REQUIRE(to_string(Half(3.5f)) == "3.5");
+ CATCH_REQUIRE(to_string(Half(-100.0f)) == "-100");
}
-TEST_CASE( "half numeric limits", "[]" ) {
+CATCH_TEST_CASE( "half numeric limits", "[]" ) {
using limits = std::numeric_limits<Half>;
- REQUIRE(limits::lowest() == -65504.0f);
- REQUIRE(limits::max() == 65504.0f);
- REQUIRE(limits::min() > 0);
- REQUIRE(limits::min() < 1);
- REQUIRE(limits::denorm_min() > 0);
- REQUIRE(limits::denorm_min() / 2 == 0);
- REQUIRE(limits::infinity() == std::numeric_limits<float>::infinity());
- REQUIRE(limits::quiet_NaN() != limits::quiet_NaN());
- REQUIRE(limits::signaling_NaN() != limits::signaling_NaN());
+ CATCH_REQUIRE(limits::lowest() == -65504.0f);
+ CATCH_REQUIRE(limits::max() == 65504.0f);
+ CATCH_REQUIRE(limits::min() > 0);
+ CATCH_REQUIRE(limits::min() < 1);
+ CATCH_REQUIRE(limits::denorm_min() > 0);
+ CATCH_REQUIRE(limits::denorm_min() / 2 == 0);
+ CATCH_REQUIRE(limits::infinity() == std::numeric_limits<float>::infinity());
+ CATCH_REQUIRE(limits::quiet_NaN() != limits::quiet_NaN());
+ CATCH_REQUIRE(limits::signaling_NaN() != limits::signaling_NaN());
}
// Check the declared type of members of numeric_limits<Half> matches
@@ -119,7 +119,7 @@
ASSERT_SAME_TYPE(traps);
ASSERT_SAME_TYPE(tinyness_before);
-TEST_CASE( "half common math functions test", "[]" ) {
+CATCH_TEST_CASE( "half common math functions test", "[]" ) {
float threshold = 0.00001;
assert(std::abs(std::lgamma(Half(10.0)) - std::lgamma(10.0f)) <= threshold);
assert(std::abs(std::exp(Half(1.0)) - std::exp(1.0f)) <= threshold);
diff --git a/aten/src/ATen/test/integer_divider_test.cu b/aten/src/ATen/test/integer_divider_test.cu
index 4c63ab3..d09a423 100644
--- a/aten/src/ATen/test/integer_divider_test.cu
+++ b/aten/src/ATen/test/integer_divider_test.cu
@@ -1,5 +1,5 @@
#define CATCH_CONFIG_MAIN
-#include "catch.hpp"
+#include "catch_utils.hpp"
// Test IntegerDivider: this tests *all* 32-bit pairs (a, b) where a % b is 0 or
// (b-1), so it takes a few minutes to run.
@@ -62,18 +62,18 @@
cudaError_t err;
err = cudaMalloc(÷rsBuf_, NUM_CASES * sizeof(IntDivider<Value>));
- REQUIRE(err == cudaSuccess);
+ CATCH_REQUIRE(err == cudaSuccess);
err = cudaMalloc(&testCasesBuf_, NUM_CASES * sizeof(TestCase<Value>));
- REQUIRE(err == cudaSuccess);
+ CATCH_REQUIRE(err == cudaSuccess);
}
~IntDividerTester() {
cudaError_t err;
err = cudaFree(dividersBuf_);
- REQUIRE(err == cudaSuccess);
+ CATCH_REQUIRE(err == cudaSuccess);
err = cudaFree(testCasesBuf_);
- REQUIRE(err == cudaSuccess);
+ CATCH_REQUIRE(err == cudaSuccess);
}
void addTestCase(Value dividend, Value divisor, int steps) {
@@ -92,18 +92,18 @@
cudaError_t err;
if (testCases_.empty()) return;
- REQUIRE(!dividers_.empty());
+ CATCH_REQUIRE(!dividers_.empty());
- REQUIRE(dividers_.size() <= NUM_CASES);
- REQUIRE(testCases_.size() <= NUM_CASES);
+ CATCH_REQUIRE(dividers_.size() <= NUM_CASES);
+ CATCH_REQUIRE(testCases_.size() <= NUM_CASES);
err = cudaMemcpy(dividersBuf_, dividers_.data(),
dividers_.size() * sizeof(IntDivider<Value>),
cudaMemcpyHostToDevice);
- REQUIRE(err == cudaSuccess);
+ CATCH_REQUIRE(err == cudaSuccess);
err = cudaMemcpy(testCasesBuf_, testCases_.data(),
testCases_.size() * sizeof(TestCase<Value>),
cudaMemcpyHostToDevice);
- REQUIRE(err == cudaSuccess);
+ CATCH_REQUIRE(err == cudaSuccess);
int numCases = testCases_.size();
testIntDivider<Value><<<512, 512>>>(
@@ -180,11 +180,11 @@
tester.flush();
}
-TEST_CASE( "CUDA integer divider", "[cuda]" ) {
+CATCH_TEST_CASE( "CUDA integer divider", "[cuda]" ) {
testUint64Divider();
testUint32Divider();
cudaError_t err = cudaDeviceSynchronize();
- REQUIRE(err == cudaSuccess);
+ CATCH_REQUIRE(err == cudaSuccess);
}
diff --git a/aten/src/ATen/test/native_test.cpp b/aten/src/ATen/test/native_test.cpp
index e10de30..4c57b7d 100644
--- a/aten/src/ATen/test/native_test.cpp
+++ b/aten/src/ATen/test/native_test.cpp
@@ -1,5 +1,5 @@
#define CATCH_CONFIG_MAIN
-#include "catch.hpp"
+#include "catch_utils.hpp"
#include "ATen/ATen.h"
#include "test_seed.h"
@@ -9,18 +9,18 @@
using Catch::Matchers::StartsWith;
#define REQUIRE_EQUAL(t1, t2) \
- REQUIRE(t1.equal(t2));
+ CATCH_REQUIRE(t1.equal(t2));
#define REQUIRE_ALLCLOSE(t1, t2) \
- REQUIRE(t1.is_same_size(t2)); \
- REQUIRE(t1.allclose(t2));
+ CATCH_REQUIRE(t1.is_same_size(t2)); \
+ CATCH_REQUIRE(t1.allclose(t2));
#define REQUIRE_ALLCLOSE_TOLERANCES(t1, t2, atol, rtol) \
- REQUIRE(t1.is_same_size(t2)); \
- REQUIRE(t1.allclose(t2, atol, rtol));
+ CATCH_REQUIRE(t1.is_same_size(t2)); \
+ CATCH_REQUIRE(t1.allclose(t2, atol, rtol));
void requireEqualTensorList(TensorList t1, TensorList t2) {
- REQUIRE(t1.size() == t2.size());
+ CATCH_REQUIRE(t1.size() == t2.size());
for (size_t i = 0; i < t1.size(); ++i) {
REQUIRE_EQUAL(t1[ i ], t2[ i ]);
}
@@ -29,7 +29,7 @@
void test(Type & T, Type & AccT) {
auto t = randn({3, 3}, T);
- SECTION( "split: test method, type, namespace give same result" ) {
+ CATCH_SECTION( "split: test method, type, namespace give same result" ) {
auto splitMethod = t.split(1, 0);
auto splitType = T.split(t, 1, 0);
auto splitNs = at::split(t, 1, 0);
@@ -40,7 +40,7 @@
REQUIRE_EQUAL(at::cat(splitMethod, 0), t);
}
- SECTION( "chunk: test method, type, namespace give same result" ) {
+ CATCH_SECTION( "chunk: test method, type, namespace give same result" ) {
// test method, type, namespace give same result
auto chunkMethod = t.chunk(3, 0);
auto chunkType = T.chunk(t, 3, 0);
@@ -53,7 +53,7 @@
}
// stack
- SECTION( "stack" ) {
+ CATCH_SECTION( "stack" ) {
auto x = rand({2, 3, 4});
auto y = rand({2, 3, 4});
auto z = rand({2, 3, 4});
@@ -66,36 +66,36 @@
expected_size.insert(expected_size.end(), x.sizes().begin() + dim, x.sizes().end());
REQUIRE_EQUAL(res, res_neg);
- REQUIRE(res.sizes().equals(expected_size));
+ CATCH_REQUIRE(res.sizes().equals(expected_size));
REQUIRE_EQUAL(res.select(dim, 0), x);
REQUIRE_EQUAL(res.select(dim, 1), y);
REQUIRE_EQUAL(res.select(dim, 2), z);
}
}
- SECTION( "size / stride" ) {
+ CATCH_SECTION( "size / stride" ) {
auto scalar = randn({}, T);
- REQUIRE_THROWS_WITH(scalar.size(0), StartsWith("dimension specified as 0 but tensor has no dimensions"));
- REQUIRE_THROWS_WITH(scalar.size(-1), StartsWith("dimension specified as -1 but tensor has no dimensions"));
- REQUIRE_THROWS_WITH(scalar.stride(0), StartsWith("dimension specified as 0 but tensor has no dimensions"));
- REQUIRE_THROWS_WITH(scalar.stride(-1), StartsWith("dimension specified as -1 but tensor has no dimensions"));
+ CATCH_REQUIRE_THROWS_WITH(scalar.size(0), StartsWith("dimension specified as 0 but tensor has no dimensions"));
+ CATCH_REQUIRE_THROWS_WITH(scalar.size(-1), StartsWith("dimension specified as -1 but tensor has no dimensions"));
+ CATCH_REQUIRE_THROWS_WITH(scalar.stride(0), StartsWith("dimension specified as 0 but tensor has no dimensions"));
+ CATCH_REQUIRE_THROWS_WITH(scalar.stride(-1), StartsWith("dimension specified as -1 but tensor has no dimensions"));
auto empty = randn({0}, T);
- REQUIRE(empty.size(0) == 0);
- REQUIRE(empty.size(-1) == 0);
- REQUIRE(empty.stride(0) == 1);
- REQUIRE(empty.stride(-1) == 1);
+ CATCH_REQUIRE(empty.size(0) == 0);
+ CATCH_REQUIRE(empty.size(-1) == 0);
+ CATCH_REQUIRE(empty.stride(0) == 1);
+ CATCH_REQUIRE(empty.stride(-1) == 1);
}
// matmul
- SECTION( "matmul" ) {
+ CATCH_SECTION( "matmul" ) {
auto scalar = randn({}, T);
auto d1 = randn({3}, T);
auto d2 = randn({2, 3}, T);
// 0-d
- REQUIRE_THROWS_WITH(scalar.matmul(d2), Catch::StartsWith("both arguments to matmul need to be at least 1D"));
- REQUIRE_THROWS_WITH(d2.matmul(scalar), Catch::StartsWith("both arguments to matmul need to be at least 1D"));
+ CATCH_REQUIRE_THROWS_WITH(scalar.matmul(d2), Catch::StartsWith("both arguments to matmul need to be at least 1D"));
+ CATCH_REQUIRE_THROWS_WITH(d2.matmul(scalar), Catch::StartsWith("both arguments to matmul need to be at least 1D"));
// 1-d
REQUIRE_ALLCLOSE(d1.matmul(d1), d1.dot(d1));
@@ -140,11 +140,11 @@
// non-expandable case
auto d5wrong = randn({2, 4, 2, 4, 3, 2}, T);
- REQUIRE_THROWS_WITH(d5.matmul(d5wrong), Catch::Contains("must match the size"));
+ CATCH_REQUIRE_THROWS_WITH(d5.matmul(d5wrong), Catch::Contains("must match the size"));
}
// _standard_gamma_grad
- SECTION( "_standard_gamma_grad" ) {
+ CATCH_SECTION( "_standard_gamma_grad" ) {
// check empty
auto empty = ones({0}, T);
REQUIRE_EQUAL(empty, at::_standard_gamma_grad(empty, empty));
@@ -158,10 +158,10 @@
// check mixing types
auto t1 = randn({3, 4}, T);
auto t2 = randn({3, 4}, T).toType(kDouble);
- REQUIRE_THROWS_WITH(at::_standard_gamma_grad(t1, t2), Catch::StartsWith("expected scalar type"));
+ CATCH_REQUIRE_THROWS_WITH(at::_standard_gamma_grad(t1, t2), Catch::StartsWith("expected scalar type"));
}
- SECTION( "where" ) {
+ CATCH_SECTION( "where" ) {
// empty
auto empty = ones({0}, T);
auto &bT = T.toScalarType(ScalarType::Byte);
@@ -180,13 +180,13 @@
}
}
-TEST_CASE( "native test CPU", "[cpu]" ) {
+CATCH_TEST_CASE( "native test CPU", "[cpu]" ) {
manual_seed(123, at::kCPU);
test(CPU(kFloat), CPU(kDouble));
}
-TEST_CASE( "native test CUDA", "[cuda]" ) {
+CATCH_TEST_CASE( "native test CUDA", "[cuda]" ) {
manual_seed(123, at::kCUDA);
if (at::hasCUDA()) {
diff --git a/aten/src/ATen/test/scalar_tensor_test.cpp b/aten/src/ATen/test/scalar_tensor_test.cpp
index d52dc27..964f626 100644
--- a/aten/src/ATen/test/scalar_tensor_test.cpp
+++ b/aten/src/ATen/test/scalar_tensor_test.cpp
@@ -1,5 +1,5 @@
#define CATCH_CONFIG_MAIN
-#include "catch.hpp"
+#include "catch_utils.hpp"
#include "ATen/ATen.h"
#include "test_seed.h"
@@ -18,14 +18,14 @@
_passed = true; \
els; \
} catch (std::exception &e) { \
- REQUIRE(!_passed); \
+ CATCH_REQUIRE(!_passed); \
catc; \
} \
}
void require_equal_size_dim(const Tensor &lhs, const Tensor &rhs) {
- REQUIRE(lhs.dim() == rhs.dim());
- REQUIRE(lhs.sizes().equals(rhs.sizes()));
+ CATCH_REQUIRE(lhs.dim() == rhs.dim());
+ CATCH_REQUIRE(lhs.sizes().equals(rhs.sizes()));
}
bool should_expand(const IntList &from_size, const IntList &to_size) {
@@ -49,15 +49,15 @@
for (auto s = sizes.begin(); s != sizes.end(); ++s) {
// verify that the dim, sizes, strides, etc match what was requested.
auto t = ones(*s, T);
- REQUIRE((size_t)t.dim() == s->size());
- REQUIRE((size_t)t.ndimension() == s->size());
- REQUIRE(t.sizes().equals(*s));
- REQUIRE(t.strides().size() == s->size());
+ CATCH_REQUIRE((size_t)t.dim() == s->size());
+ CATCH_REQUIRE((size_t)t.ndimension() == s->size());
+ CATCH_REQUIRE(t.sizes().equals(*s));
+ CATCH_REQUIRE(t.strides().size() == s->size());
auto numel = std::accumulate(s->begin(), s->end(), 1, std::multiplies<int64_t>());
- REQUIRE(t.numel() == numel);
+ CATCH_REQUIRE(t.numel() == numel);
// verify we can output
std::stringstream ss;
- REQUIRE_NOTHROW(ss << t << std::endl);
+ CATCH_REQUIRE_NOTHROW(ss << t << std::endl);
// set_
auto t2 = ones(*s, T);
@@ -65,22 +65,22 @@
require_equal_size_dim(t2, ones({0}, T));
// unsqueeze
- REQUIRE(t.unsqueeze(0).dim() == t.dim() + 1);
+ CATCH_REQUIRE(t.unsqueeze(0).dim() == t.dim() + 1);
// unsqueeze_
{
auto t2 = ones(*s, T);
auto r = t2.unsqueeze_(0);
- REQUIRE(r.dim() == t.dim() + 1);
+ CATCH_REQUIRE(r.dim() == t.dim() + 1);
}
// squeeze (with dimension argument)
if (t.dim() == 0 || t.sizes()[0] == 1) {
- REQUIRE(t.squeeze(0).dim() == std::max<int64_t>(t.dim() - 1, 0));
+ CATCH_REQUIRE(t.squeeze(0).dim() == std::max<int64_t>(t.dim() - 1, 0));
} else {
// In PyTorch, it is a no-op to try to squeeze a dimension that has size != 1;
// in NumPy this is an error.
- REQUIRE(t.squeeze(0).dim() == t.dim());
+ CATCH_REQUIRE(t.squeeze(0).dim() == t.dim());
}
// squeeze (with no dimension argument)
@@ -99,11 +99,11 @@
// squeeze_ (with dimension argument)
auto t2 = ones(*s, T);
if (t2.dim() == 0 || t2.sizes()[0] == 1) {
- REQUIRE(t2.squeeze_(0).dim() == std::max<int64_t>(t.dim() - 1, 0));
+ CATCH_REQUIRE(t2.squeeze_(0).dim() == std::max<int64_t>(t.dim() - 1, 0));
} else {
// In PyTorch, it is a no-op to try to squeeze a dimension that has size != 1;
// in NumPy this is an error.
- REQUIRE(t2.squeeze_(0).dim() == t.dim());
+ CATCH_REQUIRE(t2.squeeze_(0).dim() == t.dim());
}
}
@@ -122,31 +122,31 @@
// reduce (with dimension argument and with 1 return argument)
if (t.numel() != 0) {
- REQUIRE(t.sum(0).dim() == std::max<int64_t>(t.dim() - 1, 0));
+ CATCH_REQUIRE(t.sum(0).dim() == std::max<int64_t>(t.dim() - 1, 0));
} else {
- REQUIRE(t.sum(0).equal(at::zeros({}, T)));
+ CATCH_REQUIRE(t.sum(0).equal(at::zeros({}, T)));
}
// reduce (with dimension argument and with 2 return arguments)
if (t.numel() != 0) {
auto ret = t.min(0);
- REQUIRE(std::get<0>(ret).dim() == std::max<int64_t>(t.dim() - 1, 0));
- REQUIRE(std::get<1>(ret).dim() == std::max<int64_t>(t.dim() - 1, 0));
+ CATCH_REQUIRE(std::get<0>(ret).dim() == std::max<int64_t>(t.dim() - 1, 0));
+ CATCH_REQUIRE(std::get<1>(ret).dim() == std::max<int64_t>(t.dim() - 1, 0));
} else {
- REQUIRE_THROWS(t.min(0));
+ _CATCH_REQUIRE_THROWS(t.min(0));
}
// simple indexing
if (t.dim() > 0 && t.numel() != 0) {
- REQUIRE(t[0].dim() == std::max<int64_t>(t.dim() - 1, 0));
+ CATCH_REQUIRE(t[0].dim() == std::max<int64_t>(t.dim() - 1, 0));
} else {
- REQUIRE_THROWS(t[0]);
+ _CATCH_REQUIRE_THROWS(t[0]);
}
// fill_ (argument to fill_ can only be a 0-dim tensor)
TRY_CATCH_ELSE(t.fill_(t.sum(0)),
- REQUIRE(t.dim() > 1),
- REQUIRE(t.dim() <= 1));
+ CATCH_REQUIRE(t.dim() > 1),
+ CATCH_REQUIRE(t.dim() <= 1));
}
for (auto lhs_it = sizes.begin(); lhs_it != sizes.end(); ++lhs_it) {
@@ -156,8 +156,8 @@
auto lhs = ones(*lhs_it, T);
auto rhs = ones(*rhs_it, T);
if(*lhs_it != *rhs_it) {
- REQUIRE(!lhs.is_same_size(rhs));
- REQUIRE(!rhs.is_same_size(lhs));
+ CATCH_REQUIRE(!lhs.is_same_size(rhs));
+ CATCH_REQUIRE(!rhs.is_same_size(lhs));
}
}
// forced size functions (resize_, resize_as, set_)
@@ -192,7 +192,7 @@
auto storage = T.storage(rhs.numel(), false);
lhs.set_(storage);
// should not be dim 0 because an empty storage is dim 1; all other storages aren't scalars
- REQUIRE(lhs.dim() != 0);
+ CATCH_REQUIRE(lhs.dim() != 0);
}
{
// with storage, offset, sizes, strides
@@ -211,8 +211,8 @@
auto rhs = ones(*rhs_it, T);
auto rhs_size = *rhs_it;
TRY_CATCH_ELSE(auto result = lhs.view(rhs_size),
- REQUIRE(lhs.numel() != rhs.numel()),
- REQUIRE(lhs.numel() == rhs.numel()); require_equal_size_dim(result, rhs););
+ CATCH_REQUIRE(lhs.numel() != rhs.numel()),
+ CATCH_REQUIRE(lhs.numel() == rhs.numel()); require_equal_size_dim(result, rhs););
}
// take
@@ -220,7 +220,7 @@
auto lhs = ones(*lhs_it, T);
auto rhs = zeros(*rhs_it, T).toType(ScalarType::Long);
TRY_CATCH_ELSE(auto result = lhs.take(rhs),
- REQUIRE(lhs.numel() == 0); REQUIRE(rhs.numel() != 0),
+ CATCH_REQUIRE(lhs.numel() == 0); CATCH_REQUIRE(rhs.numel() != 0),
require_equal_size_dim(result, rhs));
}
@@ -230,7 +230,7 @@
auto lhs = ones(*lhs_it, T);
auto rhs = ones(*rhs_it, T);
TRY_CATCH_ELSE(auto result = lhs.ger(rhs),
- REQUIRE((lhs.numel() == 0 || rhs.numel() == 0 || lhs.dim() != 1 || rhs.dim() != 1)),
+ CATCH_REQUIRE((lhs.numel() == 0 || rhs.numel() == 0 || lhs.dim() != 1 || rhs.dim() != 1)),
[&]() {
int64_t dim0 = lhs.dim() == 0 ? 1 : lhs.size(0);
int64_t dim1 = rhs.dim() == 0 ? 1 : rhs.size(0);
@@ -246,8 +246,8 @@
auto rhs_size = *rhs_it;
bool should_pass = should_expand(lhs_size, rhs_size);
TRY_CATCH_ELSE(auto result = lhs.expand(rhs_size),
- REQUIRE(!should_pass),
- REQUIRE(should_pass); require_equal_size_dim(result, rhs););
+ CATCH_REQUIRE(!should_pass),
+ CATCH_REQUIRE(should_pass); require_equal_size_dim(result, rhs););
// in-place functions (would be good if we can also do a non-broadcasting one, b/c
// broadcasting functions will always end up operating on tensors of same size;
@@ -255,21 +255,21 @@
{
bool should_pass_inplace = should_expand(rhs_size, lhs_size);
TRY_CATCH_ELSE(lhs.add_(rhs),
- REQUIRE(!should_pass_inplace),
- REQUIRE(should_pass_inplace); require_equal_size_dim(lhs, ones(*lhs_it, T)););
+ CATCH_REQUIRE(!should_pass_inplace),
+ CATCH_REQUIRE(should_pass_inplace); require_equal_size_dim(lhs, ones(*lhs_it, T)););
}
}
}
}
}
-TEST_CASE( "scalar tensor test CPU", "[cpu]" ) {
+CATCH_TEST_CASE( "scalar tensor test CPU", "[cpu]" ) {
manual_seed(123, at::kCPU);
test(CPU(kFloat));
}
-TEST_CASE( "scalar tensor test CUDA", "[cuda]" ) {
+CATCH_TEST_CASE( "scalar tensor test CUDA", "[cuda]" ) {
manual_seed(123, at::kCUDA);
if (at::hasCUDA()) {
diff --git a/aten/src/ATen/test/scalar_test.cpp b/aten/src/ATen/test/scalar_test.cpp
index 72ef4e4..247830c 100644
--- a/aten/src/ATen/test/scalar_test.cpp
+++ b/aten/src/ATen/test/scalar_test.cpp
@@ -1,5 +1,5 @@
#define CATCH_CONFIG_MAIN
-#include "catch.hpp"
+#include "catch_utils.hpp"
#include <iostream>
// define constants like M_PI and C keywords for MSVC
@@ -33,25 +33,25 @@
void test_overflow() {
auto s1 = Scalar(M_PI);
- REQUIRE(s1.toFloat() == static_cast<float>(M_PI));
+ CATCH_REQUIRE(s1.toFloat() == static_cast<float>(M_PI));
s1.toHalf();
s1 = Scalar(100000);
- REQUIRE(s1.toFloat() == 100000.0);
- REQUIRE(s1.toInt() == 100000);
+ CATCH_REQUIRE(s1.toFloat() == 100000.0);
+ CATCH_REQUIRE(s1.toInt() == 100000);
- REQUIRE_THROWS_AS(s1.toHalf(), std::domain_error);
+ CATCH_REQUIRE_THROWS_AS(s1.toHalf(), std::domain_error);
s1 = Scalar(NAN);
- REQUIRE(std::isnan(s1.toFloat()));
- REQUIRE_THROWS_AS(s1.toInt(), std::domain_error);
+ CATCH_REQUIRE(std::isnan(s1.toFloat()));
+ CATCH_REQUIRE_THROWS_AS(s1.toInt(), std::domain_error);
s1 = Scalar(INFINITY);
- REQUIRE(std::isinf(s1.toFloat()));
- REQUIRE_THROWS_AS(s1.toInt(), std::domain_error);
+ CATCH_REQUIRE(std::isinf(s1.toFloat()));
+ CATCH_REQUIRE_THROWS_AS(s1.toInt(), std::domain_error);
}
-TEST_CASE( "scalar test", "[]" ) {
+CATCH_TEST_CASE( "scalar test", "[]" ) {
manual_seed(123, at::kCPU);
manual_seed(123, at::kCUDA);
@@ -62,7 +62,7 @@
Scalar h2 = h;
cout << "H2: " << h2.toDouble() << " " << what.toFloat() << " " << bar.toDouble() << " " << what.isIntegral() << "\n";
Generator & gen = at::globalContext().defaultGenerator(at::kCPU);
- REQUIRE_NOTHROW(gen.seed());
+ CATCH_REQUIRE_NOTHROW(gen.seed());
auto && C = at::globalContext();
if(at::hasCUDA()) {
auto t2 = zeros({4,4}, at::kCUDA);
@@ -71,12 +71,12 @@
auto t = ones({4,4});
auto wha2 = zeros({4,4}).add(t).sum();
- REQUIRE( wha2.toCDouble() == 16.0 );
+ CATCH_REQUIRE( wha2.toCDouble() == 16.0 );
- REQUIRE( t.sizes()[0] == 4 );
- REQUIRE( t.sizes()[1] == 4 );
- REQUIRE( t.strides()[0] == 4 );
- REQUIRE( t.strides()[1] == 1 );
+ CATCH_REQUIRE( t.sizes()[0] == 4 );
+ CATCH_REQUIRE( t.sizes()[1] == 4 );
+ CATCH_REQUIRE( t.strides()[0] == 4 );
+ CATCH_REQUIRE( t.strides()[1] == 1 );
Type & T = CPU(Float);
Tensor x = randn({1,10}, T);
@@ -88,26 +88,26 @@
Tensor next_h = i2h.add(h2h);
next_h = next_h.tanh();
- REQUIRE_THROWS(at::_local_scalar(Tensor{}));
+ _CATCH_REQUIRE_THROWS(at::_local_scalar(Tensor{}));
test_overflow();
if(at::hasCUDA()) {
auto r = CUDA(Float).copy(next_h);
- REQUIRE(CPU(Float).copy(r).equal(next_h));
+ CATCH_REQUIRE(CPU(Float).copy(r).equal(next_h));
}
- REQUIRE_NOTHROW(randn({10,10,2}, T));
+ CATCH_REQUIRE_NOTHROW(randn({10,10,2}, T));
// check Scalar.toTensor on Scalars backed by different data types
- REQUIRE(scalar_to_tensor(bar).type().scalarType() == kDouble);
- REQUIRE(scalar_to_tensor(what).type().scalarType() == kLong);
- REQUIRE(scalar_to_tensor(ones({})._local_scalar()).type().scalarType() == kDouble);
+ CATCH_REQUIRE(scalar_to_tensor(bar).type().scalarType() == kDouble);
+ CATCH_REQUIRE(scalar_to_tensor(what).type().scalarType() == kLong);
+ CATCH_REQUIRE(scalar_to_tensor(ones({})._local_scalar()).type().scalarType() == kDouble);
if (x.type().scalarType() != ScalarType::Half) {
AT_DISPATCH_ALL_TYPES(x.type(), "foo", [&] {
scalar_t s = 1;
std::stringstream ss;
- REQUIRE_NOTHROW(ss << "hello, dispatch" << x.type().toString() << s << "\n");
+ CATCH_REQUIRE_NOTHROW(ss << "hello, dispatch" << x.type().toString() << s << "\n");
auto data = (scalar_t*)x.data_ptr();
(void)data;
});
@@ -116,10 +116,10 @@
// test direct C-scalar type conversions
{
auto x = ones({1,2}, T);
- REQUIRE_THROWS(x.toCFloat());
+ _CATCH_REQUIRE_THROWS(x.toCFloat());
}
auto float_one = ones({}, T);
- REQUIRE(float_one.toCFloat() == 1);
- REQUIRE(float_one.toCInt() == 1);
- REQUIRE((float_one.toCHalf() == 1));
+ CATCH_REQUIRE(float_one.toCFloat() == 1);
+ CATCH_REQUIRE(float_one.toCInt() == 1);
+ CATCH_REQUIRE((float_one.toCHalf() == 1));
}
diff --git a/aten/src/ATen/test/stream_test.cpp b/aten/src/ATen/test/stream_test.cpp
index 145c4f4..8dc015d 100644
--- a/aten/src/ATen/test/stream_test.cpp
+++ b/aten/src/ATen/test/stream_test.cpp
@@ -1,5 +1,5 @@
#define CATCH_CONFIG_MAIN
-#include "catch.hpp"
+#include "catch_utils.hpp"
#include "ATen/cuda/CUDAContext.h"
#include "ATen/cuda/CUDAGuard.h"
@@ -14,7 +14,7 @@
/*
Tests related to ATen streams.
*/
-TEST_CASE(
+CATCH_TEST_CASE(
"Copying and Moving Streams",
"Verifies streams are live through copying and moving") {
int32_t device = -1;
@@ -29,14 +29,14 @@
copyStream = s;
- REQUIRE(copyStream.internals() == s.internals());
- REQUIRE(copyStream.device() == device);
- REQUIRE(copyStream.stream() == cuda_stream);
+ CATCH_REQUIRE(copyStream.internals() == s.internals());
+ CATCH_REQUIRE(copyStream.device() == device);
+ CATCH_REQUIRE(copyStream.stream() == cuda_stream);
}
- REQUIRE(copyStream.internals());
- REQUIRE(copyStream.device() == device);
- REQUIRE(copyStream.stream() == cuda_stream);
+ CATCH_REQUIRE(copyStream.internals());
+ CATCH_REQUIRE(copyStream.device() == device);
+ CATCH_REQUIRE(copyStream.stream() == cuda_stream);
// Tests that moving works as expected and preserves the stream
at::cuda::CUDAStream moveStream;
@@ -47,41 +47,41 @@
moveStream = std::move(s);
- REQUIRE(moveStream.device() == device);
- REQUIRE(moveStream.stream() == cuda_stream);
+ CATCH_REQUIRE(moveStream.device() == device);
+ CATCH_REQUIRE(moveStream.stream() == cuda_stream);
}
- REQUIRE(moveStream.internals());
- REQUIRE(moveStream.device() == device);
- REQUIRE(moveStream.stream() == cuda_stream);
+ CATCH_REQUIRE(moveStream.internals());
+ CATCH_REQUIRE(moveStream.device() == device);
+ CATCH_REQUIRE(moveStream.stream() == cuda_stream);
}
-TEST_CASE("Getting and Setting Streams", "Verifies streams are set properly") {
+CATCH_TEST_CASE("Getting and Setting Streams", "Verifies streams are set properly") {
at::cuda::CUDAStream myStream = at::cuda::createCUDAStream();
// Sets and gets
at::cuda::setCurrentCUDAStream(myStream);
at::cuda::CUDAStream curStream = at::cuda::getCurrentCUDAStream();
- REQUIRE(myStream == curStream);
+ CATCH_REQUIRE(myStream == curStream);
// Gets, sets, and gets default stream
at::cuda::CUDAStream defaultStream = at::cuda::getDefaultCUDAStream();
at::cuda::setCurrentCUDAStream(defaultStream);
curStream = at::cuda::getCurrentCUDAStream();
- REQUIRE(defaultStream != myStream);
- REQUIRE(curStream == defaultStream);
+ CATCH_REQUIRE(defaultStream != myStream);
+ CATCH_REQUIRE(curStream == defaultStream);
}
void thread_fun(at::cuda::CUDAStream& cur_thread_stream) {
auto new_stream = at::cuda::createCUDAStream();
at::cuda::setCurrentCUDAStream(new_stream);
cur_thread_stream = at::cuda::getCurrentCUDAStream();
- REQUIRE(cur_thread_stream == new_stream);
+ CATCH_REQUIRE(cur_thread_stream == new_stream);
}
-TEST_CASE(
+CATCH_TEST_CASE(
"Multithread Getting and Setting",
"Ensures streams are thread local") {
at::cuda::CUDAStream s0, s1;
@@ -94,25 +94,25 @@
at::cuda::CUDAStream cur_stream = at::cuda::getCurrentCUDAStream();
at::cuda::CUDAStream default_stream = at::cuda::getDefaultCUDAStream();
- REQUIRE(cur_stream == default_stream);
- REQUIRE(cur_stream != s0);
- REQUIRE(cur_stream != s1);
- REQUIRE(s0 != s1);
+ CATCH_REQUIRE(cur_stream == default_stream);
+ CATCH_REQUIRE(cur_stream != s0);
+ CATCH_REQUIRE(cur_stream != s1);
+ CATCH_REQUIRE(s0 != s1);
}
-TEST_CASE("CUDAGuard") {
+CATCH_TEST_CASE("CUDAGuard") {
if (at::cuda::getNumGPUs() < 2) {
return;
}
// -- begin setup
- REQUIRE(at::cuda::current_device() == 0);
+ CATCH_REQUIRE(at::cuda::current_device() == 0);
std::vector<at::cuda::CUDAStream> streams0 = {
at::cuda::getDefaultCUDAStream(),
at::cuda::createCUDAStream()};
- REQUIRE(streams0[0].device() == 0);
- REQUIRE(streams0[1].device() == 0);
+ CATCH_REQUIRE(streams0[0].device() == 0);
+ CATCH_REQUIRE(streams0[1].device() == 0);
at::cuda::setCurrentCUDAStream(streams0[0]);
std::vector<at::cuda::CUDAStream> streams1;
@@ -121,47 +121,47 @@
streams1.push_back(at::cuda::getDefaultCUDAStream());
streams1.push_back(at::cuda::createCUDAStream());
}
- REQUIRE(streams1[0].device() == 1);
- REQUIRE(streams1[1].device() == 1);
+ CATCH_REQUIRE(streams1[0].device() == 1);
+ CATCH_REQUIRE(streams1[1].device() == 1);
at::cuda::setCurrentCUDAStream(streams1[0]);
- REQUIRE(at::cuda::current_device() == 0);
+ CATCH_REQUIRE(at::cuda::current_device() == 0);
// -- end setup
// Test that all original streams are recorded.
{
at::cuda::CUDAGuard guard;
- REQUIRE(guard.original_streams().empty());
+ CATCH_REQUIRE(guard.original_streams().empty());
guard.set_stream(streams0[0]);
- REQUIRE(
+ CATCH_REQUIRE(
guard.original_streams().size() == at::cuda::getNumGPUs());
- REQUIRE(guard.original_streams()[0] == streams0[0]);
- REQUIRE(guard.original_streams()[1] == streams1[0]);
+ CATCH_REQUIRE(guard.original_streams()[0] == streams0[0]);
+ CATCH_REQUIRE(guard.original_streams()[1] == streams1[0]);
}
// Setting a stream changes the current device and the stream on that device
{
at::cuda::CUDAGuard guard(streams1[1]);
- REQUIRE(guard.last_device() == 1);
- REQUIRE(at::cuda::current_device() == 1);
- REQUIRE(at::cuda::getCurrentCUDAStream(1) == streams1[1]);
+ CATCH_REQUIRE(guard.last_device() == 1);
+ CATCH_REQUIRE(at::cuda::current_device() == 1);
+ CATCH_REQUIRE(at::cuda::getCurrentCUDAStream(1) == streams1[1]);
}
// Device and stream are now reset
- REQUIRE(at::cuda::current_device() == 0);
- REQUIRE(at::cuda::getCurrentCUDAStream(1) == streams1[0]);
+ CATCH_REQUIRE(at::cuda::current_device() == 0);
+ CATCH_REQUIRE(at::cuda::getCurrentCUDAStream(1) == streams1[0]);
// Setting only the device changes only the current device and not the stream
{
at::cuda::CUDAGuard guard(/*device=*/1);
- REQUIRE(guard.last_device() == 1);
- REQUIRE(at::cuda::current_device() == 1);
- REQUIRE(at::cuda::getCurrentCUDAStream(1) == streams1[0]);
+ CATCH_REQUIRE(guard.last_device() == 1);
+ CATCH_REQUIRE(at::cuda::current_device() == 1);
+ CATCH_REQUIRE(at::cuda::getCurrentCUDAStream(1) == streams1[0]);
}
- REQUIRE(at::cuda::current_device() == 0);
- REQUIRE(at::cuda::getCurrentCUDAStream(0) == streams0[0]);
+ CATCH_REQUIRE(at::cuda::current_device() == 0);
+ CATCH_REQUIRE(at::cuda::getCurrentCUDAStream(0) == streams0[0]);
// Setting the stream first, and then the device, first changes the devices
// back, and then resets the stream on the initial device.
@@ -171,12 +171,12 @@
guard.set_device(1);
}
- REQUIRE(at::cuda::current_device() == 0);
- REQUIRE(at::cuda::getCurrentCUDAStream(0) == streams0[0]);
- REQUIRE(at::cuda::getCurrentCUDAStream(1) == streams1[0]);
+ CATCH_REQUIRE(at::cuda::current_device() == 0);
+ CATCH_REQUIRE(at::cuda::getCurrentCUDAStream(0) == streams0[0]);
+ CATCH_REQUIRE(at::cuda::getCurrentCUDAStream(1) == streams1[0]);
}
-TEST_CASE("CUDAGuardIsMovable") {
+CATCH_TEST_CASE("CUDAGuardIsMovable") {
if (at::cuda::getNumGPUs() < 2) {
return;
}
@@ -185,17 +185,17 @@
at::cuda::CUDAGuard first(stream);
first.set_device(1);
at::cuda::CUDAGuard second(std::move(first));
- REQUIRE(second.original_streams().size() == device_count);
- REQUIRE(second.original_device() == 0);
- REQUIRE(second.last_device() == 1);
+ CATCH_REQUIRE(second.original_streams().size() == device_count);
+ CATCH_REQUIRE(second.original_device() == 0);
+ CATCH_REQUIRE(second.last_device() == 1);
at::cuda::CUDAGuard third;
third = std::move(second);
- REQUIRE(third.original_streams().size() == device_count);
- REQUIRE(third.original_device() == 0);
- REQUIRE(third.last_device() == 1);
+ CATCH_REQUIRE(third.original_streams().size() == device_count);
+ CATCH_REQUIRE(third.original_device() == 0);
+ CATCH_REQUIRE(third.last_device() == 1);
}
-TEST_CASE("Streampool Round Robin") {
+CATCH_TEST_CASE("Streampool Round Robin") {
std::vector<at::cuda::CUDAStream> streams{};
for (int i = 0; i < 200; ++i) {
streams.emplace_back(at::cuda::detail::CUDAStream_createStream());
@@ -209,10 +209,10 @@
if (!result_pair.second) hasDuplicates = true;
}
- REQUIRE(hasDuplicates);
+ CATCH_REQUIRE(hasDuplicates);
}
-TEST_CASE("Multi-GPU") {
+CATCH_TEST_CASE("Multi-GPU") {
if (at::cuda::getNumGPUs() < 2) return;
at::cuda::CUDAStream s0 = at::cuda::createCUDAStream(true, 0);
@@ -221,17 +221,17 @@
at::cuda::setCurrentCUDAStream(s0);
at::cuda::setCurrentCUDAStream(s1);
- REQUIRE(s0 == at::cuda::getCurrentCUDAStream());
+ CATCH_REQUIRE(s0 == at::cuda::getCurrentCUDAStream());
at::DeviceGuard device_guard{1};
- REQUIRE(s1 == at::cuda::getCurrentCUDAStream());
+ CATCH_REQUIRE(s1 == at::cuda::getCurrentCUDAStream());
}
-TEST_CASE("CUDAEvent Syncs") {
+CATCH_TEST_CASE("CUDAEvent Syncs") {
const auto stream = at::cuda::createCUDAStream();
at::cuda::CUDAEvent event;
- REQUIRE(!event.happened());
+ CATCH_REQUIRE(!event.happened());
event.recordOnce(stream);
@@ -242,10 +242,10 @@
wait_stream1.synchronize_with(event);
cudaStreamSynchronize(wait_stream0);
- REQUIRE(event.happened());
+ CATCH_REQUIRE(event.happened());
}
-TEST_CASE("Cross-Device Events") {
+CATCH_TEST_CASE("Cross-Device Events") {
if (at::cuda::getNumGPUs() < 2) return;
const auto stream0 = at::cuda::createCUDAStream();
@@ -260,10 +260,10 @@
event0 = std::move(event1);
- REQUIRE(event0.device() == 1);
+ CATCH_REQUIRE(event0.device() == 1);
stream0.synchronize_with(event0);
cudaStreamSynchronize(stream0);
- REQUIRE(event0.happened());
+ CATCH_REQUIRE(event0.happened());
}
diff --git a/aten/src/ATen/test/test_parallel.cpp b/aten/src/ATen/test/test_parallel.cpp
index 5523280..8170173 100644
--- a/aten/src/ATen/test/test_parallel.cpp
+++ b/aten/src/ATen/test/test_parallel.cpp
@@ -1,5 +1,5 @@
#define CATCH_CONFIG_MAIN
-#include "catch.hpp"
+#include "catch_utils.hpp"
#include "ATen/ATen.h"
#include "ATen/DLConvertor.h"
@@ -11,7 +11,7 @@
using namespace at;
-TEST_CASE( "parallel", "[cpu]" ) {
+CATCH_TEST_CASE( "parallel", "[cpu]" ) {
manual_seed(123, at::kCPU);
set_num_threads(1);
@@ -24,5 +24,5 @@
as[0] = 1;
as[1] = 0;
as[2] = 0;
- REQUIRE(a.sum(0).equal(as));
+ CATCH_REQUIRE(a.sum(0).equal(as));
}
diff --git a/aten/src/ATen/test/undefined_tensor_test.cpp b/aten/src/ATen/test/undefined_tensor_test.cpp
index e47772a..c01dff2 100644
--- a/aten/src/ATen/test/undefined_tensor_test.cpp
+++ b/aten/src/ATen/test/undefined_tensor_test.cpp
@@ -1,5 +1,5 @@
#define CATCH_CONFIG_MAIN
-#include "catch.hpp"
+#include "catch_utils.hpp"
#include "ATen/ATen.h"
#include "ATen/core/UndefinedTensorImpl.h"
@@ -8,7 +8,7 @@
using namespace at;
-TEST_CASE( "undefined tensor test", "[]" ) {
+CATCH_TEST_CASE( "undefined tensor test", "[]" ) {
manual_seed(123, at::kCPU);
// mainly test ops on undefined tensors don't segfault and give a reasonable errror message.
@@ -17,36 +17,36 @@
std::stringstream ss;
ss << und << std::endl;
- REQUIRE(!und.defined());
- REQUIRE(std::string("UndefinedType") == und.toString());
+ CATCH_REQUIRE(!und.defined());
+ CATCH_REQUIRE(std::string("UndefinedType") == und.toString());
- REQUIRE_THROWS(und.strides());
- REQUIRE_THROWS(und.dim());
- REQUIRE_THROWS([]() {return Tensor();}() = Scalar(5));
- REQUIRE_THROWS(und.add(und));
- REQUIRE_THROWS(und.add(ft));
- REQUIRE_THROWS(ft.add(und));
- REQUIRE_THROWS(und.add(5));
- REQUIRE_THROWS(und.mm(und));
+ _CATCH_REQUIRE_THROWS(und.strides());
+ _CATCH_REQUIRE_THROWS(und.dim());
+ _CATCH_REQUIRE_THROWS([]() {return Tensor();}() = Scalar(5));
+ _CATCH_REQUIRE_THROWS(und.add(und));
+ _CATCH_REQUIRE_THROWS(und.add(ft));
+ _CATCH_REQUIRE_THROWS(ft.add(und));
+ _CATCH_REQUIRE_THROWS(und.add(5));
+ _CATCH_REQUIRE_THROWS(und.mm(und));
und.toType(und.type());
- REQUIRE_THROWS(und.toType(ft.type()));
- REQUIRE_THROWS(ft.toType(und.type()));
+ _CATCH_REQUIRE_THROWS(und.toType(ft.type()));
+ _CATCH_REQUIRE_THROWS(ft.toType(und.type()));
und.toType(ScalarType::Undefined);
- REQUIRE_THROWS(und.toType(ScalarType::Float));
- REQUIRE_THROWS(ft.toType(ScalarType::Undefined));
+ _CATCH_REQUIRE_THROWS(und.toType(ScalarType::Float));
+ _CATCH_REQUIRE_THROWS(ft.toType(ScalarType::Undefined));
// copy_
- REQUIRE_THROWS(und.copy_(und));
- REQUIRE_THROWS(und.copy_(ft));
- REQUIRE_THROWS(ft.copy_(und));
+ _CATCH_REQUIRE_THROWS(und.copy_(und));
+ _CATCH_REQUIRE_THROWS(und.copy_(ft));
+ _CATCH_REQUIRE_THROWS(ft.copy_(und));
und.toBackend(Backend::Undefined);
- REQUIRE_THROWS(und.toBackend(Backend::CPU));
- REQUIRE_THROWS(ft.toBackend(Backend::Undefined));
+ _CATCH_REQUIRE_THROWS(und.toBackend(Backend::CPU));
+ _CATCH_REQUIRE_THROWS(ft.toBackend(Backend::Undefined));
Tensor to_move = ones({1}, CPU(kFloat));
Tensor m(std::move(to_move));
- REQUIRE(!to_move.defined());
- REQUIRE(to_move.unsafeGetTensorImpl() == UndefinedTensorImpl::singleton());
+ CATCH_REQUIRE(!to_move.defined());
+ CATCH_REQUIRE(to_move.unsafeGetTensorImpl() == UndefinedTensorImpl::singleton());
}
diff --git a/aten/src/ATen/test/weakref_test.cpp b/aten/src/ATen/test/weakref_test.cpp
index 167520b..42c9f61 100644
--- a/aten/src/ATen/test/weakref_test.cpp
+++ b/aten/src/ATen/test/weakref_test.cpp
@@ -1,5 +1,5 @@
#define CATCH_CONFIG_MAIN
-#include "catch.hpp"
+#include "catch_utils.hpp"
#include "ATen/ATen.h"
@@ -10,53 +10,53 @@
using at::Tensor;
using at::WeakTensor;
-TEST_CASE( "Weak pointer tests", "" ) {
- SECTION("gets invalidated") {
+CATCH_TEST_CASE( "Weak pointer tests", "" ) {
+ CATCH_SECTION("gets invalidated") {
Tensor a = at::ones({2, 2});
WeakTensor b = a;
a.reset();
- REQUIRE_FALSE(b.lock().defined());
+ CATCH_REQUIRE_FALSE(b.lock().defined());
}
- SECTION("can successfully lock") {
+ CATCH_SECTION("can successfully lock") {
Tensor a = at::ones({2, 2});
WeakTensor b = a;
auto c = b.lock();
- REQUIRE(c.defined());
+ CATCH_REQUIRE(c.defined());
a.reset();
- REQUIRE(b.lock().defined());
+ CATCH_REQUIRE(b.lock().defined());
c.reset();
- REQUIRE_FALSE(b.lock().defined());
+ CATCH_REQUIRE_FALSE(b.lock().defined());
}
- SECTION("updates refcounts correctly") {
+ CATCH_SECTION("updates refcounts correctly") {
Tensor a = at::ones({2, 2});
- REQUIRE(a.use_count() == 1);
- REQUIRE(a.weak_use_count() == 1);
+ CATCH_REQUIRE(a.use_count() == 1);
+ CATCH_REQUIRE(a.weak_use_count() == 1);
{
WeakTensor b = a;
- REQUIRE(a.use_count() == 1);
- REQUIRE(a.weak_use_count() == 2);
+ CATCH_REQUIRE(a.use_count() == 1);
+ CATCH_REQUIRE(a.weak_use_count() == 2);
}
- REQUIRE(a.use_count() == 1);
- REQUIRE(a.weak_use_count() == 1);
+ CATCH_REQUIRE(a.use_count() == 1);
+ CATCH_REQUIRE(a.weak_use_count() == 1);
{
WeakTensor b = a;
- REQUIRE(a.use_count() == 1);
+ CATCH_REQUIRE(a.use_count() == 1);
auto locked = b.lock();
- REQUIRE(locked.defined());
- REQUIRE(a.use_count() == 2);
+ CATCH_REQUIRE(locked.defined());
+ CATCH_REQUIRE(a.use_count() == 2);
}
- REQUIRE(a.use_count() == 1);
- REQUIRE(a.weak_use_count() == 1);
+ CATCH_REQUIRE(a.use_count() == 1);
+ CATCH_REQUIRE(a.weak_use_count() == 1);
{
WeakTensor b = a;
- REQUIRE(a.use_count() == 1);
- REQUIRE(a.weak_use_count() == 2);
+ CATCH_REQUIRE(a.use_count() == 1);
+ CATCH_REQUIRE(a.weak_use_count() == 2);
a.reset();
- REQUIRE(b.use_count() == 0);
- REQUIRE(b.weak_use_count() == 1);
+ CATCH_REQUIRE(b.use_count() == 0);
+ CATCH_REQUIRE(b.weak_use_count() == 1);
}
}
}
diff --git a/aten/src/ATen/test/wrapdim_test.cpp b/aten/src/ATen/test/wrapdim_test.cpp
index 8e813bc..f76dac2 100644
--- a/aten/src/ATen/test/wrapdim_test.cpp
+++ b/aten/src/ATen/test/wrapdim_test.cpp
@@ -1,43 +1,43 @@
#define CATCH_CONFIG_MAIN
-#include "catch.hpp"
+#include "catch_utils.hpp"
#include "ATen/ATen.h"
#include "test_seed.h"
using namespace at;
-TEST_CASE( "wrapdim test", "[]" ) {
+CATCH_TEST_CASE( "wrapdim test", "[]" ) {
manual_seed(123, at::kCPU);
Type & T = CPU(kFloat);
- SECTION( "simple case" ) {
+ CATCH_SECTION( "simple case" ) {
auto a = randn({2, 3, 4, 5}, T);
- REQUIRE(a.prod(-4).equal(a.prod(0)));
- REQUIRE(a.prod(3).equal(a.prod(-1)));
+ CATCH_REQUIRE(a.prod(-4).equal(a.prod(0)));
+ CATCH_REQUIRE(a.prod(3).equal(a.prod(-1)));
}
- SECTION( "expression specification" ) {
+ CATCH_SECTION( "expression specification" ) {
auto a = randn({2, 3, 4, 5}, T);
- REQUIRE(a.unsqueeze(-5).equal(a.unsqueeze(0)));
- REQUIRE(a.unsqueeze(4).equal(a.unsqueeze(-1)));
+ CATCH_REQUIRE(a.unsqueeze(-5).equal(a.unsqueeze(0)));
+ CATCH_REQUIRE(a.unsqueeze(4).equal(a.unsqueeze(-1)));
// can unsqueeze scalar
auto b = randn(1, T);
b.unsafeGetTensorImpl()->maybe_zero_dim(true);
- REQUIRE(b.unsqueeze(0).equal(b.unsqueeze(-1)));
+ CATCH_REQUIRE(b.unsqueeze(0).equal(b.unsqueeze(-1)));
}
- SECTION( "empty tensor" ) {
+ CATCH_SECTION( "empty tensor" ) {
auto a = randn(0, T);
- REQUIRE(a.prod(0).equal(at::ones({}, T)));
+ CATCH_REQUIRE(a.prod(0).equal(at::ones({}, T)));
}
- SECTION( "scalar vs 1-dim, 1-size" ) {
+ CATCH_SECTION( "scalar vs 1-dim, 1-size" ) {
auto a = randn(1, T);
- REQUIRE(a.prod(0).equal(a.prod(-1)));
+ CATCH_REQUIRE(a.prod(0).equal(a.prod(-1)));
a.unsafeGetTensorImpl()->maybe_zero_dim(true);
- REQUIRE(a.dim() == 0);
- REQUIRE(a.prod(0).equal(a.prod(-1)));
+ CATCH_REQUIRE(a.dim() == 0);
+ CATCH_REQUIRE(a.prod(0).equal(a.prod(-1)));
}
}
diff --git a/test/cpp/api/any.cpp b/test/cpp/api/any.cpp
index 9368d4d..18db2f5 100644
--- a/test/cpp/api/any.cpp
+++ b/test/cpp/api/any.cpp
@@ -1,4 +1,4 @@
-#include <catch.hpp>
+#include "catch_utils.hpp"
#include <torch/nn/modules/any.h>
#include <torch/torch.h>
@@ -13,39 +13,39 @@
using Catch::Contains;
using Catch::StartsWith;
-TEST_CASE("any-module") {
+CATCH_TEST_CASE("any-module") {
torch::manual_seed(0);
- SECTION("int()") {
+ CATCH_SECTION("int()") {
struct M : torch::nn::Module {
int forward() {
return 123;
}
};
AnyModule any(M{});
- REQUIRE(any.forward<int>() == 123);
+ CATCH_REQUIRE(any.forward<int>() == 123);
}
- SECTION("int(int)") {
+ CATCH_SECTION("int(int)") {
struct M : torch::nn::Module {
int forward(int x) {
return x;
}
};
AnyModule any(M{});
- REQUIRE(any.forward<int>(5) == 5);
+ CATCH_REQUIRE(any.forward<int>(5) == 5);
}
- SECTION("const char*(const char*)") {
+ CATCH_SECTION("const char*(const char*)") {
struct M : torch::nn::Module {
const char* forward(const char* x) {
return x;
}
};
AnyModule any(M{});
- REQUIRE(any.forward<const char*>("hello") == std::string("hello"));
+ CATCH_REQUIRE(any.forward<const char*>("hello") == std::string("hello"));
}
- SECTION("string(int, const double)") {
+ CATCH_SECTION("string(int, const double)") {
struct M : torch::nn::Module {
std::string forward(int x, const double f) {
return std::to_string(static_cast<int>(x + f));
@@ -53,10 +53,10 @@
};
AnyModule any(M{});
int x = 4;
- REQUIRE(any.forward<std::string>(x, 3.14) == std::string("7"));
+ CATCH_REQUIRE(any.forward<std::string>(x, 3.14) == std::string("7"));
}
- SECTION("Tensor(string, const string&, string&&)") {
+ CATCH_SECTION("Tensor(string, const string&, string&&)") {
struct M : torch::nn::Module {
torch::Tensor forward(
std::string a,
@@ -67,42 +67,42 @@
}
};
AnyModule any(M{});
- REQUIRE(
+ CATCH_REQUIRE(
any.forward(
std::string("a"), std::string("ab"), std::string("abc"))
.sum()
.toCInt() == 6);
}
- SECTION("wrong argument type") {
+ CATCH_SECTION("wrong argument type") {
struct M : torch::nn::Module {
int forward(float x) {
return x;
}
};
AnyModule any(M{});
- REQUIRE_THROWS_WITH(
+ CATCH_REQUIRE_THROWS_WITH(
any.forward(5.0),
StartsWith("Expected argument #0 to be of type float, "
"but received value of type double"));
}
- SECTION("wrong number of arguments") {
+ CATCH_SECTION("wrong number of arguments") {
struct M : torch::nn::Module {
int forward(int a, int b) {
return a + b;
}
};
AnyModule any(M{});
- REQUIRE_THROWS_WITH(
+ CATCH_REQUIRE_THROWS_WITH(
any.forward(),
Contains("M's forward() method expects 2 arguments, but received 0"));
- REQUIRE_THROWS_WITH(
+ CATCH_REQUIRE_THROWS_WITH(
any.forward(5),
Contains("M's forward() method expects 2 arguments, but received 1"));
- REQUIRE_THROWS_WITH(
+ CATCH_REQUIRE_THROWS_WITH(
any.forward(1, 2, 3),
Contains("M's forward() method expects 2 arguments, but received 3"));
}
- SECTION("get()") {
+ CATCH_SECTION("get()") {
struct M : torch::nn::Module {
explicit M(int value_) : torch::nn::Module("M"), value(value_) {}
int value;
@@ -112,16 +112,16 @@
};
AnyModule any(M{5});
- SECTION("good cast") {
- REQUIRE(any.get<M>().value == 5);
+ CATCH_SECTION("good cast") {
+ CATCH_REQUIRE(any.get<M>().value == 5);
}
- SECTION("bad cast") {
+ CATCH_SECTION("bad cast") {
struct N : torch::nn::Module {};
- REQUIRE_THROWS_WITH(any.get<N>(), StartsWith("Attempted to cast module"));
+ CATCH_REQUIRE_THROWS_WITH(any.get<N>(), StartsWith("Attempted to cast module"));
}
}
- SECTION("ptr()") {
+ CATCH_SECTION("ptr()") {
struct M : torch::nn::Module {
explicit M(int value_) : torch::nn::Module("M"), value(value_) {}
int value;
@@ -131,24 +131,24 @@
};
AnyModule any(M{5});
- SECTION("base class cast") {
+ CATCH_SECTION("base class cast") {
auto ptr = any.ptr();
- REQUIRE(ptr != nullptr);
- REQUIRE(ptr->name() == "M");
+ CATCH_REQUIRE(ptr != nullptr);
+ CATCH_REQUIRE(ptr->name() == "M");
}
- SECTION("good downcast") {
+ CATCH_SECTION("good downcast") {
auto ptr = any.ptr<M>();
- REQUIRE(ptr != nullptr);
- REQUIRE(ptr->value == 5);
+ CATCH_REQUIRE(ptr != nullptr);
+ CATCH_REQUIRE(ptr->value == 5);
}
- SECTION("bad downcast") {
+ CATCH_SECTION("bad downcast") {
struct N : torch::nn::Module {};
- REQUIRE_THROWS_WITH(any.ptr<N>(), StartsWith("Attempted to cast module"));
+ CATCH_REQUIRE_THROWS_WITH(any.ptr<N>(), StartsWith("Attempted to cast module"));
}
}
- SECTION("default state is empty") {
+ CATCH_SECTION("default state is empty") {
struct M : torch::nn::Module {
explicit M(int value_) : value(value_) {}
int value;
@@ -157,33 +157,33 @@
}
};
AnyModule any;
- REQUIRE(any.is_empty());
+ CATCH_REQUIRE(any.is_empty());
any = std::make_shared<M>(5);
- REQUIRE(!any.is_empty());
- REQUIRE(any.get<M>().value == 5);
+ CATCH_REQUIRE(!any.is_empty());
+ CATCH_REQUIRE(any.get<M>().value == 5);
}
- SECTION("all methods throw for empty AnyModule") {
+ CATCH_SECTION("all methods throw for empty AnyModule") {
struct M : torch::nn::Module {
int forward(int x) {
return x;
}
};
AnyModule any;
- REQUIRE(any.is_empty());
- REQUIRE_THROWS_WITH(
+ CATCH_REQUIRE(any.is_empty());
+ CATCH_REQUIRE_THROWS_WITH(
any.get<M>(), StartsWith("Cannot call get() on an empty AnyModule"));
- REQUIRE_THROWS_WITH(
+ CATCH_REQUIRE_THROWS_WITH(
any.ptr<M>(), StartsWith("Cannot call ptr() on an empty AnyModule"));
- REQUIRE_THROWS_WITH(
+ CATCH_REQUIRE_THROWS_WITH(
any.ptr(), StartsWith("Cannot call ptr() on an empty AnyModule"));
- REQUIRE_THROWS_WITH(
+ CATCH_REQUIRE_THROWS_WITH(
any.type_info(),
StartsWith("Cannot call type_info() on an empty AnyModule"));
- REQUIRE_THROWS_WITH(
+ CATCH_REQUIRE_THROWS_WITH(
any.forward<int>(5),
StartsWith("Cannot call forward() on an empty AnyModule"));
}
- SECTION("can move assign different modules") {
+ CATCH_SECTION("can move assign different modules") {
struct M : torch::nn::Module {
std::string forward(int x) {
return std::to_string(x);
@@ -195,15 +195,15 @@
}
};
AnyModule any;
- REQUIRE(any.is_empty());
+ CATCH_REQUIRE(any.is_empty());
any = std::make_shared<M>();
- REQUIRE(!any.is_empty());
- REQUIRE(any.forward<std::string>(5) == "5");
+ CATCH_REQUIRE(!any.is_empty());
+ CATCH_REQUIRE(any.forward<std::string>(5) == "5");
any = std::make_shared<N>();
- REQUIRE(!any.is_empty());
- REQUIRE(any.forward<int>(5.0f) == 8);
+ CATCH_REQUIRE(!any.is_empty());
+ CATCH_REQUIRE(any.forward<int>(5.0f) == 8);
}
- SECTION("constructs from ModuleHolder") {
+ CATCH_SECTION("constructs from ModuleHolder") {
struct MImpl : torch::nn::Module {
explicit MImpl(int value_) : torch::nn::Module("M"), value(value_) {}
int value;
@@ -218,14 +218,14 @@
};
AnyModule any(M{5});
- REQUIRE(any.get<MImpl>().value == 5);
- REQUIRE(any.get<M>()->value == 5);
+ CATCH_REQUIRE(any.get<MImpl>().value == 5);
+ CATCH_REQUIRE(any.get<M>()->value == 5);
AnyModule module(Linear(3, 4));
std::shared_ptr<Module> ptr = module.ptr();
Linear linear(module.get<Linear>());
}
- SECTION("converts autograd::Variable to torch::Tensor correctly") {
+ CATCH_SECTION("converts autograd::Variable to torch::Tensor correctly") {
struct M : torch::nn::Module {
torch::Tensor forward(torch::Tensor input) {
return input;
@@ -236,12 +236,12 @@
// torch::Tensor before being passed to the function (to avoid a type
// mismatch).
AnyModule any(M{});
- REQUIRE(
+ CATCH_REQUIRE(
any.forward(torch::autograd::Variable(torch::ones(5)))
.sum()
.toCFloat() == 5);
// at::Tensors that are not variables work too.
- REQUIRE(any.forward(at::ones(5)).sum().toCFloat() == 5);
+ CATCH_REQUIRE(any.forward(at::ones(5)).sum().toCFloat() == 5);
}
}
}
@@ -263,92 +263,92 @@
} // namespace nn
} // namespace torch
-TEST_CASE("any-value") {
+CATCH_TEST_CASE("any-value") {
torch::manual_seed(0);
- SECTION("gets the correct value for the right type") {
- SECTION("int") {
+ CATCH_SECTION("gets the correct value for the right type") {
+ CATCH_SECTION("int") {
auto value = make_value(5);
// const and non-const types have the same typeid()
- REQUIRE(value.try_get<int>() != nullptr);
- REQUIRE(value.try_get<const int>() != nullptr);
- REQUIRE(value.get<int>() == 5);
+ CATCH_REQUIRE(value.try_get<int>() != nullptr);
+ CATCH_REQUIRE(value.try_get<const int>() != nullptr);
+ CATCH_REQUIRE(value.get<int>() == 5);
}
- SECTION("const int") {
+ CATCH_SECTION("const int") {
auto value = make_value(5);
- REQUIRE(value.try_get<const int>() != nullptr);
- REQUIRE(value.try_get<int>() != nullptr);
- REQUIRE(value.get<const int>() == 5);
+ CATCH_REQUIRE(value.try_get<const int>() != nullptr);
+ CATCH_REQUIRE(value.try_get<int>() != nullptr);
+ CATCH_REQUIRE(value.get<const int>() == 5);
}
- SECTION("const char*") {
+ CATCH_SECTION("const char*") {
auto value = make_value("hello");
- REQUIRE(value.try_get<const char*>() != nullptr);
- REQUIRE(value.get<const char*>() == std::string("hello"));
+ CATCH_REQUIRE(value.try_get<const char*>() != nullptr);
+ CATCH_REQUIRE(value.get<const char*>() == std::string("hello"));
}
- SECTION("std::string") {
+ CATCH_SECTION("std::string") {
auto value = make_value(std::string("hello"));
- REQUIRE(value.try_get<std::string>() != nullptr);
- REQUIRE(value.get<std::string>() == "hello");
+ CATCH_REQUIRE(value.try_get<std::string>() != nullptr);
+ CATCH_REQUIRE(value.get<std::string>() == "hello");
}
- SECTION("pointers") {
+ CATCH_SECTION("pointers") {
std::string s("hello");
std::string* p = &s;
auto value = make_value(p);
- REQUIRE(value.try_get<std::string*>() != nullptr);
- REQUIRE(*value.get<std::string*>() == "hello");
+ CATCH_REQUIRE(value.try_get<std::string*>() != nullptr);
+ CATCH_REQUIRE(*value.get<std::string*>() == "hello");
}
- SECTION("references") {
+ CATCH_SECTION("references") {
std::string s("hello");
const std::string& t = s;
auto value = make_value(t);
- REQUIRE(value.try_get<std::string>() != nullptr);
- REQUIRE(value.get<std::string>() == "hello");
+ CATCH_REQUIRE(value.try_get<std::string>() != nullptr);
+ CATCH_REQUIRE(value.get<std::string>() == "hello");
}
}
- SECTION("try_get returns nullptr for the wrong type") {
+ CATCH_SECTION("try_get returns nullptr for the wrong type") {
auto value = make_value(5);
- REQUIRE(value.try_get<int>() != nullptr);
- REQUIRE(value.try_get<float>() == nullptr);
- REQUIRE(value.try_get<long>() == nullptr);
- REQUIRE(value.try_get<std::string>() == nullptr);
+ CATCH_REQUIRE(value.try_get<int>() != nullptr);
+ CATCH_REQUIRE(value.try_get<float>() == nullptr);
+ CATCH_REQUIRE(value.try_get<long>() == nullptr);
+ CATCH_REQUIRE(value.try_get<std::string>() == nullptr);
}
- SECTION("get throws for the wrong type") {
+ CATCH_SECTION("get throws for the wrong type") {
auto value = make_value(5);
- REQUIRE(value.try_get<int>() != nullptr);
- REQUIRE_THROWS_WITH(
+ CATCH_REQUIRE(value.try_get<int>() != nullptr);
+ CATCH_REQUIRE_THROWS_WITH(
value.get<float>(),
StartsWith("Attempted to cast Value to float, "
"but its actual type is int"));
- REQUIRE_THROWS_WITH(
+ CATCH_REQUIRE_THROWS_WITH(
value.get<long>(),
StartsWith("Attempted to cast Value to long, "
"but its actual type is int"));
}
- SECTION("move is allowed") {
+ CATCH_SECTION("move is allowed") {
auto value = make_value(5);
- SECTION("construction") {
+ CATCH_SECTION("construction") {
auto copy = make_value(std::move(value));
- REQUIRE(copy.try_get<int>() != nullptr);
- REQUIRE(copy.get<int>() == 5);
+ CATCH_REQUIRE(copy.try_get<int>() != nullptr);
+ CATCH_REQUIRE(copy.get<int>() == 5);
}
- SECTION("assignment") {
+ CATCH_SECTION("assignment") {
auto copy = make_value(10);
copy = std::move(value);
- REQUIRE(copy.try_get<int>() != nullptr);
- REQUIRE(copy.get<int>() == 5);
+ CATCH_REQUIRE(copy.try_get<int>() != nullptr);
+ CATCH_REQUIRE(copy.get<int>() == 5);
}
}
- SECTION("type_info is correct") {
- SECTION("int") {
+ CATCH_SECTION("type_info is correct") {
+ CATCH_SECTION("int") {
auto value = make_value(5);
- REQUIRE(value.type_info().hash_code() == typeid(int).hash_code());
+ CATCH_REQUIRE(value.type_info().hash_code() == typeid(int).hash_code());
}
- SECTION("const char") {
+ CATCH_SECTION("const char") {
auto value = make_value("hello");
- REQUIRE(value.type_info().hash_code() == typeid(const char*).hash_code());
+ CATCH_REQUIRE(value.type_info().hash_code() == typeid(const char*).hash_code());
}
- SECTION("std::string") {
+ CATCH_SECTION("std::string") {
auto value = make_value(std::string("hello"));
- REQUIRE(value.type_info().hash_code() == typeid(std::string).hash_code());
+ CATCH_REQUIRE(value.type_info().hash_code() == typeid(std::string).hash_code());
}
}
}
diff --git a/test/cpp/api/catch_utils.hpp b/test/cpp/api/catch_utils.hpp
new file mode 100644
index 0000000..b9b0a87
--- /dev/null
+++ b/test/cpp/api/catch_utils.hpp
@@ -0,0 +1,8 @@
+#pragma once
+
+#define CATCH_CONFIG_PREFIX_ALL
+#include <catch.hpp>
+
+// CATCH_REQUIRE_THROWS is not defined identically to REQUIRE_THROWS and causes warning;
+// define our own version that doesn't warn.
+#define _CATCH_REQUIRE_THROWS( ... ) INTERNAL_CATCH_THROWS( "CATCH_REQUIRE_THROWS", Catch::ResultDisposition::Normal, __VA_ARGS__ )
diff --git a/test/cpp/api/cursor.cpp b/test/cpp/api/cursor.cpp
index 5c99866..e08bd78 100644
--- a/test/cpp/api/cursor.cpp
+++ b/test/cpp/api/cursor.cpp
@@ -1,4 +1,4 @@
-#include <catch.hpp>
+#include "catch_utils.hpp"
#include <torch/nn/cursor.h>
#include <torch/nn/module.h>
@@ -58,158 +58,158 @@
std::vector<std::shared_ptr<Module>> m;
};
-TEST_CASE("cursor/module") {
+CATCH_TEST_CASE("cursor/module") {
torch::manual_seed(0);
- SECTION("Works for flat models (depth = 1)") {
+ CATCH_SECTION("Works for flat models (depth = 1)") {
Container model(TestModule(1), TestModule(2), TestModule(3));
auto cursor = model.modules();
- SECTION("Iterates in the correct order") {
+ CATCH_SECTION("Iterates in the correct order") {
auto iterator = cursor.begin();
- REQUIRE(&iterator->value == &model[0]);
- REQUIRE(&(++iterator)->value == &model[1]);
- REQUIRE(&(++iterator)->value == &model[2]);
- REQUIRE(++iterator == cursor.end());
+ CATCH_REQUIRE(&iterator->value == &model[0]);
+ CATCH_REQUIRE(&(++iterator)->value == &model[1]);
+ CATCH_REQUIRE(&(++iterator)->value == &model[2]);
+ CATCH_REQUIRE(++iterator == cursor.end());
}
- SECTION("names are flat") {
+ CATCH_SECTION("names are flat") {
auto iterator = cursor.begin();
- REQUIRE(iterator->key == "0");
- REQUIRE((++iterator)->key == "1");
- REQUIRE((++iterator)->key == "2");
+ CATCH_REQUIRE(iterator->key == "0");
+ CATCH_REQUIRE((++iterator)->key == "1");
+ CATCH_REQUIRE((++iterator)->key == "2");
}
- SECTION("Apply works") {
+ CATCH_SECTION("Apply works") {
size_t count = 0;
cursor.apply([&count, &model](Module& module) {
- REQUIRE(&module == &model[count]);
+ CATCH_REQUIRE(&module == &model[count]);
count += 1;
});
- REQUIRE(count == 3);
+ CATCH_REQUIRE(count == 3);
}
- SECTION("Apply_items works") {
+ CATCH_SECTION("Apply_items works") {
size_t count = 0;
cursor.apply_items(
[&count, &model](const std::string& key, Module& module) {
- REQUIRE(&module == &model[count]);
+ CATCH_REQUIRE(&module == &model[count]);
count += 1;
});
- REQUIRE(count == 3);
+ CATCH_REQUIRE(count == 3);
}
- SECTION("Map works") {
+ CATCH_SECTION("Map works") {
std::vector<Module*> vector(3);
cursor.map(vector.begin(), [](Module& module) { return &module; });
- REQUIRE(vector[0] == &model[0]);
- REQUIRE(vector[1] == &model[1]);
- REQUIRE(vector[2] == &model[2]);
+ CATCH_REQUIRE(vector[0] == &model[0]);
+ CATCH_REQUIRE(vector[1] == &model[1]);
+ CATCH_REQUIRE(vector[2] == &model[2]);
std::list<Module*> list;
cursor.map(std::inserter(list, list.end()), [](Module& module) {
return &module;
});
- REQUIRE(list.size() == 3);
+ CATCH_REQUIRE(list.size() == 3);
auto iterator = list.begin();
- REQUIRE(*iterator++ == &model[0]);
- REQUIRE(*iterator++ == &model[1]);
- REQUIRE(*iterator++ == &model[2]);
- REQUIRE(iterator == list.end());
+ CATCH_REQUIRE(*iterator++ == &model[0]);
+ CATCH_REQUIRE(*iterator++ == &model[1]);
+ CATCH_REQUIRE(*iterator++ == &model[2]);
+ CATCH_REQUIRE(iterator == list.end());
}
- SECTION("Map_items works") {
+ CATCH_SECTION("Map_items works") {
std::map<std::string, Module*> output;
cursor.map_items(
std::inserter(output, output.end()),
[](const std::string& key, Module& module) {
return std::make_pair(key, &module);
});
- REQUIRE(output.size() == 3);
- REQUIRE(output.count("0"));
- REQUIRE(output.count("1"));
- REQUIRE(output.count("2"));
- REQUIRE(output["0"] == &model[0]);
- REQUIRE(output["1"] == &model[1]);
- REQUIRE(output["2"] == &model[2]);
+ CATCH_REQUIRE(output.size() == 3);
+ CATCH_REQUIRE(output.count("0"));
+ CATCH_REQUIRE(output.count("1"));
+ CATCH_REQUIRE(output.count("2"));
+ CATCH_REQUIRE(output["0"] == &model[0]);
+ CATCH_REQUIRE(output["1"] == &model[1]);
+ CATCH_REQUIRE(output["2"] == &model[2]);
}
- SECTION("Count works for flat models") {
- REQUIRE(cursor.size() == model.m.size());
+ CATCH_SECTION("Count works for flat models") {
+ CATCH_REQUIRE(cursor.size() == model.m.size());
}
- SECTION("find() finds the correct modules when given a valid key") {
- REQUIRE(cursor.find("0") == &model[0]);
- REQUIRE(cursor.find("1") == &model[1]);
- REQUIRE(cursor.find("2") == &model[2]);
+ CATCH_SECTION("find() finds the correct modules when given a valid key") {
+ CATCH_REQUIRE(cursor.find("0") == &model[0]);
+ CATCH_REQUIRE(cursor.find("1") == &model[1]);
+ CATCH_REQUIRE(cursor.find("2") == &model[2]);
}
- SECTION("find() returns nullptr when given an invalid key") {
- REQUIRE(cursor.find("foo") == nullptr);
- REQUIRE(cursor.find("bar") == nullptr);
+ CATCH_SECTION("find() returns nullptr when given an invalid key") {
+ CATCH_REQUIRE(cursor.find("foo") == nullptr);
+ CATCH_REQUIRE(cursor.find("bar") == nullptr);
}
- SECTION("at(key) returns the correct modules when given a valid key") {
- REQUIRE(&cursor.at("0") == &model[0]);
- REQUIRE(&cursor.at("1") == &model[1]);
- REQUIRE(&cursor.at("2") == &model[2]);
+ CATCH_SECTION("at(key) returns the correct modules when given a valid key") {
+ CATCH_REQUIRE(&cursor.at("0") == &model[0]);
+ CATCH_REQUIRE(&cursor.at("1") == &model[1]);
+ CATCH_REQUIRE(&cursor.at("2") == &model[2]);
}
- SECTION("at(key) throws when given an invalid key") {
- REQUIRE_THROWS_WITH(cursor.at("foo"), StartsWith("No such key: 'foo'"));
- REQUIRE_THROWS_WITH(cursor.at("bar"), StartsWith("No such key: 'bar'"));
+ CATCH_SECTION("at(key) throws when given an invalid key") {
+ CATCH_REQUIRE_THROWS_WITH(cursor.at("foo"), StartsWith("No such key: 'foo'"));
+ CATCH_REQUIRE_THROWS_WITH(cursor.at("bar"), StartsWith("No such key: 'bar'"));
}
- SECTION(
+ CATCH_SECTION(
"operator[key] returns the correct modules when given a valid key") {
- REQUIRE(&cursor["0"] == &model[0]);
- REQUIRE(&cursor["1"] == &model[1]);
- REQUIRE(&cursor["2"] == &model[2]);
+ CATCH_REQUIRE(&cursor["0"] == &model[0]);
+ CATCH_REQUIRE(&cursor["1"] == &model[1]);
+ CATCH_REQUIRE(&cursor["2"] == &model[2]);
}
- SECTION("operator[key] throws when given an invalid key") {
- REQUIRE_THROWS_WITH(cursor["foo"], StartsWith("No such key: 'foo'"));
- REQUIRE_THROWS_WITH(cursor["bar"], StartsWith("No such key: 'bar'"));
+ CATCH_SECTION("operator[key] throws when given an invalid key") {
+ CATCH_REQUIRE_THROWS_WITH(cursor["foo"], StartsWith("No such key: 'foo'"));
+ CATCH_REQUIRE_THROWS_WITH(cursor["bar"], StartsWith("No such key: 'bar'"));
}
- SECTION("at(index) returns the correct modules when given a valid index") {
- REQUIRE(&cursor.at(0).value == &model[0]);
- REQUIRE(&cursor.at(1).value == &model[1]);
- REQUIRE(&cursor.at(2).value == &model[2]);
+ CATCH_SECTION("at(index) returns the correct modules when given a valid index") {
+ CATCH_REQUIRE(&cursor.at(0).value == &model[0]);
+ CATCH_REQUIRE(&cursor.at(1).value == &model[1]);
+ CATCH_REQUIRE(&cursor.at(2).value == &model[2]);
}
- SECTION("at(index) throws when given an invalid index") {
- REQUIRE_THROWS_WITH(
+ CATCH_SECTION("at(index) throws when given an invalid index") {
+ CATCH_REQUIRE_THROWS_WITH(
cursor.at(5),
StartsWith("Index 5 is out of range for cursor of size 3"));
- REQUIRE_THROWS_WITH(
+ CATCH_REQUIRE_THROWS_WITH(
cursor.at(123),
StartsWith("Index 123 is out of range for cursor of size 3"));
}
- SECTION(
+ CATCH_SECTION(
"operator[index] returns the correct modules when given a valid index") {
- REQUIRE(&cursor[0].value == &model[0]);
- REQUIRE(&cursor[1].value == &model[1]);
- REQUIRE(&cursor[2].value == &model[2]);
+ CATCH_REQUIRE(&cursor[0].value == &model[0]);
+ CATCH_REQUIRE(&cursor[1].value == &model[1]);
+ CATCH_REQUIRE(&cursor[2].value == &model[2]);
}
- SECTION("operator[index] throws when given an invalid key") {
- REQUIRE_THROWS_WITH(
+ CATCH_SECTION("operator[index] throws when given an invalid key") {
+ CATCH_REQUIRE_THROWS_WITH(
cursor[5],
StartsWith("Index 5 is out of range for cursor of size 3"));
- REQUIRE_THROWS_WITH(
+ CATCH_REQUIRE_THROWS_WITH(
cursor[123],
StartsWith("Index 123 is out of range for cursor of size 3"));
}
- SECTION("contains() is correct") {
- REQUIRE(cursor.contains("0"));
- REQUIRE(cursor.contains("1"));
- REQUIRE(cursor.contains("2"));
+ CATCH_SECTION("contains() is correct") {
+ CATCH_REQUIRE(cursor.contains("0"));
+ CATCH_REQUIRE(cursor.contains("1"));
+ CATCH_REQUIRE(cursor.contains("2"));
}
}
- SECTION("Works for deeper hierarchies (depth > 1)") {
+ CATCH_SECTION("Works for deeper hierarchies (depth > 1)") {
// clang-format off
Container model(
Container(
@@ -227,106 +227,106 @@
auto cursor = model.modules();
// This is sufficient for the hierarchical case
// (other tests build on top)
- SECTION("Iterates in the correct order") {
+ CATCH_SECTION("Iterates in the correct order") {
auto iterator = cursor.begin();
- REQUIRE(&iterator->value == &model[0]);
+ CATCH_REQUIRE(&iterator->value == &model[0]);
auto* seq = dynamic_cast<Container*>(&model[0]);
- REQUIRE(seq != nullptr);
- REQUIRE(&(++iterator)->value == &(*seq)[0]);
- REQUIRE(&(++iterator)->value == &(*seq)[1]);
+ CATCH_REQUIRE(seq != nullptr);
+ CATCH_REQUIRE(&(++iterator)->value == &(*seq)[0]);
+ CATCH_REQUIRE(&(++iterator)->value == &(*seq)[1]);
- REQUIRE(&(++iterator)->value == &model[1]);
- REQUIRE(&(++iterator)->value == &model[2]);
+ CATCH_REQUIRE(&(++iterator)->value == &model[1]);
+ CATCH_REQUIRE(&(++iterator)->value == &model[2]);
seq = dynamic_cast<Container*>(&model[2]);
- REQUIRE(seq != nullptr);
- REQUIRE(&(++iterator)->value == &(*seq)[0]);
- REQUIRE(&(++iterator)->value == &(*seq)[1]);
+ CATCH_REQUIRE(seq != nullptr);
+ CATCH_REQUIRE(&(++iterator)->value == &(*seq)[0]);
+ CATCH_REQUIRE(&(++iterator)->value == &(*seq)[1]);
seq = dynamic_cast<Container*>(&(*seq)[1]);
- REQUIRE(seq != nullptr);
- REQUIRE(&(++iterator)->value == &(*seq)[0]);
- REQUIRE(&(++iterator)->value == &(*seq)[1]);
+ CATCH_REQUIRE(seq != nullptr);
+ CATCH_REQUIRE(&(++iterator)->value == &(*seq)[0]);
+ CATCH_REQUIRE(&(++iterator)->value == &(*seq)[1]);
}
- SECTION("children() returns only the first level of submodules") {
+ CATCH_SECTION("children() returns only the first level of submodules") {
auto children = model.children();
- REQUIRE(children.size() == 3);
- REQUIRE(&children.at("0") == &model[0]);
- REQUIRE(&children.at("1") == &model[1]);
- REQUIRE(&children.at("2") == &model[2]);
- REQUIRE(!children.contains("0.0"));
+ CATCH_REQUIRE(children.size() == 3);
+ CATCH_REQUIRE(&children.at("0") == &model[0]);
+ CATCH_REQUIRE(&children.at("1") == &model[1]);
+ CATCH_REQUIRE(&children.at("2") == &model[2]);
+ CATCH_REQUIRE(!children.contains("0.0"));
size_t count = 0;
for (auto& child : children) {
- REQUIRE(child.key == std::to_string(count));
- REQUIRE(&child.value == &model[count]);
+ CATCH_REQUIRE(child.key == std::to_string(count));
+ CATCH_REQUIRE(&child.value == &model[count]);
count += 1;
}
}
}
}
-TEST_CASE("cursor/parameter") {
+CATCH_TEST_CASE("cursor/parameter") {
torch::manual_seed(0);
- SECTION("Works for single models") {
+ CATCH_SECTION("Works for single models") {
TestModule model(1);
auto cursor = model.parameters();
- SECTION("Iterates in the correct order") {
+ CATCH_SECTION("Iterates in the correct order") {
auto iterator = cursor.begin();
- REQUIRE(iterator->value.equal(model.tensor1));
- REQUIRE((++iterator)->value.equal(model.tensor2));
+ CATCH_REQUIRE(iterator->value.equal(model.tensor1));
+ CATCH_REQUIRE((++iterator)->value.equal(model.tensor2));
}
}
- SECTION("Works for flat models (depth = 1)") {
+ CATCH_SECTION("Works for flat models (depth = 1)") {
auto first = std::make_shared<TestModule>(1);
auto second = std::make_shared<TestModule>(2);
Container model(first, second);
auto cursor = model.parameters();
- SECTION("Iterates in the correct order") {
+ CATCH_SECTION("Iterates in the correct order") {
auto iterator = cursor.begin();
- REQUIRE(iterator->value.equal(first->tensor1));
- REQUIRE((++iterator)->value.equal(first->tensor2));
- REQUIRE((++iterator)->value.equal(second->tensor1));
- REQUIRE((++iterator)->value.equal(second->tensor2));
+ CATCH_REQUIRE(iterator->value.equal(first->tensor1));
+ CATCH_REQUIRE((++iterator)->value.equal(first->tensor2));
+ CATCH_REQUIRE((++iterator)->value.equal(second->tensor1));
+ CATCH_REQUIRE((++iterator)->value.equal(second->tensor2));
}
- SECTION("Apply_items works") {
+ CATCH_SECTION("Apply_items works") {
size_t count = 0;
cursor.apply_items([&count, &model, &first, &second](
const std::string& key, torch::Tensor& tensor) {
switch (count) {
case 0: {
- REQUIRE(tensor.equal(first->tensor1));
+ CATCH_REQUIRE(tensor.equal(first->tensor1));
break;
}
case 1: {
- REQUIRE(tensor.equal(first->tensor2));
+ CATCH_REQUIRE(tensor.equal(first->tensor2));
break;
}
case 2: {
- REQUIRE(tensor.equal(second->tensor1));
+ CATCH_REQUIRE(tensor.equal(second->tensor1));
break;
}
case 3: {
- REQUIRE(tensor.equal(second->tensor2));
+ CATCH_REQUIRE(tensor.equal(second->tensor2));
break;
}
}
count += 1;
});
- REQUIRE(count == 4);
+ CATCH_REQUIRE(count == 4);
}
// Other tests are correct based on correct iteration behavior and apply
// working.
}
- SECTION("Works for deeper hierarchies (depth > 1)") {
+ CATCH_SECTION("Works for deeper hierarchies (depth > 1)") {
std::vector<std::shared_ptr<TestModule>> modules;
for (size_t i = 1; i <= 6; ++i) {
modules.push_back(std::make_shared<TestModule>(i));
@@ -346,36 +346,36 @@
// clang-format on
auto cursor = model.parameters();
- SECTION("Iterates in the correct order") {
+ CATCH_SECTION("Iterates in the correct order") {
auto iterator = cursor.begin();
- REQUIRE(iterator->value.equal(modules[0]->tensor1));
- REQUIRE((++iterator)->value.equal(modules[0]->tensor2));
+ CATCH_REQUIRE(iterator->value.equal(modules[0]->tensor1));
+ CATCH_REQUIRE((++iterator)->value.equal(modules[0]->tensor2));
for (size_t index = 1; index < 6; ++index) {
- REQUIRE((++iterator)->value.equal(modules[index]->tensor1));
- REQUIRE((++iterator)->value.equal(modules[index]->tensor2));
+ CATCH_REQUIRE((++iterator)->value.equal(modules[index]->tensor1));
+ CATCH_REQUIRE((++iterator)->value.equal(modules[index]->tensor2));
}
}
- SECTION("names are hierarchical") {
+ CATCH_SECTION("names are hierarchical") {
auto iterator = cursor.begin();
- REQUIRE(iterator->key == "0.0.tensor1");
- REQUIRE((++iterator)->key == "0.0.tensor2");
- REQUIRE((++iterator)->key == "0.1.tensor1");
- REQUIRE((++iterator)->key == "0.1.tensor2");
- REQUIRE((++iterator)->key == "1.tensor1");
- REQUIRE((++iterator)->key == "1.tensor2");
- REQUIRE((++iterator)->key == "2.0.tensor1");
- REQUIRE((++iterator)->key == "2.0.tensor2");
- REQUIRE((++iterator)->key == "2.1.0.tensor1");
- REQUIRE((++iterator)->key == "2.1.0.tensor2");
- REQUIRE((++iterator)->key == "2.1.1.tensor1");
- REQUIRE((++iterator)->key == "2.1.1.tensor2");
- REQUIRE(++iterator == cursor.end());
+ CATCH_REQUIRE(iterator->key == "0.0.tensor1");
+ CATCH_REQUIRE((++iterator)->key == "0.0.tensor2");
+ CATCH_REQUIRE((++iterator)->key == "0.1.tensor1");
+ CATCH_REQUIRE((++iterator)->key == "0.1.tensor2");
+ CATCH_REQUIRE((++iterator)->key == "1.tensor1");
+ CATCH_REQUIRE((++iterator)->key == "1.tensor2");
+ CATCH_REQUIRE((++iterator)->key == "2.0.tensor1");
+ CATCH_REQUIRE((++iterator)->key == "2.0.tensor2");
+ CATCH_REQUIRE((++iterator)->key == "2.1.0.tensor1");
+ CATCH_REQUIRE((++iterator)->key == "2.1.0.tensor2");
+ CATCH_REQUIRE((++iterator)->key == "2.1.1.tensor1");
+ CATCH_REQUIRE((++iterator)->key == "2.1.1.tensor2");
+ CATCH_REQUIRE(++iterator == cursor.end());
}
}
}
-TEST_CASE("cursor/non-const-to-const-conversion") {
+CATCH_TEST_CASE("cursor/non-const-to-const-conversion") {
torch::manual_seed(0);
auto first = std::make_shared<TestModule>(1);
auto second = std::make_shared<TestModule>(2);
@@ -404,11 +404,11 @@
}
}
-TEST_CASE("cursor/can-invoke-const-method-on-const-cursor") {
+CATCH_TEST_CASE("cursor/can-invoke-const-method-on-const-cursor") {
torch::manual_seed(0);
TestModule model(1);
/// This will only compile if `Cursor` has the appropriate const methods.
const auto cursor = model.parameters();
- REQUIRE(cursor.contains("tensor1"));
+ CATCH_REQUIRE(cursor.contains("tensor1"));
}
diff --git a/test/cpp/api/integration.cpp b/test/cpp/api/integration.cpp
index 8d75319..972223a 100644
--- a/test/cpp/api/integration.cpp
+++ b/test/cpp/api/integration.cpp
@@ -1,4 +1,4 @@
-#include <catch.hpp>
+#include "catch_utils.hpp"
#include <torch/nn/modules/batchnorm.h>
#include <torch/nn/modules/conv.h>
@@ -230,7 +230,7 @@
return correct.sum().toCFloat() > telabel.size(0) * 0.8;
}
-TEST_CASE("integration/cartpole") {
+CATCH_TEST_CASE("integration/cartpole") {
torch::manual_seed(0);
std::cerr << "Training episodic policy gradient with a critic for up to 3000"
" episodes, rest your eyes for a bit!\n";
@@ -326,11 +326,11 @@
if (running_reward > 150) {
break;
}
- REQUIRE(episode < 3000);
+ CATCH_REQUIRE(episode < 3000);
}
}
-TEST_CASE("integration/mnist", "[cuda]") {
+CATCH_TEST_CASE("integration/mnist", "[cuda]") {
torch::manual_seed(0);
auto model = std::make_shared<SimpleContainer>();
auto conv1 = model->add(Conv2d(1, 10, 5), "conv1");
@@ -357,7 +357,7 @@
auto optimizer = torch::optim::SGD(
model->parameters(), torch::optim::SGDOptions(1e-2).momentum(0.5));
- REQUIRE(test_mnist(
+ CATCH_REQUIRE(test_mnist(
32, // batch_size
3, // num_epochs
true, // useGPU
@@ -366,7 +366,7 @@
optimizer));
}
-TEST_CASE("integration/mnist/batchnorm", "[cuda]") {
+CATCH_TEST_CASE("integration/mnist/batchnorm", "[cuda]") {
torch::manual_seed(0);
auto model = std::make_shared<SimpleContainer>();
auto conv1 = model->add(Conv2d(1, 10, 5), "conv1");
@@ -393,7 +393,7 @@
auto optimizer = torch::optim::SGD(
model->parameters(), torch::optim::SGDOptions(1e-2).momentum(0.5));
- REQUIRE(test_mnist(
+ CATCH_REQUIRE(test_mnist(
32, // batch_size
3, // num_epochs
true, // useGPU
diff --git a/test/cpp/api/jit.cpp b/test/cpp/api/jit.cpp
index c46868c..b477b11 100644
--- a/test/cpp/api/jit.cpp
+++ b/test/cpp/api/jit.cpp
@@ -1,12 +1,12 @@
-#include <catch.hpp>
+#include "catch_utils.hpp"
#include <torch/jit.h>
#include <torch/tensor.h>
#include <string>
-TEST_CASE("torch script") {
- SECTION("multiple functions") {
+CATCH_TEST_CASE("torch script") {
+ CATCH_SECTION("multiple functions") {
auto module = torch::jit::compile(R"JIT(
def test_mul(a, b):
return a * b
@@ -21,11 +21,11 @@
auto a = torch::ones(1);
auto b = torch::ones(1);
- REQUIRE(1 == module->run_method("test_mul", a, b).toTensor().toCLong());
+ CATCH_REQUIRE(1 == module->run_method("test_mul", a, b).toTensor().toCLong());
- REQUIRE(2 == module->run_method("test_relu", a, b).toTensor().toCLong());
+ CATCH_REQUIRE(2 == module->run_method("test_relu", a, b).toTensor().toCLong());
- REQUIRE(
+ CATCH_REQUIRE(
0x200 == module->run_method("test_while", a, b).toTensor().toCLong());
}
}
diff --git a/test/cpp/api/main.cpp b/test/cpp/api/main.cpp
index 4b1aaba..92ea356 100644
--- a/test/cpp/api/main.cpp
+++ b/test/cpp/api/main.cpp
@@ -1,5 +1,5 @@
#define CATCH_CONFIG_RUNNER
-#include <catch.hpp>
+#include "catch_utils.hpp"
#include <torch/cuda.h>
diff --git a/test/cpp/api/misc.cpp b/test/cpp/api/misc.cpp
index 6d065bf..8ced0e0 100644
--- a/test/cpp/api/misc.cpp
+++ b/test/cpp/api/misc.cpp
@@ -1,4 +1,4 @@
-#include <catch.hpp>
+#include "catch_utils.hpp"
#include <torch/detail/ordered_dict.h>
#include <torch/expanding_array.h>
@@ -18,7 +18,7 @@
using Catch::StartsWith;
-TEST_CASE("NoGrad") {
+CATCH_TEST_CASE("NoGrad") {
torch::manual_seed(0);
torch::NoGradGuard guard;
Linear model(5, 2);
@@ -27,88 +27,88 @@
torch::Tensor s = y.sum();
s.backward();
- REQUIRE(!model->parameters()["weight"].grad().defined());
+ CATCH_REQUIRE(!model->parameters()["weight"].grad().defined());
}
-TEST_CASE("autograd") {
+CATCH_TEST_CASE("autograd") {
torch::manual_seed(0);
auto x = torch::randn({3, 3}, torch::requires_grad());
auto y = torch::randn({3, 3});
auto z = x * y;
- SECTION("derivatives of zero-dim tensors") {
+ CATCH_SECTION("derivatives of zero-dim tensors") {
z.sum().backward();
- REQUIRE(x.grad().allclose(y));
+ CATCH_REQUIRE(x.grad().allclose(y));
}
- SECTION("derivatives of tensors") {
+ CATCH_SECTION("derivatives of tensors") {
z.backward();
- REQUIRE(x.grad().allclose(y));
+ CATCH_REQUIRE(x.grad().allclose(y));
}
- SECTION("custom gradient inputs") {
+ CATCH_SECTION("custom gradient inputs") {
z.sum().backward(torch::ones({}) * 2);
- REQUIRE(x.grad().allclose(y * 2));
+ CATCH_REQUIRE(x.grad().allclose(y * 2));
}
// Assume everything else is safe from PyTorch tests.
}
-TEST_CASE("nn::init") {
+CATCH_TEST_CASE("nn::init") {
auto tensor = torch::empty({3, 4}, torch::requires_grad());
- REQUIRE_THROWS_WITH(
+ CATCH_REQUIRE_THROWS_WITH(
tensor.fill_(1),
StartsWith("a leaf Variable that requires grad "
"has been used in an in-place operation"));
- REQUIRE(torch::nn::init::ones_(tensor).sum().toCInt() == 12);
+ CATCH_REQUIRE(torch::nn::init::ones_(tensor).sum().toCInt() == 12);
}
-TEST_CASE("expanding-array") {
+CATCH_TEST_CASE("expanding-array") {
torch::manual_seed(0);
- SECTION("successful construction") {
- SECTION("initializer_list") {
+ CATCH_SECTION("successful construction") {
+ CATCH_SECTION("initializer_list") {
torch::ExpandingArray<5> e({1, 2, 3, 4, 5});
- REQUIRE(e.size() == 5);
+ CATCH_REQUIRE(e.size() == 5);
for (size_t i = 0; i < e.size(); ++i) {
- REQUIRE((*e)[i] == i + 1);
+ CATCH_REQUIRE((*e)[i] == i + 1);
}
}
- SECTION("vector") {
+ CATCH_SECTION("vector") {
torch::ExpandingArray<5> e(std::vector<int64_t>{1, 2, 3, 4, 5});
- REQUIRE(e.size() == 5);
+ CATCH_REQUIRE(e.size() == 5);
for (size_t i = 0; i < e.size(); ++i) {
- REQUIRE((*e)[i] == i + 1);
+ CATCH_REQUIRE((*e)[i] == i + 1);
}
}
- SECTION("array") {
+ CATCH_SECTION("array") {
torch::ExpandingArray<5> e(std::array<int64_t, 5>({1, 2, 3, 4, 5}));
- REQUIRE(e.size() == 5);
+ CATCH_REQUIRE(e.size() == 5);
for (size_t i = 0; i < e.size(); ++i) {
- REQUIRE((*e)[i] == i + 1);
+ CATCH_REQUIRE((*e)[i] == i + 1);
}
}
- SECTION("single value") {
+ CATCH_SECTION("single value") {
torch::ExpandingArray<5> e(5);
- REQUIRE(e.size() == 5);
+ CATCH_REQUIRE(e.size() == 5);
for (size_t i = 0; i < e.size(); ++i) {
- REQUIRE((*e)[i] == 5);
+ CATCH_REQUIRE((*e)[i] == 5);
}
}
}
- SECTION("throws for incorrect size on construction") {
- SECTION("initializer_list") {
- REQUIRE_THROWS_WITH(
+ CATCH_SECTION("throws for incorrect size on construction") {
+ CATCH_SECTION("initializer_list") {
+ CATCH_REQUIRE_THROWS_WITH(
torch::ExpandingArray<5>({1, 2, 3, 4, 5, 6, 7}),
StartsWith("Expected 5 values, but instead got 7"));
}
- SECTION("vector") {
- REQUIRE_THROWS_WITH(
+ CATCH_SECTION("vector") {
+ CATCH_REQUIRE_THROWS_WITH(
torch::ExpandingArray<5>(std::vector<int64_t>({1, 2, 3, 4, 5, 6, 7})),
StartsWith("Expected 5 values, but instead got 7"));
}
}
}
-TEST_CASE("make_unique") {
+CATCH_TEST_CASE("make_unique") {
struct Test {
explicit Test(const int& x) : lvalue_(x) {}
explicit Test(int&& x) : rvalue_(x) {}
@@ -117,216 +117,216 @@
at::optional<int> rvalue_;
};
- SECTION("forwards rvalues correctly") {
+ CATCH_SECTION("forwards rvalues correctly") {
auto ptr = torch::make_unique<Test>(123);
- REQUIRE(!ptr->lvalue_.has_value());
- REQUIRE(ptr->rvalue_.has_value());
- REQUIRE(*ptr->rvalue_ == 123);
+ CATCH_REQUIRE(!ptr->lvalue_.has_value());
+ CATCH_REQUIRE(ptr->rvalue_.has_value());
+ CATCH_REQUIRE(*ptr->rvalue_ == 123);
}
- SECTION("forwards lvalues correctly") {
+ CATCH_SECTION("forwards lvalues correctly") {
int x = 5;
auto ptr = torch::make_unique<Test>(x);
- REQUIRE(ptr->lvalue_.has_value());
- REQUIRE(*ptr->lvalue_ == 5);
- REQUIRE(!ptr->rvalue_.has_value());
+ CATCH_REQUIRE(ptr->lvalue_.has_value());
+ CATCH_REQUIRE(*ptr->lvalue_ == 5);
+ CATCH_REQUIRE(!ptr->rvalue_.has_value());
}
- SECTION("Can construct unique_ptr of array") {
+ CATCH_SECTION("Can construct unique_ptr of array") {
auto ptr = torch::make_unique<int[]>(3);
// Value initialization is required by the standard.
- REQUIRE(ptr[0] == 0);
- REQUIRE(ptr[1] == 0);
- REQUIRE(ptr[2] == 0);
+ CATCH_REQUIRE(ptr[0] == 0);
+ CATCH_REQUIRE(ptr[1] == 0);
+ CATCH_REQUIRE(ptr[2] == 0);
}
}
-TEST_CASE("ordered-dict") {
- SECTION("is empty after default construction") {
+CATCH_TEST_CASE("ordered-dict") {
+ CATCH_SECTION("is empty after default construction") {
OrderedDict<int> dict;
- REQUIRE(dict.subject() == "Key");
- REQUIRE(dict.is_empty());
- REQUIRE(dict.size() == 0);
+ CATCH_REQUIRE(dict.subject() == "Key");
+ CATCH_REQUIRE(dict.is_empty());
+ CATCH_REQUIRE(dict.size() == 0);
}
- SECTION("insert inserts elements when they are not yet present") {
+ CATCH_SECTION("insert inserts elements when they are not yet present") {
OrderedDict<int> dict;
dict.insert("a", 1);
dict.insert("b", 2);
- REQUIRE(dict.size() == 2);
+ CATCH_REQUIRE(dict.size() == 2);
}
- SECTION("get returns values when present") {
+ CATCH_SECTION("get returns values when present") {
OrderedDict<int> dict;
dict.insert("a", 1);
dict.insert("b", 2);
- REQUIRE(dict.get("a") == 1);
- REQUIRE(dict.get("b") == 2);
+ CATCH_REQUIRE(dict.get("a") == 1);
+ CATCH_REQUIRE(dict.get("b") == 2);
}
- SECTION("get throws when passed keys that are not present") {
+ CATCH_SECTION("get throws when passed keys that are not present") {
OrderedDict<int> dict;
dict.insert("a", 1);
dict.insert("b", 2);
- REQUIRE_THROWS_WITH(
+ CATCH_REQUIRE_THROWS_WITH(
dict.get("foo"), StartsWith("Key 'foo' is not defined"));
- REQUIRE_THROWS_WITH(dict.get(""), StartsWith("Key '' is not defined"));
+ CATCH_REQUIRE_THROWS_WITH(dict.get(""), StartsWith("Key '' is not defined"));
}
- SECTION("can initialize from list") {
+ CATCH_SECTION("can initialize from list") {
OrderedDict<int> dict = {{"a", 1}, {"b", 2}};
- REQUIRE(dict.size() == 2);
- REQUIRE(dict.get("a") == 1);
- REQUIRE(dict.get("b") == 2);
+ CATCH_REQUIRE(dict.size() == 2);
+ CATCH_REQUIRE(dict.get("a") == 1);
+ CATCH_REQUIRE(dict.get("b") == 2);
}
- SECTION("insert throws when passed elements that are present") {
+ CATCH_SECTION("insert throws when passed elements that are present") {
OrderedDict<int> dict = {{"a", 1}, {"b", 2}};
- REQUIRE_THROWS_WITH(
+ CATCH_REQUIRE_THROWS_WITH(
dict.insert("a", 1), StartsWith("Key 'a' already defined"));
- REQUIRE_THROWS_WITH(
+ CATCH_REQUIRE_THROWS_WITH(
dict.insert("b", 1), StartsWith("Key 'b' already defined"));
}
- SECTION("front() returns the first item") {
+ CATCH_SECTION("front() returns the first item") {
OrderedDict<int> dict = {{"a", 1}, {"b", 2}};
- REQUIRE(dict.front().key == "a");
- REQUIRE(dict.front().value == 1);
+ CATCH_REQUIRE(dict.front().key == "a");
+ CATCH_REQUIRE(dict.front().value == 1);
}
- SECTION("back() returns the last item") {
+ CATCH_SECTION("back() returns the last item") {
OrderedDict<int> dict = {{"a", 1}, {"b", 2}};
- REQUIRE(dict.back().key == "b");
- REQUIRE(dict.back().value == 2);
+ CATCH_REQUIRE(dict.back().key == "b");
+ CATCH_REQUIRE(dict.back().value == 2);
}
- SECTION("find returns pointers to values when present") {
+ CATCH_SECTION("find returns pointers to values when present") {
OrderedDict<int> dict = {{"a", 1}, {"b", 2}};
- REQUIRE(dict.find("a") != nullptr);
- REQUIRE(*dict.find("a") == 1);
- REQUIRE(dict.find("b") != nullptr);
- REQUIRE(*dict.find("b") == 2);
+ CATCH_REQUIRE(dict.find("a") != nullptr);
+ CATCH_REQUIRE(*dict.find("a") == 1);
+ CATCH_REQUIRE(dict.find("b") != nullptr);
+ CATCH_REQUIRE(*dict.find("b") == 2);
}
- SECTION("find returns null pointers when passed keys that are not present") {
+ CATCH_SECTION("find returns null pointers when passed keys that are not present") {
OrderedDict<int> dict = {{"a", 1}, {"b", 2}};
- REQUIRE(dict.find("bar") == nullptr);
- REQUIRE(dict.find("") == nullptr);
+ CATCH_REQUIRE(dict.find("bar") == nullptr);
+ CATCH_REQUIRE(dict.find("") == nullptr);
}
- SECTION("operator[] returns values when passed keys that are present") {
+ CATCH_SECTION("operator[] returns values when passed keys that are present") {
OrderedDict<int> dict = {{"a", 1}, {"b", 2}};
- REQUIRE(dict["a"] == 1);
- REQUIRE(dict["b"] == 2);
+ CATCH_REQUIRE(dict["a"] == 1);
+ CATCH_REQUIRE(dict["b"] == 2);
}
- SECTION("operator[] returns items positionally when passed integers") {
+ CATCH_SECTION("operator[] returns items positionally when passed integers") {
OrderedDict<int> dict = {{"a", 1}, {"b", 2}};
- REQUIRE(dict[0].key == "a");
- REQUIRE(dict[0].value == 1);
- REQUIRE(dict[1].key == "b");
- REQUIRE(dict[1].value == 2);
+ CATCH_REQUIRE(dict[0].key == "a");
+ CATCH_REQUIRE(dict[0].value == 1);
+ CATCH_REQUIRE(dict[1].key == "b");
+ CATCH_REQUIRE(dict[1].value == 2);
}
- SECTION("operator[] throws when passed keys that are not present") {
+ CATCH_SECTION("operator[] throws when passed keys that are not present") {
OrderedDict<int> dict = {{"a", 1}, {"b", 2}};
- REQUIRE_THROWS_WITH(
+ CATCH_REQUIRE_THROWS_WITH(
dict.get("foo"), StartsWith("Key 'foo' is not defined"));
- REQUIRE_THROWS_WITH(dict.get(""), StartsWith("Key '' is not defined"));
+ CATCH_REQUIRE_THROWS_WITH(dict.get(""), StartsWith("Key '' is not defined"));
}
- SECTION("update inserts all items from another OrderedDict") {
+ CATCH_SECTION("update inserts all items from another OrderedDict") {
OrderedDict<int> dict = {{"a", 1}, {"b", 2}};
OrderedDict<int> dict2 = {{"c", 3}};
dict2.update(dict);
- REQUIRE(dict2.size() == 3);
- REQUIRE(dict2.find("a") != nullptr);
- REQUIRE(dict2.find("b") != nullptr);
- REQUIRE(dict2.find("c") != nullptr);
+ CATCH_REQUIRE(dict2.size() == 3);
+ CATCH_REQUIRE(dict2.find("a") != nullptr);
+ CATCH_REQUIRE(dict2.find("b") != nullptr);
+ CATCH_REQUIRE(dict2.find("c") != nullptr);
}
- SECTION("update also checks for duplicates") {
+ CATCH_SECTION("update also checks for duplicates") {
OrderedDict<int> dict = {{"a", 1}, {"b", 2}};
OrderedDict<int> dict2 = {{"a", 1}};
- REQUIRE_THROWS_WITH(
+ CATCH_REQUIRE_THROWS_WITH(
dict2.update(dict), StartsWith("Key 'a' already defined"));
}
- SECTION("Can iterate items") {
+ CATCH_SECTION("Can iterate items") {
OrderedDict<int> dict = {{"a", 1}, {"b", 2}};
auto iterator = dict.begin();
- REQUIRE(iterator != dict.end());
- REQUIRE(iterator->key == "a");
- REQUIRE(iterator->value == 1);
+ CATCH_REQUIRE(iterator != dict.end());
+ CATCH_REQUIRE(iterator->key == "a");
+ CATCH_REQUIRE(iterator->value == 1);
++iterator;
- REQUIRE(iterator != dict.end());
- REQUIRE(iterator->key == "b");
- REQUIRE(iterator->value == 2);
+ CATCH_REQUIRE(iterator != dict.end());
+ CATCH_REQUIRE(iterator->key == "b");
+ CATCH_REQUIRE(iterator->value == 2);
++iterator;
- REQUIRE(iterator == dict.end());
+ CATCH_REQUIRE(iterator == dict.end());
}
- SECTION("clear makes the dict empty") {
+ CATCH_SECTION("clear makes the dict empty") {
OrderedDict<int> dict = {{"a", 1}, {"b", 2}};
- REQUIRE(!dict.is_empty());
+ CATCH_REQUIRE(!dict.is_empty());
dict.clear();
- REQUIRE(dict.is_empty());
+ CATCH_REQUIRE(dict.is_empty());
}
- SECTION("can copy construct") {
+ CATCH_SECTION("can copy construct") {
OrderedDict<int> dict = {{"a", 1}, {"b", 2}};
OrderedDict<int> copy = dict;
- REQUIRE(copy.size() == 2);
- REQUIRE(*copy[0] == 1);
- REQUIRE(*copy[1] == 2);
+ CATCH_REQUIRE(copy.size() == 2);
+ CATCH_REQUIRE(*copy[0] == 1);
+ CATCH_REQUIRE(*copy[1] == 2);
}
- SECTION("can copy assign") {
+ CATCH_SECTION("can copy assign") {
OrderedDict<int> dict = {{"a", 1}, {"b", 2}};
OrderedDict<int> copy = {{"c", 1}};
- REQUIRE(copy.find("c") != nullptr);
+ CATCH_REQUIRE(copy.find("c") != nullptr);
copy = dict;
- REQUIRE(copy.size() == 2);
- REQUIRE(*copy[0] == 1);
- REQUIRE(*copy[1] == 2);
- REQUIRE(copy.find("c") == nullptr);
+ CATCH_REQUIRE(copy.size() == 2);
+ CATCH_REQUIRE(*copy[0] == 1);
+ CATCH_REQUIRE(*copy[1] == 2);
+ CATCH_REQUIRE(copy.find("c") == nullptr);
}
- SECTION("can move construct") {
+ CATCH_SECTION("can move construct") {
OrderedDict<int> dict = {{"a", 1}, {"b", 2}};
OrderedDict<int> copy = std::move(dict);
- REQUIRE(copy.size() == 2);
- REQUIRE(*copy[0] == 1);
- REQUIRE(*copy[1] == 2);
+ CATCH_REQUIRE(copy.size() == 2);
+ CATCH_REQUIRE(*copy[0] == 1);
+ CATCH_REQUIRE(*copy[1] == 2);
}
- SECTION("can move assign") {
+ CATCH_SECTION("can move assign") {
OrderedDict<int> dict = {{"a", 1}, {"b", 2}};
OrderedDict<int> copy = {{"c", 1}};
- REQUIRE(copy.find("c") != nullptr);
+ CATCH_REQUIRE(copy.find("c") != nullptr);
copy = std::move(dict);
- REQUIRE(copy.size() == 2);
- REQUIRE(*copy[0] == 1);
- REQUIRE(*copy[1] == 2);
- REQUIRE(copy.find("c") == nullptr);
+ CATCH_REQUIRE(copy.size() == 2);
+ CATCH_REQUIRE(*copy[0] == 1);
+ CATCH_REQUIRE(*copy[1] == 2);
+ CATCH_REQUIRE(copy.find("c") == nullptr);
}
- SECTION("can insert with braces") {
+ CATCH_SECTION("can insert with braces") {
OrderedDict<std::pair<int, int>> dict;
dict.insert("a", {1, 2});
- REQUIRE(!dict.is_empty());
- REQUIRE(dict["a"].first == 1);
- REQUIRE(dict["a"].second == 2);
+ CATCH_REQUIRE(!dict.is_empty());
+ CATCH_REQUIRE(dict["a"].first == 1);
+ CATCH_REQUIRE(dict["a"].second == 2);
}
- SECTION("Error messages include the what") {
+ CATCH_SECTION("Error messages include the what") {
OrderedDict<int> dict("Penguin");
- REQUIRE(dict.subject() == "Penguin");
+ CATCH_REQUIRE(dict.subject() == "Penguin");
dict.insert("a", 1);
- REQUIRE(!dict.is_empty());
- REQUIRE_THROWS_WITH(
+ CATCH_REQUIRE(!dict.is_empty());
+ CATCH_REQUIRE_THROWS_WITH(
dict.get("b"), StartsWith("Penguin 'b' is not defined"));
- REQUIRE_THROWS_WITH(
+ CATCH_REQUIRE_THROWS_WITH(
dict.insert("a", 1), StartsWith("Penguin 'a' already defined"));
}
}
diff --git a/test/cpp/api/module.cpp b/test/cpp/api/module.cpp
index d4049d6..2b9d0ad 100644
--- a/test/cpp/api/module.cpp
+++ b/test/cpp/api/module.cpp
@@ -1,4 +1,4 @@
-#include <catch.hpp>
+#include "catch_utils.hpp"
#include <torch/nn/module.h>
#include <torch/nn/modules/linear.h>
@@ -22,21 +22,21 @@
};
} // namespace test
-TEST_CASE("module/training-mode") {
+CATCH_TEST_CASE("module/training-mode") {
torch::manual_seed(0);
Linear module(3, 4);
- REQUIRE(module->is_training());
- SECTION("Enable eval mode") {
+ CATCH_REQUIRE(module->is_training());
+ CATCH_SECTION("Enable eval mode") {
module->eval();
- REQUIRE(!module->is_training());
+ CATCH_REQUIRE(!module->is_training());
}
- SECTION("Enable train mode") {
+ CATCH_SECTION("Enable train mode") {
module->train();
- REQUIRE(module->is_training());
+ CATCH_REQUIRE(module->is_training());
}
}
-TEST_CASE("module/zero-grad") {
+CATCH_TEST_CASE("module/zero-grad") {
torch::manual_seed(0);
Linear module(3, 4);
auto weight = torch::ones({8, 3}, torch::requires_grad());
@@ -44,18 +44,18 @@
loss.backward();
for (auto& parameter : module->parameters()) {
auto grad = parameter->grad();
- REQUIRE(grad.defined());
- REQUIRE(grad.sum().toCFloat() != 0);
+ CATCH_REQUIRE(grad.defined());
+ CATCH_REQUIRE(grad.sum().toCFloat() != 0);
}
module->zero_grad();
for (auto& parameter : module->parameters()) {
auto grad = parameter->grad();
- REQUIRE(grad.defined());
- REQUIRE(grad.sum().toCFloat() == 0);
+ CATCH_REQUIRE(grad.defined());
+ CATCH_REQUIRE(grad.sum().toCFloat() == 0);
}
}
-TEST_CASE("module/zero-grad-with-undefined") {
+CATCH_TEST_CASE("module/zero-grad-with-undefined") {
struct TestModule : torch::nn::Module {
TestModule() {
x = register_parameter("x", torch::ones(5, at::requires_grad()));
@@ -68,120 +68,120 @@
auto z = module.x * 2;
z.sum().backward();
- REQUIRE(module.x.grad().defined());
- REQUIRE(!module.y.grad().defined());
+ CATCH_REQUIRE(module.x.grad().defined());
+ CATCH_REQUIRE(!module.y.grad().defined());
module.zero_grad();
- REQUIRE(module.x.grad().defined());
- REQUIRE(!module.y.grad().defined());
+ CATCH_REQUIRE(module.x.grad().defined());
+ CATCH_REQUIRE(!module.y.grad().defined());
- REQUIRE(module.x.grad().sum().toCFloat() == 0);
+ CATCH_REQUIRE(module.x.grad().sum().toCFloat() == 0);
}
-TEST_CASE("module/name") {
+CATCH_TEST_CASE("module/name") {
// CHECK instead of REQUIRE because demangling may fail.
AGIUnit agi;
// Call it twice just to make sure there are no bugs in the lazy
// initialization semantics.
- CHECK(agi.name() == "AGIUnit");
- CHECK(agi.name() == "AGIUnit");
- SECTION("correctly demangled") {
- CHECK(test::AGIUnit().name() == "test::AGIUnit");
- CHECK(test::AGIUnit2().name() == "Foo");
+ CATCH_CHECK(agi.name() == "AGIUnit");
+ CATCH_CHECK(agi.name() == "AGIUnit");
+ CATCH_SECTION("correctly demangled") {
+ CATCH_CHECK(test::AGIUnit().name() == "test::AGIUnit");
+ CATCH_CHECK(test::AGIUnit2().name() == "Foo");
}
}
-TEST_CASE("module/as") {
+CATCH_TEST_CASE("module/as") {
Linear module(3, 4);
- REQUIRE(module->as<Linear>() == module.get());
- REQUIRE(module->as<LinearImpl>() == module.get());
- REQUIRE(module->as<Module>() == module.get());
- REQUIRE(module->as<AGIUnit>() == nullptr);
+ CATCH_REQUIRE(module->as<Linear>() == module.get());
+ CATCH_REQUIRE(module->as<LinearImpl>() == module.get());
+ CATCH_REQUIRE(module->as<Module>() == module.get());
+ CATCH_REQUIRE(module->as<AGIUnit>() == nullptr);
std::shared_ptr<Module> raw = module.ptr();
- REQUIRE(raw->as<Linear>() == module.get());
- REQUIRE(raw->as<LinearImpl>() == module.get());
- REQUIRE(raw->as<Module>() == module.get());
- REQUIRE(raw->as<AGIUnit>() == nullptr);
+ CATCH_REQUIRE(raw->as<Linear>() == module.get());
+ CATCH_REQUIRE(raw->as<LinearImpl>() == module.get());
+ CATCH_REQUIRE(raw->as<Module>() == module.get());
+ CATCH_REQUIRE(raw->as<AGIUnit>() == nullptr);
Module& raw_ref = *raw.get();
- REQUIRE(raw_ref.as<Linear>() == module.get());
- REQUIRE(raw_ref.as<LinearImpl>() == module.get());
- REQUIRE(raw_ref.as<Module>() == module.get());
- REQUIRE(raw_ref.as<AGIUnit>() == nullptr);
+ CATCH_REQUIRE(raw_ref.as<Linear>() == module.get());
+ CATCH_REQUIRE(raw_ref.as<LinearImpl>() == module.get());
+ CATCH_REQUIRE(raw_ref.as<Module>() == module.get());
+ CATCH_REQUIRE(raw_ref.as<AGIUnit>() == nullptr);
if (auto* linear = raw_ref.as<Linear>()) {
- REQUIRE(linear->weight.ndimension() == 2);
+ CATCH_REQUIRE(linear->weight.ndimension() == 2);
}
AGIUnit unit;
- REQUIRE(unit.as<Linear>() == nullptr);
- REQUIRE(unit.as<LinearImpl>() == nullptr);
- REQUIRE(unit.as<AGIUnit>() == &unit);
+ CATCH_REQUIRE(unit.as<Linear>() == nullptr);
+ CATCH_REQUIRE(unit.as<LinearImpl>() == nullptr);
+ CATCH_REQUIRE(unit.as<AGIUnit>() == &unit);
}
-TEST_CASE("module/conversions", "[multi-cuda]") {
+CATCH_TEST_CASE("module/conversions", "[multi-cuda]") {
torch::manual_seed(0);
Linear module(128, 64);
- SECTION("starts as float on CPU") {
+ CATCH_SECTION("starts as float on CPU") {
for (auto& parameter : module->parameters()) {
- REQUIRE(parameter->device() == torch::Device(torch::kCPU));
- REQUIRE(parameter->dtype() == torch::kFloat32);
+ CATCH_REQUIRE(parameter->device() == torch::Device(torch::kCPU));
+ CATCH_REQUIRE(parameter->dtype() == torch::kFloat32);
}
}
- SECTION("to(CUDA)") {
+ CATCH_SECTION("to(CUDA)") {
module->to({torch::kCUDA, 0});
for (auto& parameter : module->parameters()) {
- REQUIRE(parameter->device().type() == torch::Device::Type::CUDA);
- REQUIRE(parameter->device().index() == 0);
+ CATCH_REQUIRE(parameter->device().type() == torch::Device::Type::CUDA);
+ CATCH_REQUIRE(parameter->device().index() == 0);
}
module->to({at::kCUDA, 1});
for (auto& parameter : module->parameters()) {
- REQUIRE(parameter->device().type() == torch::Device::Type::CUDA);
- REQUIRE(parameter->device().index() == 1);
+ CATCH_REQUIRE(parameter->device().type() == torch::Device::Type::CUDA);
+ CATCH_REQUIRE(parameter->device().index() == 1);
}
}
- SECTION("to(CPU)") {
+ CATCH_SECTION("to(CPU)") {
module->to(torch::Device(torch::kCPU));
for (auto& parameter : module->parameters()) {
- REQUIRE(parameter->device().type() == torch::Device::Type::CPU);
+ CATCH_REQUIRE(parameter->device().type() == torch::Device::Type::CPU);
}
}
- SECTION("to(Int32)") {
+ CATCH_SECTION("to(Int32)") {
module->to(torch::kInt32);
for (auto& parameter : module->parameters()) {
- REQUIRE(parameter->dtype() == torch::kInt32);
+ CATCH_REQUIRE(parameter->dtype() == torch::kInt32);
}
}
- SECTION("to(Float64)") {
+ CATCH_SECTION("to(Float64)") {
module->to(torch::kFloat64);
for (auto& parameter : module->parameters()) {
- REQUIRE(parameter->dtype() == torch::kFloat64);
+ CATCH_REQUIRE(parameter->dtype() == torch::kFloat64);
}
}
- SECTION("to(CUDA, Byte)") {
+ CATCH_SECTION("to(CUDA, Byte)") {
module->to(torch::Device(torch::kCUDA, 1), torch::kUInt8);
for (auto& parameter : module->parameters()) {
- REQUIRE(parameter->device().type() == torch::Device::Type::CUDA);
- REQUIRE(parameter->device().index() == 1);
+ CATCH_REQUIRE(parameter->device().type() == torch::Device::Type::CUDA);
+ CATCH_REQUIRE(parameter->device().index() == 1);
}
for (auto& parameter : module->parameters()) {
- REQUIRE(parameter->dtype() == torch::kUInt8);
+ CATCH_REQUIRE(parameter->dtype() == torch::kUInt8);
}
}
}
-TEST_CASE("module/clone") {
+CATCH_TEST_CASE("module/clone") {
torch::manual_seed(0);
- SECTION(
+ CATCH_SECTION(
"a module that does not override clone() throws when clone() is called") {
struct UnCloneable : Module {};
UnCloneable module;
- REQUIRE_THROWS_WITH(
+ CATCH_REQUIRE_THROWS_WITH(
module.clone(), StartsWith("clone() has not been implemented"));
}
- SECTION(
+ CATCH_SECTION(
"a module that overrides clone() does not throw when clone() is called ") {
struct Cloneable : Module {
std::shared_ptr<Module> clone(
@@ -190,10 +190,10 @@
}
};
Cloneable module;
- REQUIRE_NOTHROW(module.clone());
+ CATCH_REQUIRE_NOTHROW(module.clone());
}
- SECTION("Cloning creates distinct parameters") {
+ CATCH_SECTION("Cloning creates distinct parameters") {
struct TestModule : public Cloneable<TestModule> {
TestModule() {
reset();
@@ -216,32 +216,32 @@
auto module2 = module->clone();
auto params1 = module->parameters();
auto params2 = module2->parameters();
- REQUIRE(params1.size() == 6);
- REQUIRE(params2.size() == 6);
+ CATCH_REQUIRE(params1.size() == 6);
+ CATCH_REQUIRE(params2.size() == 6);
for (auto& param : params1) {
- REQUIRE(!pointer_equal(param.value, params2[param.key]));
- REQUIRE(param->allclose(params2[param.key]));
+ CATCH_REQUIRE(!pointer_equal(param.value, params2[param.key]));
+ CATCH_REQUIRE(param->allclose(params2[param.key]));
param->add_(2);
}
for (auto& param : params1) {
- REQUIRE(!param->allclose(params2[param.key]));
+ CATCH_REQUIRE(!param->allclose(params2[param.key]));
}
auto buffers1 = module->buffers();
auto buffers2 = module2->buffers();
- REQUIRE(buffers1.size() == 1);
- REQUIRE(buffers2.size() == 1);
+ CATCH_REQUIRE(buffers1.size() == 1);
+ CATCH_REQUIRE(buffers2.size() == 1);
for (auto& buffer : buffers1) {
- REQUIRE(!pointer_equal(buffer.value, buffers2[buffer.key]));
- REQUIRE(buffer->allclose(buffers2[buffer.key]));
+ CATCH_REQUIRE(!pointer_equal(buffer.value, buffers2[buffer.key]));
+ CATCH_REQUIRE(buffer->allclose(buffers2[buffer.key]));
buffer->add_(2);
}
for (auto& buffer : buffers1) {
- REQUIRE(!buffer->allclose(buffers2[buffer.key]));
+ CATCH_REQUIRE(!buffer->allclose(buffers2[buffer.key]));
}
}
- SECTION("Cloning preserves external references") {
+ CATCH_SECTION("Cloning preserves external references") {
struct TestModule : public Cloneable<TestModule> {
TestModule() {
reset();
@@ -256,19 +256,19 @@
torch::NoGradGuard no_grad;
module->weight += 1;
}
- REQUIRE(pointer_equal(module->weight, module->parameters()["weight"]));
- REQUIRE(module->weight.allclose(module->parameters()["weight"]));
+ CATCH_REQUIRE(pointer_equal(module->weight, module->parameters()["weight"]));
+ CATCH_REQUIRE(module->weight.allclose(module->parameters()["weight"]));
auto module2 = std::dynamic_pointer_cast<TestModule>(
std::shared_ptr<Module>(module->clone()));
- REQUIRE(!pointer_equal(module2->weight, module->weight));
- REQUIRE(pointer_equal(module2->weight, module2->parameters()["weight"]));
- REQUIRE(module2->weight.allclose(module2->parameters()["weight"]));
- REQUIRE(module2->weight.allclose(module->weight));
- REQUIRE(!pointer_equal(module2->weight, module->parameters()["weight"]));
+ CATCH_REQUIRE(!pointer_equal(module2->weight, module->weight));
+ CATCH_REQUIRE(pointer_equal(module2->weight, module2->parameters()["weight"]));
+ CATCH_REQUIRE(module2->weight.allclose(module2->parameters()["weight"]));
+ CATCH_REQUIRE(module2->weight.allclose(module->weight));
+ CATCH_REQUIRE(!pointer_equal(module2->weight, module->parameters()["weight"]));
}
- SECTION("Cloning copies the values of variables of submodules") {
+ CATCH_SECTION("Cloning copies the values of variables of submodules") {
struct TestModule : public Cloneable<TestModule> {
TestModule() {
reset();
@@ -299,16 +299,16 @@
auto b = std::dynamic_pointer_cast<NestedModule>(a->clone());
- REQUIRE(!pointer_equal(b->module->weight, a->module->weight));
- REQUIRE(
+ CATCH_REQUIRE(!pointer_equal(b->module->weight, a->module->weight));
+ CATCH_REQUIRE(
pointer_equal(b->module->weight, b->module->parameters()["weight"]));
- REQUIRE(b->module->parameters()["weight"].allclose(a->module->weight));
- REQUIRE(b->module->weight.allclose(a->module->weight));
- REQUIRE(b->module->value == a->module->value);
+ CATCH_REQUIRE(b->module->parameters()["weight"].allclose(a->module->weight));
+ CATCH_REQUIRE(b->module->weight.allclose(a->module->weight));
+ CATCH_REQUIRE(b->module->value == a->module->value);
}
}
-TEST_CASE("module/clone-to-device", "[cuda]") {
+CATCH_TEST_CASE("module/clone-to-device", "[cuda]") {
struct TestModule : public Cloneable<TestModule> {
TestModule() {
reset();
@@ -324,7 +324,7 @@
torch::Tensor buffer;
};
- SECTION("Cloning preserves the device of parameters/buffers") {
+ CATCH_SECTION("Cloning preserves the device of parameters/buffers") {
TestModule m;
torch::Device device(torch::kCUDA, 0);
@@ -332,33 +332,33 @@
auto clone = m.clone();
for (const auto& parameter : clone->parameters()) {
- REQUIRE(parameter->device().type() == device.type());
- REQUIRE(parameter->device().index() == device.index());
+ CATCH_REQUIRE(parameter->device().type() == device.type());
+ CATCH_REQUIRE(parameter->device().index() == device.index());
}
for (const auto& buffer : clone->buffers()) {
- REQUIRE(buffer->device().type() == device.type());
- REQUIRE(buffer->device().index() == device.index());
+ CATCH_REQUIRE(buffer->device().type() == device.type());
+ CATCH_REQUIRE(buffer->device().index() == device.index());
}
}
- SECTION(
+ CATCH_SECTION(
"Cloning to a particular device places all parameters/buffers there") {
TestModule m;
torch::Device device(torch::kCUDA, 1);
// everything is on CPU here
auto clone = m.clone(device);
for (const auto& parameter : clone->parameters()) {
- REQUIRE(parameter->device().type() == device.type());
- REQUIRE(parameter->device().index() == device.index());
+ CATCH_REQUIRE(parameter->device().type() == device.type());
+ CATCH_REQUIRE(parameter->device().index() == device.index());
}
for (const auto& buffer : clone->buffers()) {
- REQUIRE(buffer->device().type() == device.type());
- REQUIRE(buffer->device().index() == device.index());
+ CATCH_REQUIRE(buffer->device().type() == device.type());
+ CATCH_REQUIRE(buffer->device().index() == device.index());
}
}
}
-TEST_CASE("module/parameters") {
+CATCH_TEST_CASE("module/parameters") {
torch::manual_seed(0);
struct TestModule : Module {
TestModule() {
@@ -372,19 +372,19 @@
TestModule module;
- SECTION("has correct number of parameters") {
- REQUIRE(module.parameters().size() == 3);
+ CATCH_SECTION("has correct number of parameters") {
+ CATCH_REQUIRE(module.parameters().size() == 3);
}
- SECTION("contains parameters with the correct name") {
+ CATCH_SECTION("contains parameters with the correct name") {
auto parameters = module.parameters();
- REQUIRE(parameters.contains("a"));
- REQUIRE(parameters.contains("b"));
- REQUIRE(parameters.contains("c"));
+ CATCH_REQUIRE(parameters.contains("a"));
+ CATCH_REQUIRE(parameters.contains("b"));
+ CATCH_REQUIRE(parameters.contains("c"));
}
}
-TEST_CASE("module/buffers") {
+CATCH_TEST_CASE("module/buffers") {
torch::manual_seed(0);
struct TestModule : Module {
TestModule() {
@@ -398,19 +398,19 @@
TestModule module;
- SECTION("has correct number of buffers") {
- REQUIRE(module.buffers().size() == 3);
+ CATCH_SECTION("has correct number of buffers") {
+ CATCH_REQUIRE(module.buffers().size() == 3);
}
- SECTION("contains buffers with the correct name") {
+ CATCH_SECTION("contains buffers with the correct name") {
auto buffers = module.buffers();
- REQUIRE(buffers.contains("a"));
- REQUIRE(buffers.contains("b"));
- REQUIRE(buffers.contains("c"));
+ CATCH_REQUIRE(buffers.contains("a"));
+ CATCH_REQUIRE(buffers.contains("b"));
+ CATCH_REQUIRE(buffers.contains("c"));
}
}
-TEST_CASE("module/default-constructor") {
+CATCH_TEST_CASE("module/default-constructor") {
struct AImpl : torch::nn::Module {
AImpl() : x_(123) {}
AImpl(int x) : x_(x) {}
@@ -420,20 +420,20 @@
{
A a;
- REQUIRE(a);
- REQUIRE(!a.is_empty());
- REQUIRE(a->x_ == 123);
+ CATCH_REQUIRE(a);
+ CATCH_REQUIRE(!a.is_empty());
+ CATCH_REQUIRE(a->x_ == 123);
}
{
A a(5);
- REQUIRE(a);
- REQUIRE(!a.is_empty());
- REQUIRE(a->x_ == 5);
+ CATCH_REQUIRE(a);
+ CATCH_REQUIRE(!a.is_empty());
+ CATCH_REQUIRE(a->x_ == 5);
}
{
A a = nullptr;
- REQUIRE(!a);
- REQUIRE(a.is_empty());
- REQUIRE_THROWS_WITH(a->x_, StartsWith("Accessing empty ModuleHolder"));
+ CATCH_REQUIRE(!a);
+ CATCH_REQUIRE(a.is_empty());
+ CATCH_REQUIRE_THROWS_WITH(a->x_, StartsWith("Accessing empty ModuleHolder"));
}
}
diff --git a/test/cpp/api/modules.cpp b/test/cpp/api/modules.cpp
index 928a39f..7d4f9ab 100644
--- a/test/cpp/api/modules.cpp
+++ b/test/cpp/api/modules.cpp
@@ -1,4 +1,4 @@
-#include <catch.hpp>
+#include "catch_utils.hpp"
#include <torch/nn/module.h>
#include <torch/nn/modules/batchnorm.h>
@@ -39,92 +39,92 @@
std::shared_ptr<TestModel> t;
};
-TEST_CASE("modules") {
+CATCH_TEST_CASE("modules") {
torch::manual_seed(0);
- SECTION("conv") {
- SECTION("1d") {
+ CATCH_SECTION("conv") {
+ CATCH_SECTION("1d") {
Conv1d model(Conv1dOptions(3, 2, 3).stride(2));
auto x = torch::randn({2, 3, 5}, torch::requires_grad());
auto y = model->forward(x);
torch::Tensor s = y.sum();
s.backward();
- REQUIRE(y.ndimension() == 3);
- REQUIRE(s.ndimension() == 0);
+ CATCH_REQUIRE(y.ndimension() == 3);
+ CATCH_REQUIRE(s.ndimension() == 0);
for (auto i = 0; i < 3; i++) {
- REQUIRE(y.size(i) == 2);
+ CATCH_REQUIRE(y.size(i) == 2);
}
- REQUIRE(model->parameters()["weight"].grad().numel() == 3 * 2 * 3);
+ CATCH_REQUIRE(model->parameters()["weight"].grad().numel() == 3 * 2 * 3);
}
- SECTION("2d") {
- SECTION("even") {
+ CATCH_SECTION("2d") {
+ CATCH_SECTION("even") {
Conv2d model(Conv2dOptions(3, 2, 3).stride(2));
auto x = torch::randn({2, 3, 5, 5}, torch::requires_grad());
auto y = model->forward(x);
torch::Tensor s = y.sum();
s.backward();
- REQUIRE(y.ndimension() == 4);
- REQUIRE(s.ndimension() == 0);
+ CATCH_REQUIRE(y.ndimension() == 4);
+ CATCH_REQUIRE(s.ndimension() == 0);
for (auto i = 0; i < 4; i++) {
- REQUIRE(y.size(i) == 2);
+ CATCH_REQUIRE(y.size(i) == 2);
}
- REQUIRE(model->parameters()["weight"].grad().numel() == 3 * 2 * 3 * 3);
+ CATCH_REQUIRE(model->parameters()["weight"].grad().numel() == 3 * 2 * 3 * 3);
}
- SECTION("uneven") {
+ CATCH_SECTION("uneven") {
Conv2d model(Conv2dOptions(3, 2, {3, 2}).stride({2, 2}));
auto x = torch::randn({2, 3, 5, 4}, torch::requires_grad());
auto y = model->forward(x);
torch::Tensor s = y.sum();
s.backward();
- REQUIRE(y.ndimension() == 4);
- REQUIRE(s.ndimension() == 0);
+ CATCH_REQUIRE(y.ndimension() == 4);
+ CATCH_REQUIRE(s.ndimension() == 0);
for (auto i = 0; i < 4; i++) {
- REQUIRE(y.size(i) == 2);
+ CATCH_REQUIRE(y.size(i) == 2);
}
- REQUIRE(model->parameters()["weight"].grad().numel() == 3 * 2 * 3 * 2);
+ CATCH_REQUIRE(model->parameters()["weight"].grad().numel() == 3 * 2 * 3 * 2);
}
}
- SECTION("3d") {
+ CATCH_SECTION("3d") {
Conv3d model(Conv3dOptions(3, 2, 3).stride(2));
auto x = torch::randn({2, 3, 5, 5, 5}, torch::requires_grad());
auto y = model->forward(x);
torch::Tensor s = y.sum();
s.backward();
- REQUIRE(y.ndimension() == 5);
- REQUIRE(s.ndimension() == 0);
+ CATCH_REQUIRE(y.ndimension() == 5);
+ CATCH_REQUIRE(s.ndimension() == 0);
for (auto i = 0; i < 5; i++) {
- REQUIRE(y.size(i) == 2);
+ CATCH_REQUIRE(y.size(i) == 2);
}
- REQUIRE(
+ CATCH_REQUIRE(
model->parameters()["weight"].grad().numel() == 3 * 2 * 3 * 3 * 3);
}
}
- SECTION("linear") {
- SECTION("basic1") {
+ CATCH_SECTION("linear") {
+ CATCH_SECTION("basic1") {
Linear model(5, 2);
auto x = torch::randn({10, 5}, torch::requires_grad());
auto y = model->forward(x);
torch::Tensor s = y.sum();
s.backward();
- REQUIRE(y.ndimension() == 2);
- REQUIRE(s.ndimension() == 0);
- REQUIRE(y.size(0) == 10);
- REQUIRE(y.size(1) == 2);
+ CATCH_REQUIRE(y.ndimension() == 2);
+ CATCH_REQUIRE(s.ndimension() == 0);
+ CATCH_REQUIRE(y.size(0) == 10);
+ CATCH_REQUIRE(y.size(1) == 2);
- REQUIRE(model->parameters()["weight"].grad().numel() == 2 * 5);
+ CATCH_REQUIRE(model->parameters()["weight"].grad().numel() == 2 * 5);
}
}
- SECTION("simple") {
+ CATCH_SECTION("simple") {
auto model = std::make_shared<SimpleContainer>();
auto l1 = model->add(Linear(10, 3), "l1");
auto l2 = model->add(Linear(3, 5), "l2");
@@ -136,20 +136,20 @@
x = l3->forward(x).clamp_min(0);
x.backward();
- REQUIRE(x.ndimension() == 2);
- REQUIRE(x.size(0) == 1000);
- REQUIRE(x.size(1) == 100);
- REQUIRE(x.min().toCFloat() == 0);
+ CATCH_REQUIRE(x.ndimension() == 2);
+ CATCH_REQUIRE(x.size(0) == 1000);
+ CATCH_REQUIRE(x.size(1) == 100);
+ CATCH_REQUIRE(x.min().toCFloat() == 0);
}
- SECTION("embedding") {
- SECTION("basic") {
+ CATCH_SECTION("embedding") {
+ CATCH_SECTION("basic") {
const int64_t dict_size = 10;
Embedding model(dict_size, 2);
- REQUIRE(model->parameters().contains("weight"));
- REQUIRE(model->weight.ndimension() == 2);
- REQUIRE(model->weight.size(0) == dict_size);
- REQUIRE(model->weight.size(1) == 2);
+ CATCH_REQUIRE(model->parameters().contains("weight"));
+ CATCH_REQUIRE(model->weight.ndimension() == 2);
+ CATCH_REQUIRE(model->weight.size(0) == dict_size);
+ CATCH_REQUIRE(model->weight.size(1) == 2);
// Cannot get gradients to change indices (input) - only for embedding
// params
@@ -158,65 +158,65 @@
torch::Tensor s = y.sum();
s.backward();
- REQUIRE(y.ndimension() == 2);
- REQUIRE(s.ndimension() == 0);
- REQUIRE(y.size(0) == 10);
- REQUIRE(y.size(1) == 2);
+ CATCH_REQUIRE(y.ndimension() == 2);
+ CATCH_REQUIRE(s.ndimension() == 0);
+ CATCH_REQUIRE(y.size(0) == 10);
+ CATCH_REQUIRE(y.size(1) == 2);
- REQUIRE(model->parameters()["weight"].grad().numel() == 2 * dict_size);
+ CATCH_REQUIRE(model->parameters()["weight"].grad().numel() == 2 * dict_size);
}
- SECTION("list") {
+ CATCH_SECTION("list") {
Embedding model(6, 4);
auto x = torch::full({2, 3}, 5, torch::kInt64);
auto y = model->forward(x);
torch::Tensor s = y.sum();
s.backward();
- REQUIRE(y.ndimension() == 3);
- REQUIRE(y.size(0) == 2);
- REQUIRE(y.size(1) == 3);
- REQUIRE(y.size(2) == 4);
+ CATCH_REQUIRE(y.ndimension() == 3);
+ CATCH_REQUIRE(y.size(0) == 2);
+ CATCH_REQUIRE(y.size(1) == 3);
+ CATCH_REQUIRE(y.size(2) == 4);
}
}
- SECTION("dropout") {
+ CATCH_SECTION("dropout") {
Dropout dropout(0.5);
torch::Tensor x = torch::ones(100, torch::requires_grad());
torch::Tensor y = dropout->forward(x);
y.backward();
- REQUIRE(y.ndimension() == 1);
- REQUIRE(y.size(0) == 100);
- REQUIRE(y.sum().toCFloat() < 130); // Probably
- REQUIRE(y.sum().toCFloat() > 70); // Probably
+ CATCH_REQUIRE(y.ndimension() == 1);
+ CATCH_REQUIRE(y.size(0) == 100);
+ CATCH_REQUIRE(y.sum().toCFloat() < 130); // Probably
+ CATCH_REQUIRE(y.sum().toCFloat() > 70); // Probably
dropout->eval();
y = dropout->forward(x);
- REQUIRE(y.sum().toCFloat() == 100);
+ CATCH_REQUIRE(y.sum().toCFloat() == 100);
}
- SECTION("param") {
+ CATCH_SECTION("param") {
auto model = std::make_shared<NestedModel>();
auto parameters = model->parameters();
- REQUIRE(parameters["param"].size(0) == 3);
- REQUIRE(parameters["param"].size(1) == 2);
- REQUIRE(parameters["param"].size(2) == 21);
- REQUIRE(parameters["l1.bias"].size(0) == 20);
- REQUIRE(parameters["l1.weight"].size(0) == 20);
- REQUIRE(parameters["l1.weight"].size(1) == 5);
- REQUIRE(parameters["test.l1.bias"].size(0) == 3);
- REQUIRE(parameters["test.l1.weight"].size(0) == 3);
- REQUIRE(parameters["test.l1.weight"].size(1) == 10);
- REQUIRE(parameters["test.l2.bias"].size(0) == 5);
- REQUIRE(parameters["test.l2.weight"].size(0) == 5);
- REQUIRE(parameters["test.l2.weight"].size(1) == 3);
- REQUIRE(parameters["test.l3.bias"].size(0) == 100);
- REQUIRE(parameters["test.l3.weight"].size(0) == 100);
- REQUIRE(parameters["test.l3.weight"].size(1) == 5);
+ CATCH_REQUIRE(parameters["param"].size(0) == 3);
+ CATCH_REQUIRE(parameters["param"].size(1) == 2);
+ CATCH_REQUIRE(parameters["param"].size(2) == 21);
+ CATCH_REQUIRE(parameters["l1.bias"].size(0) == 20);
+ CATCH_REQUIRE(parameters["l1.weight"].size(0) == 20);
+ CATCH_REQUIRE(parameters["l1.weight"].size(1) == 5);
+ CATCH_REQUIRE(parameters["test.l1.bias"].size(0) == 3);
+ CATCH_REQUIRE(parameters["test.l1.weight"].size(0) == 3);
+ CATCH_REQUIRE(parameters["test.l1.weight"].size(1) == 10);
+ CATCH_REQUIRE(parameters["test.l2.bias"].size(0) == 5);
+ CATCH_REQUIRE(parameters["test.l2.weight"].size(0) == 5);
+ CATCH_REQUIRE(parameters["test.l2.weight"].size(1) == 3);
+ CATCH_REQUIRE(parameters["test.l3.bias"].size(0) == 100);
+ CATCH_REQUIRE(parameters["test.l3.weight"].size(0) == 100);
+ CATCH_REQUIRE(parameters["test.l3.weight"].size(1) == 5);
}
- SECTION("functional") {
+ CATCH_SECTION("functional") {
{
bool was_called = false;
auto functional = Functional([&was_called](torch::Tensor input) {
@@ -224,63 +224,63 @@
return input;
});
auto output = functional->forward(torch::ones(5, torch::requires_grad()));
- REQUIRE(was_called);
- REQUIRE(output.equal(torch::ones(5, torch::requires_grad())));
+ CATCH_REQUIRE(was_called);
+ CATCH_REQUIRE(output.equal(torch::ones(5, torch::requires_grad())));
was_called = false;
// Use the call operator overload here.
output = functional(torch::ones(5, torch::requires_grad()));
- REQUIRE(was_called);
- REQUIRE(output.equal(torch::ones(5, torch::requires_grad())));
+ CATCH_REQUIRE(was_called);
+ CATCH_REQUIRE(output.equal(torch::ones(5, torch::requires_grad())));
}
{
auto functional = Functional(torch::relu);
- REQUIRE(functional(torch::ones({})).toCFloat() == 1);
- REQUIRE(functional(torch::ones({})).toCFloat() == 1);
- REQUIRE(functional(torch::ones({}) * -1).toCFloat() == 0);
+ CATCH_REQUIRE(functional(torch::ones({})).toCFloat() == 1);
+ CATCH_REQUIRE(functional(torch::ones({})).toCFloat() == 1);
+ CATCH_REQUIRE(functional(torch::ones({}) * -1).toCFloat() == 0);
}
{
auto functional =
Functional(torch::elu, /*alpha=*/1, /*scale=*/0, /*input_scale=*/1);
- REQUIRE(functional(torch::ones({})).toCFloat() == 0);
+ CATCH_REQUIRE(functional(torch::ones({})).toCFloat() == 0);
}
}
- SECTION("batchnorm") {
+ CATCH_SECTION("batchnorm") {
{
BatchNorm bn(5);
// Is stateful by default.
- REQUIRE(bn->options.stateful());
+ CATCH_REQUIRE(bn->options.stateful());
- REQUIRE(bn->running_mean.defined());
- REQUIRE(bn->running_mean.dim() == 1);
- REQUIRE(bn->running_mean.size(0) == 5);
+ CATCH_REQUIRE(bn->running_mean.defined());
+ CATCH_REQUIRE(bn->running_mean.dim() == 1);
+ CATCH_REQUIRE(bn->running_mean.size(0) == 5);
- REQUIRE(bn->running_variance.defined());
- REQUIRE(bn->running_variance.dim() == 1);
- REQUIRE(bn->running_variance.size(0) == 5);
+ CATCH_REQUIRE(bn->running_variance.defined());
+ CATCH_REQUIRE(bn->running_variance.dim() == 1);
+ CATCH_REQUIRE(bn->running_variance.size(0) == 5);
// Is affine by default.
- REQUIRE(bn->options.affine());
+ CATCH_REQUIRE(bn->options.affine());
- REQUIRE(bn->weight.defined());
- REQUIRE(bn->weight.dim() == 1);
- REQUIRE(bn->weight.size(0) == 5);
+ CATCH_REQUIRE(bn->weight.defined());
+ CATCH_REQUIRE(bn->weight.dim() == 1);
+ CATCH_REQUIRE(bn->weight.size(0) == 5);
- REQUIRE(bn->bias.defined());
- REQUIRE(bn->bias.dim() == 1);
- REQUIRE(bn->bias.size(0) == 5);
+ CATCH_REQUIRE(bn->bias.defined());
+ CATCH_REQUIRE(bn->bias.dim() == 1);
+ CATCH_REQUIRE(bn->bias.size(0) == 5);
}
{
BatchNorm bn(BatchNormOptions(5).stateful(false).affine(false));
- REQUIRE(!bn->running_mean.defined());
- REQUIRE(!bn->running_variance.defined());
- REQUIRE(!bn->weight.defined());
- REQUIRE(!bn->bias.defined());
+ CATCH_REQUIRE(!bn->running_mean.defined());
+ CATCH_REQUIRE(!bn->running_variance.defined());
+ CATCH_REQUIRE(!bn->weight.defined());
+ CATCH_REQUIRE(!bn->bias.defined());
- REQUIRE_THROWS_WITH(
+ CATCH_REQUIRE_THROWS_WITH(
bn->forward(torch::ones({2, 5})),
StartsWith("Calling BatchNorm::forward is only permitted "
"when the 'stateful' option is true (was false). "
@@ -298,14 +298,14 @@
auto output = bn->pure_forward(input, mean, variance);
auto expected =
(input - mean) / torch::sqrt(variance + bn->options.eps());
- REQUIRE(output.allclose(expected));
+ CATCH_REQUIRE(output.allclose(expected));
}
}
}
-TEST_CASE("modules_cuda", "[cuda]") {
+CATCH_TEST_CASE("modules_cuda", "[cuda]") {
torch::manual_seed(0);
- SECTION("1") {
+ CATCH_SECTION("1") {
Linear model(5, 2);
model->to(torch::kCUDA);
auto x =
@@ -314,15 +314,15 @@
torch::Tensor s = y.sum();
s.backward();
- REQUIRE(y.ndimension() == 2);
- REQUIRE(s.ndimension() == 0);
- REQUIRE(y.size(0) == 10);
- REQUIRE(y.size(1) == 2);
+ CATCH_REQUIRE(y.ndimension() == 2);
+ CATCH_REQUIRE(s.ndimension() == 0);
+ CATCH_REQUIRE(y.size(0) == 10);
+ CATCH_REQUIRE(y.size(1) == 2);
- REQUIRE(model->parameters()["weight"].grad().numel() == 2 * 5);
+ CATCH_REQUIRE(model->parameters()["weight"].grad().numel() == 2 * 5);
}
- SECTION("2") {
+ CATCH_SECTION("2") {
Linear model(5, 2);
model->to(torch::kCUDA);
model->to(torch::kCPU);
@@ -331,11 +331,11 @@
torch::Tensor s = y.sum();
s.backward();
- REQUIRE(y.ndimension() == 2);
- REQUIRE(s.ndimension() == 0);
- REQUIRE(y.size(0) == 10);
- REQUIRE(y.size(1) == 2);
+ CATCH_REQUIRE(y.ndimension() == 2);
+ CATCH_REQUIRE(s.ndimension() == 0);
+ CATCH_REQUIRE(y.size(0) == 10);
+ CATCH_REQUIRE(y.size(1) == 2);
- REQUIRE(model->parameters()["weight"].grad().numel() == 2 * 5);
+ CATCH_REQUIRE(model->parameters()["weight"].grad().numel() == 2 * 5);
}
}
diff --git a/test/cpp/api/optim.cpp b/test/cpp/api/optim.cpp
index ab27818..4cb398d 100644
--- a/test/cpp/api/optim.cpp
+++ b/test/cpp/api/optim.cpp
@@ -1,4 +1,4 @@
-#include <catch.hpp>
+#include "catch_utils.hpp"
#include <torch/nn/module.h>
#include <torch/nn/modules/functional.h>
@@ -118,24 +118,24 @@
optimizer.step();
if (i % kSampleEvery == 0) {
- REQUIRE(
+ CATCH_REQUIRE(
expected_parameters.at(i / kSampleEvery).size() == parameters.size());
for (size_t p = 0; p < parameters.size(); ++p) {
- REQUIRE(parameters.at(p)->defined());
+ CATCH_REQUIRE(parameters.at(p)->defined());
auto computed = parameters.at(p)->flatten();
auto expected = expected_parameters.at(i / kSampleEvery).at(p);
if (!computed.allclose(expected, /*rtol=*/1e-3, /*atol=*/5e-4)) {
std::cout << "Iteration " << i << ": " << computed
<< " != " << expected << " (parameter " << p << ")"
<< std::endl;
- REQUIRE(false);
+ CATCH_REQUIRE(false);
}
}
}
}
}
-TEST_CASE("Optim/BasicInterface") {
+CATCH_TEST_CASE("Optim/BasicInterface") {
struct MyOptimizer : Optimizer {
using Optimizer::Optimizer;
void step() override {}
@@ -144,139 +144,139 @@
torch::ones({2, 3}), torch::zeros({2, 3}), torch::rand({2, 3})};
{
MyOptimizer optimizer(parameters);
- REQUIRE(optimizer.size() == parameters.size());
+ CATCH_REQUIRE(optimizer.size() == parameters.size());
}
{
MyOptimizer optimizer;
- REQUIRE(optimizer.size() == 0);
+ CATCH_REQUIRE(optimizer.size() == 0);
optimizer.add_parameters(parameters);
- REQUIRE(optimizer.size() == parameters.size());
+ CATCH_REQUIRE(optimizer.size() == parameters.size());
for (size_t p = 0; p < parameters.size(); ++p) {
- REQUIRE(optimizer.parameters()[p].allclose(parameters[p]));
+ CATCH_REQUIRE(optimizer.parameters()[p].allclose(parameters[p]));
}
}
{
Linear linear(3, 4);
MyOptimizer optimizer(linear->parameters());
- REQUIRE(optimizer.size() == linear->parameters().size());
+ CATCH_REQUIRE(optimizer.size() == linear->parameters().size());
}
}
-TEST_CASE("Optim/XORConvergence/SGD") {
- REQUIRE(test_optimizer_xor<SGD>(
+CATCH_TEST_CASE("Optim/XORConvergence/SGD") {
+ CATCH_REQUIRE(test_optimizer_xor<SGD>(
SGDOptions(0.1).momentum(0.9).nesterov(true).weight_decay(1e-6)));
}
-TEST_CASE("Optim/XORConvergence/Adagrad") {
- REQUIRE(test_optimizer_xor<Adagrad>(
+CATCH_TEST_CASE("Optim/XORConvergence/Adagrad") {
+ CATCH_REQUIRE(test_optimizer_xor<Adagrad>(
AdagradOptions(1.0).weight_decay(1e-6).lr_decay(1e-3)));
}
-TEST_CASE("Optim/XORConvergence/RMSprop") {
- REQUIRE(test_optimizer_xor<RMSprop>(RMSpropOptions(0.1).centered(true)));
+CATCH_TEST_CASE("Optim/XORConvergence/RMSprop") {
+ CATCH_REQUIRE(test_optimizer_xor<RMSprop>(RMSpropOptions(0.1).centered(true)));
}
-TEST_CASE("Optim/XORConvergence/RMSpropWithMomentum") {
- REQUIRE(test_optimizer_xor<RMSprop>(
+CATCH_TEST_CASE("Optim/XORConvergence/RMSpropWithMomentum") {
+ CATCH_REQUIRE(test_optimizer_xor<RMSprop>(
RMSpropOptions(0.1).momentum(0.9).weight_decay(1e-6)));
}
-TEST_CASE("Optim/XORConvergence/Adam") {
- REQUIRE(test_optimizer_xor<Adam>(AdamOptions(0.1).weight_decay(1e-6)));
+CATCH_TEST_CASE("Optim/XORConvergence/Adam") {
+ CATCH_REQUIRE(test_optimizer_xor<Adam>(AdamOptions(0.1).weight_decay(1e-6)));
}
-TEST_CASE("Optim/XORConvergence/AdamWithAmsgrad") {
- REQUIRE(test_optimizer_xor<Adam>(
+CATCH_TEST_CASE("Optim/XORConvergence/AdamWithAmsgrad") {
+ CATCH_REQUIRE(test_optimizer_xor<Adam>(
AdamOptions(0.1).weight_decay(1e-6).amsgrad(true)));
}
-TEST_CASE("Optim/ProducesPyTorchValues/Adam") {
+CATCH_TEST_CASE("Optim/ProducesPyTorchValues/Adam") {
check_exact_values<Adam>(AdamOptions(1.0), expected_parameters::Adam);
}
-TEST_CASE("Optim/ProducesPyTorchValues/AdamWithWeightDecay") {
+CATCH_TEST_CASE("Optim/ProducesPyTorchValues/AdamWithWeightDecay") {
check_exact_values<Adam>(
AdamOptions(1.0).weight_decay(1e-2),
expected_parameters::Adam_with_weight_decay);
}
-TEST_CASE("Optim/ProducesPyTorchValues/AdamWithWeightDecayAndAMSGrad") {
+CATCH_TEST_CASE("Optim/ProducesPyTorchValues/AdamWithWeightDecayAndAMSGrad") {
check_exact_values<Adam>(
AdamOptions(1.0).weight_decay(1e-6).amsgrad(true),
expected_parameters::Adam_with_weight_decay_and_amsgrad);
}
-TEST_CASE("Optim/ProducesPyTorchValues/Adagrad") {
+CATCH_TEST_CASE("Optim/ProducesPyTorchValues/Adagrad") {
check_exact_values<Adagrad>(
AdagradOptions(1.0), expected_parameters::Adagrad);
}
-TEST_CASE("Optim/ProducesPyTorchValues/AdagradWithWeightDecay") {
+CATCH_TEST_CASE("Optim/ProducesPyTorchValues/AdagradWithWeightDecay") {
check_exact_values<Adagrad>(
AdagradOptions(1.0).weight_decay(1e-2),
expected_parameters::Adagrad_with_weight_decay);
}
-TEST_CASE("Optim/ProducesPyTorchValues/AdagradWithWeightDecayAndLRDecay") {
+CATCH_TEST_CASE("Optim/ProducesPyTorchValues/AdagradWithWeightDecayAndLRDecay") {
check_exact_values<Adagrad>(
AdagradOptions(1.0).weight_decay(1e-6).lr_decay(1e-3),
expected_parameters::Adagrad_with_weight_decay_and_lr_decay);
}
-TEST_CASE("Optim/ProducesPyTorchValues/RMSprop") {
+CATCH_TEST_CASE("Optim/ProducesPyTorchValues/RMSprop") {
check_exact_values<RMSprop>(
RMSpropOptions(0.1), expected_parameters::RMSprop);
}
-TEST_CASE("Optim/ProducesPyTorchValues/RMSpropWithWeightDecay") {
+CATCH_TEST_CASE("Optim/ProducesPyTorchValues/RMSpropWithWeightDecay") {
check_exact_values<RMSprop>(
RMSpropOptions(0.1).weight_decay(1e-2),
expected_parameters::RMSprop_with_weight_decay);
}
-TEST_CASE("Optim/ProducesPyTorchValues/RMSpropWithWeightDecayAndCentered") {
+CATCH_TEST_CASE("Optim/ProducesPyTorchValues/RMSpropWithWeightDecayAndCentered") {
check_exact_values<RMSprop>(
RMSpropOptions(0.1).weight_decay(1e-6).centered(true),
expected_parameters::RMSprop_with_weight_decay_and_centered);
}
-TEST_CASE(
+CATCH_TEST_CASE(
"Optim/ProducesPyTorchValues/RMSpropWithWeightDecayAndCenteredAndMomentum") {
check_exact_values<RMSprop>(
RMSpropOptions(0.1).weight_decay(1e-6).centered(true).momentum(0.9),
expected_parameters::RMSprop_with_weight_decay_and_centered_and_momentum);
}
-TEST_CASE("Optim/ProducesPyTorchValues/SGD") {
+CATCH_TEST_CASE("Optim/ProducesPyTorchValues/SGD") {
check_exact_values<SGD>(SGDOptions(0.1), expected_parameters::SGD);
}
-TEST_CASE("Optim/ProducesPyTorchValues/SGDWithWeightDecay") {
+CATCH_TEST_CASE("Optim/ProducesPyTorchValues/SGDWithWeightDecay") {
check_exact_values<SGD>(
SGDOptions(0.1).weight_decay(1e-2),
expected_parameters::SGD_with_weight_decay);
}
-TEST_CASE("Optim/ProducesPyTorchValues/SGDWithWeightDecayAndMomentum") {
+CATCH_TEST_CASE("Optim/ProducesPyTorchValues/SGDWithWeightDecayAndMomentum") {
check_exact_values<SGD>(
SGDOptions(0.1).weight_decay(1e-2).momentum(0.9),
expected_parameters::SGD_with_weight_decay_and_momentum);
}
-TEST_CASE("Optim/ProducesPyTorchValues/SGDWithWeightDecayAndNesterovMomentum") {
+CATCH_TEST_CASE("Optim/ProducesPyTorchValues/SGDWithWeightDecayAndNesterovMomentum") {
check_exact_values<SGD>(
SGDOptions(0.1).weight_decay(1e-6).momentum(0.9).nesterov(true),
expected_parameters::SGD_with_weight_decay_and_nesterov_momentum);
}
-TEST_CASE("Optim/ZeroGrad") {
+CATCH_TEST_CASE("Optim/ZeroGrad") {
torch::manual_seed(0);
Linear model(2, 8);
SGD optimizer(model->parameters(), 0.1);
for (const auto& parameter : model->parameters()) {
- REQUIRE(!parameter->grad().defined());
+ CATCH_REQUIRE(!parameter->grad().defined());
}
auto output = model->forward(torch::ones({5, 2}));
@@ -284,19 +284,19 @@
loss.backward();
for (const auto& parameter : model->parameters()) {
- REQUIRE(parameter->grad().defined());
- REQUIRE(parameter->grad().sum().toCFloat() > 0);
+ CATCH_REQUIRE(parameter->grad().defined());
+ CATCH_REQUIRE(parameter->grad().sum().toCFloat() > 0);
}
optimizer.zero_grad();
for (const auto& parameter : model->parameters()) {
- REQUIRE(parameter->grad().defined());
- REQUIRE(parameter->grad().sum().toCFloat() == 0);
+ CATCH_REQUIRE(parameter->grad().defined());
+ CATCH_REQUIRE(parameter->grad().sum().toCFloat() == 0);
}
}
-TEST_CASE("Optim/ExternalVectorOfParameters") {
+CATCH_TEST_CASE("Optim/ExternalVectorOfParameters") {
torch::manual_seed(0);
std::vector<torch::Tensor> parameters = {
@@ -313,12 +313,12 @@
optimizer.step();
- REQUIRE(parameters[0].allclose(original_parameters[0] - 1.0));
- REQUIRE(parameters[1].allclose(original_parameters[1] - 1.0));
- REQUIRE(parameters[2].allclose(original_parameters[2] - 1.0));
+ CATCH_REQUIRE(parameters[0].allclose(original_parameters[0] - 1.0));
+ CATCH_REQUIRE(parameters[1].allclose(original_parameters[1] - 1.0));
+ CATCH_REQUIRE(parameters[2].allclose(original_parameters[2] - 1.0));
}
-TEST_CASE("Optim/AddParameter/LBFGS") {
+CATCH_TEST_CASE("Optim/AddParameter/LBFGS") {
torch::manual_seed(0);
std::vector<torch::Tensor> parameters = {torch::randn({5, 5})};
diff --git a/test/cpp/api/parallel.cpp b/test/cpp/api/parallel.cpp
index a151758..33e3a16 100644
--- a/test/cpp/api/parallel.cpp
+++ b/test/cpp/api/parallel.cpp
@@ -1,4 +1,4 @@
-#include <catch.hpp>
+#include "catch_utils.hpp"
#include <torch/csrc/autograd/functions/comm.h>
#include <torch/nn/module.h>
@@ -19,92 +19,92 @@
#ifdef USE_CUDA
-TEST_CASE("Parallel/DifferentiableScatter", "[multi-cuda]") {
+CATCH_TEST_CASE("Parallel/DifferentiableScatter", "[multi-cuda]") {
Scatter scatter(
{torch::Device(torch::kCUDA, 0), torch::Device(torch::kCUDA, 1)});
auto input = torch::ones(10, torch::requires_grad(true));
auto output = scatter.apply({input});
- REQUIRE(output.size() == 2);
- REQUIRE(output[0].size(0) == 5);
- REQUIRE(output[1].size(0) == 5);
+ CATCH_REQUIRE(output.size() == 2);
+ CATCH_REQUIRE(output[0].size(0) == 5);
+ CATCH_REQUIRE(output[1].size(0) == 5);
- REQUIRE(torch::cat({output[0].to(torch::kCPU), output[1].to(torch::kCPU)})
+ CATCH_REQUIRE(torch::cat({output[0].to(torch::kCPU), output[1].to(torch::kCPU)})
.allclose(input));
auto sum = output[0].to({torch::kCUDA, 1}) + output[1];
sum.backward();
- REQUIRE(input.grad().defined());
- REQUIRE(input.grad().device().is_cpu());
- REQUIRE(input.grad().sum().toCInt() == 10);
+ CATCH_REQUIRE(input.grad().defined());
+ CATCH_REQUIRE(input.grad().device().is_cpu());
+ CATCH_REQUIRE(input.grad().sum().toCInt() == 10);
}
-TEST_CASE("Parallel/DifferentiableGather", "[multi-cuda]") {
+CATCH_TEST_CASE("Parallel/DifferentiableGather", "[multi-cuda]") {
Gather gather(torch::Device(torch::kCUDA, 1));
auto a = torch::ones(5, torch::requires_grad(true).device({torch::kCUDA, 0}));
auto b = torch::ones(5, torch::requires_grad(true).device({torch::kCUDA, 1}));
auto outputs = gather.apply({a, b});
- REQUIRE(outputs.size() == 1);
+ CATCH_REQUIRE(outputs.size() == 1);
auto& output = outputs.front();
- REQUIRE(output.size(0) == 10);
- REQUIRE(output.device() == torch::Device(torch::kCUDA, 1));
+ CATCH_REQUIRE(output.size(0) == 10);
+ CATCH_REQUIRE(output.device() == torch::Device(torch::kCUDA, 1));
auto chunks = output.chunk(2);
- REQUIRE(chunks[0].to({torch::kCUDA, 0}).allclose(a));
- REQUIRE(chunks[1].allclose(b));
+ CATCH_REQUIRE(chunks[0].to({torch::kCUDA, 0}).allclose(a));
+ CATCH_REQUIRE(chunks[1].allclose(b));
output.backward();
- REQUIRE(a.grad().defined());
- REQUIRE(a.grad().device() == torch::Device(torch::kCUDA, 0));
- REQUIRE(a.grad().sum().toCInt() == 5);
+ CATCH_REQUIRE(a.grad().defined());
+ CATCH_REQUIRE(a.grad().device() == torch::Device(torch::kCUDA, 0));
+ CATCH_REQUIRE(a.grad().sum().toCInt() == 5);
- REQUIRE(b.grad().defined());
- REQUIRE(b.grad().device() == torch::Device(torch::kCUDA, 1));
- REQUIRE(b.grad().sum().toCInt() == 5);
+ CATCH_REQUIRE(b.grad().defined());
+ CATCH_REQUIRE(b.grad().device() == torch::Device(torch::kCUDA, 1));
+ CATCH_REQUIRE(b.grad().sum().toCInt() == 5);
}
-TEST_CASE("Parallel/Replicate", "[multi-cuda]") {
+CATCH_TEST_CASE("Parallel/Replicate", "[multi-cuda]") {
Linear linear(3, 4);
auto replicas = parallel::replicate(
linear, {torch::Device(torch::kCUDA, 0), torch::Device(torch::kCUDA, 1)});
- REQUIRE(replicas.size() == 2);
+ CATCH_REQUIRE(replicas.size() == 2);
auto original_parameters = linear->parameters();
auto replica1_parameters = replicas[0]->parameters();
for (auto& parameter : replica1_parameters) {
- REQUIRE(parameter->device() == torch::Device(torch::kCUDA, 0));
+ CATCH_REQUIRE(parameter->device() == torch::Device(torch::kCUDA, 0));
}
replicas[0]->to(torch::kCPU);
- REQUIRE(replica1_parameters.size() == original_parameters.size());
+ CATCH_REQUIRE(replica1_parameters.size() == original_parameters.size());
for (size_t i = 0; i < original_parameters.size(); ++i) {
- REQUIRE(replica1_parameters[i]->allclose(*original_parameters[i]));
- REQUIRE(
+ CATCH_REQUIRE(replica1_parameters[i]->allclose(*original_parameters[i]));
+ CATCH_REQUIRE(
replica1_parameters[i].data<float>() !=
original_parameters[i].data<float>());
}
auto replica2_parameters = replicas[1]->parameters();
for (auto& parameter : replica2_parameters) {
- REQUIRE(parameter->device() == torch::Device(torch::kCUDA, 1));
+ CATCH_REQUIRE(parameter->device() == torch::Device(torch::kCUDA, 1));
}
replicas[1]->to(torch::kCPU);
- REQUIRE(replica2_parameters.size() == original_parameters.size());
+ CATCH_REQUIRE(replica2_parameters.size() == original_parameters.size());
for (size_t i = 0; i < original_parameters.size(); ++i) {
- REQUIRE(replica2_parameters[i]->allclose(*original_parameters[i]));
- REQUIRE(
+ CATCH_REQUIRE(replica2_parameters[i]->allclose(*original_parameters[i]));
+ CATCH_REQUIRE(
replica2_parameters[i].data<float>() !=
original_parameters[i].data<float>());
}
}
-TEST_CASE("Parallel/ParallelApply", "[multi-cuda]") {
+CATCH_TEST_CASE("Parallel/ParallelApply", "[multi-cuda]") {
Linear a(3, 4);
Linear b(std::static_pointer_cast<LinearImpl>(a->clone()));
@@ -121,17 +121,17 @@
auto outputs = parallel::parallel_apply(modules, inputs);
- REQUIRE(outputs.size() == 3);
- REQUIRE(outputs[0].device().is_cpu());
+ CATCH_REQUIRE(outputs.size() == 3);
+ CATCH_REQUIRE(outputs[0].device().is_cpu());
- REQUIRE(outputs[1].device() == torch::Device(torch::kCUDA, 0));
- REQUIRE(outputs[1].to(torch::kCPU).allclose(outputs[0]));
+ CATCH_REQUIRE(outputs[1].device() == torch::Device(torch::kCUDA, 0));
+ CATCH_REQUIRE(outputs[1].to(torch::kCPU).allclose(outputs[0]));
- REQUIRE(outputs[2].device() == torch::Device(torch::kCUDA, 1));
- REQUIRE(outputs[2].to(torch::kCPU).allclose(outputs[0]));
+ CATCH_REQUIRE(outputs[2].device() == torch::Device(torch::kCUDA, 1));
+ CATCH_REQUIRE(outputs[2].to(torch::kCPU).allclose(outputs[0]));
}
-TEST_CASE("Parallel/ParallelApplyWithDifferentOutputDevice", "[multi-cuda]") {
+CATCH_TEST_CASE("Parallel/ParallelApplyWithDifferentOutputDevice", "[multi-cuda]") {
struct M : torch::nn::Module {
torch::Tensor forward(torch::Tensor input) {
return torch::ones({5}, torch::dtype(torch::kInt32));
@@ -147,17 +147,17 @@
auto outputs = parallel::parallel_apply(modules, inputs, devices);
- REQUIRE(outputs.size() == 3);
- REQUIRE(outputs[0].device().is_cuda());
- REQUIRE(outputs[0].device() == torch::Device(torch::kCUDA, 1));
+ CATCH_REQUIRE(outputs.size() == 3);
+ CATCH_REQUIRE(outputs[0].device().is_cuda());
+ CATCH_REQUIRE(outputs[0].device() == torch::Device(torch::kCUDA, 1));
- REQUIRE(outputs[1].device().is_cuda());
- REQUIRE(outputs[1].device() == torch::Device(torch::kCUDA, 0));
+ CATCH_REQUIRE(outputs[1].device().is_cuda());
+ CATCH_REQUIRE(outputs[1].device() == torch::Device(torch::kCUDA, 0));
- REQUIRE(outputs[2].device().is_cpu());
+ CATCH_REQUIRE(outputs[2].device().is_cpu());
}
-TEST_CASE("Parallel/ParallelApplyRethrowsException", "[multi-cuda]") {
+CATCH_TEST_CASE("Parallel/ParallelApplyRethrowsException", "[multi-cuda]") {
struct M : torch::nn::Cloneable<M> {
void reset() override {}
torch::Tensor forward(torch::Tensor input) {
@@ -167,11 +167,11 @@
auto m = std::make_shared<M>();
auto input = torch::ones({10, 3});
- REQUIRE_THROWS_WITH(
+ CATCH_REQUIRE_THROWS_WITH(
parallel::data_parallel(m, input), StartsWith("Badness!"));
}
-TEST_CASE(
+CATCH_TEST_CASE(
"Parallel/DataParallelPlacesTheOutputOnTheRequestedDevice",
"[multi-cuda]") {
struct M : torch::nn::Cloneable<M> {
@@ -192,9 +192,9 @@
input,
/*devices=*/at::nullopt,
/*output_device=*/torch::Device(torch::kCUDA, 1));
- REQUIRE(output.defined());
- REQUIRE(output.device().is_cuda());
- REQUIRE(output.device().index() == 1);
+ CATCH_REQUIRE(output.defined());
+ CATCH_REQUIRE(output.device().is_cuda());
+ CATCH_REQUIRE(output.device().index() == 1);
}
{
// Verify for the single-device case (where we don't scatter/gather).
@@ -203,16 +203,16 @@
input,
/*devices=*/std::vector<torch::Device>{torch::Device(torch::kCUDA, 0)},
/*output_device=*/torch::Device(torch::kCUDA, 1));
- REQUIRE(m->intermediate_tensor.defined());
- REQUIRE(m->intermediate_tensor.device().is_cuda());
- REQUIRE(m->intermediate_tensor.device().index() == 0);
- REQUIRE(output.defined());
- REQUIRE(output.device().is_cuda());
- REQUIRE(output.device().index() == 1);
+ CATCH_REQUIRE(m->intermediate_tensor.defined());
+ CATCH_REQUIRE(m->intermediate_tensor.device().is_cuda());
+ CATCH_REQUIRE(m->intermediate_tensor.device().index() == 0);
+ CATCH_REQUIRE(output.defined());
+ CATCH_REQUIRE(output.device().is_cuda());
+ CATCH_REQUIRE(output.device().index() == 1);
}
}
-TEST_CASE("Parallel/DataParallelUsesAllAvailableCUDADevices", "[cuda]") {
+CATCH_TEST_CASE("Parallel/DataParallelUsesAllAvailableCUDADevices", "[cuda]") {
struct M : torch::nn::Cloneable<M> {
void reset() override {}
torch::Tensor forward(torch::Tensor input) {
@@ -225,9 +225,9 @@
auto output = parallel::data_parallel(m, input);
const auto device_count = torch::cuda::device_count();
- REQUIRE(output.numel() == device_count);
+ CATCH_REQUIRE(output.numel() == device_count);
for (size_t i = 0; i < device_count; ++i) {
- REQUIRE(output[i].toCInt() == i);
+ CATCH_REQUIRE(output[i].toCInt() == i);
}
}
diff --git a/test/cpp/api/rnn.cpp b/test/cpp/api/rnn.cpp
index 9668572..a307851 100644
--- a/test/cpp/api/rnn.cpp
+++ b/test/cpp/api/rnn.cpp
@@ -1,4 +1,4 @@
-#include <catch.hpp>
+#include "catch_utils.hpp"
#include <torch/nn/modules/linear.h>
#include <torch/nn/modules/rnn.h>
@@ -71,22 +71,22 @@
// Expect the LSTM to have 64 outputs and 3 layers, with an input of batch
// 10 and 16 time steps (10 x 16 x n)
- REQUIRE(output.output.ndimension() == 3);
- REQUIRE(output.output.size(0) == 10);
- REQUIRE(output.output.size(1) == 16);
- REQUIRE(output.output.size(2) == 64);
+ CATCH_REQUIRE(output.output.ndimension() == 3);
+ CATCH_REQUIRE(output.output.size(0) == 10);
+ CATCH_REQUIRE(output.output.size(1) == 16);
+ CATCH_REQUIRE(output.output.size(2) == 64);
- REQUIRE(output.state.ndimension() == 4);
- REQUIRE(output.state.size(0) == 2); // (hx, cx)
- REQUIRE(output.state.size(1) == 3); // layers
- REQUIRE(output.state.size(2) == 16); // Batchsize
- REQUIRE(output.state.size(3) == 64); // 64 hidden dims
+ CATCH_REQUIRE(output.state.ndimension() == 4);
+ CATCH_REQUIRE(output.state.size(0) == 2); // (hx, cx)
+ CATCH_REQUIRE(output.state.size(1) == 3); // layers
+ CATCH_REQUIRE(output.state.size(2) == 16); // Batchsize
+ CATCH_REQUIRE(output.state.size(3) == 64); // 64 hidden dims
// Something is in the hiddens
- REQUIRE(output.state.norm().toCFloat() > 0);
+ CATCH_REQUIRE(output.state.norm().toCFloat() > 0);
}
-TEST_CASE("RNN/CheckOutputSizes") {
+CATCH_TEST_CASE("RNN/CheckOutputSizes") {
torch::manual_seed(0);
LSTM model(LSTMOptions(128, 64).layers(3).dropout(0.2));
// Input size is: sequence length, batch size, input size
@@ -104,10 +104,10 @@
torch::Tensor diff = next.state - output.state;
// Hiddens changed
- REQUIRE(diff.abs().sum().toCFloat() > 1e-3);
+ CATCH_REQUIRE(diff.abs().sum().toCFloat() > 1e-3);
}
-TEST_CASE("RNN/CheckOutputValuesMatchPyTorch") {
+CATCH_TEST_CASE("RNN/CheckOutputValuesMatchPyTorch") {
torch::manual_seed(0);
// Make sure the outputs match pytorch outputs
LSTM model(2, 2);
@@ -127,10 +127,10 @@
}
auto out = model->forward(x);
- REQUIRE(out.output.ndimension() == 3);
- REQUIRE(out.output.size(0) == 3);
- REQUIRE(out.output.size(1) == 4);
- REQUIRE(out.output.size(2) == 2);
+ CATCH_REQUIRE(out.output.ndimension() == 3);
+ CATCH_REQUIRE(out.output.size(0) == 3);
+ CATCH_REQUIRE(out.output.size(1) == 4);
+ CATCH_REQUIRE(out.output.size(2) == 2);
auto flat = out.output.view(3 * 4 * 2);
float c_out[] = {0.4391, 0.5402, 0.4330, 0.5324, 0.4261, 0.5239,
@@ -138,14 +138,14 @@
0.6620, 0.7860, 0.6501, 0.7741, 0.7889, 0.9003,
0.7769, 0.8905, 0.7635, 0.8794, 0.7484, 0.8666};
for (size_t i = 0; i < 3 * 4 * 2; i++) {
- REQUIRE(std::abs(flat[i].toCFloat() - c_out[i]) < 1e-3);
+ CATCH_REQUIRE(std::abs(flat[i].toCFloat() - c_out[i]) < 1e-3);
}
- REQUIRE(out.state.ndimension() == 4); // (hx, cx) x layers x B x 2
- REQUIRE(out.state.size(0) == 2);
- REQUIRE(out.state.size(1) == 1);
- REQUIRE(out.state.size(2) == 4);
- REQUIRE(out.state.size(3) == 2);
+ CATCH_REQUIRE(out.state.ndimension() == 4); // (hx, cx) x layers x B x 2
+ CATCH_REQUIRE(out.state.size(0) == 2);
+ CATCH_REQUIRE(out.state.size(1) == 1);
+ CATCH_REQUIRE(out.state.size(2) == 4);
+ CATCH_REQUIRE(out.state.size(3) == 2);
flat = out.state.view(16);
float h_out[] = {0.7889,
0.9003,
@@ -164,33 +164,33 @@
1.0931,
1.4911};
for (size_t i = 0; i < 16; i++) {
- REQUIRE(std::abs(flat[i].toCFloat() - h_out[i]) < 1e-3);
+ CATCH_REQUIRE(std::abs(flat[i].toCFloat() - h_out[i]) < 1e-3);
}
}
-TEST_CASE("RNN/integration/LSTM") {
- REQUIRE(test_RNN_xor<LSTM>(
+CATCH_TEST_CASE("RNN/integration/LSTM") {
+ CATCH_REQUIRE(test_RNN_xor<LSTM>(
[](int s) { return LSTM(LSTMOptions(s, s).layers(2)); }));
}
-TEST_CASE("RNN/integration/GRU") {
- REQUIRE(
+CATCH_TEST_CASE("RNN/integration/GRU") {
+ CATCH_REQUIRE(
test_RNN_xor<GRU>([](int s) { return GRU(GRUOptions(s, s).layers(2)); }));
}
-TEST_CASE("RNN/integration/RNN") {
- SECTION("relu") {
- REQUIRE(test_RNN_xor<RNN>(
+CATCH_TEST_CASE("RNN/integration/RNN") {
+ CATCH_SECTION("relu") {
+ CATCH_REQUIRE(test_RNN_xor<RNN>(
[](int s) { return RNN(RNNOptions(s, s).relu().layers(2)); }));
}
- SECTION("tanh") {
- REQUIRE(test_RNN_xor<RNN>(
+ CATCH_SECTION("tanh") {
+ CATCH_REQUIRE(test_RNN_xor<RNN>(
[](int s) { return RNN(RNNOptions(s, s).tanh().layers(2)); }));
}
}
-TEST_CASE("rnn_cuda", "[cuda]") {
- SECTION("sizes") {
+CATCH_TEST_CASE("rnn_cuda", "[cuda]") {
+ CATCH_SECTION("sizes") {
torch::manual_seed(0);
LSTM model(LSTMOptions(128, 64).layers(3).dropout(0.2));
model->to(torch::kCUDA);
@@ -209,26 +209,26 @@
torch::Tensor diff = next.state - output.state;
// Hiddens changed
- REQUIRE(diff.abs().sum().toCFloat() > 1e-3);
+ CATCH_REQUIRE(diff.abs().sum().toCFloat() > 1e-3);
}
- SECTION("lstm") {
- REQUIRE(test_RNN_xor<LSTM>(
+ CATCH_SECTION("lstm") {
+ CATCH_REQUIRE(test_RNN_xor<LSTM>(
[](int s) { return LSTM(LSTMOptions(s, s).layers(2)); }, true));
}
- SECTION("gru") {
- REQUIRE(test_RNN_xor<GRU>(
+ CATCH_SECTION("gru") {
+ CATCH_REQUIRE(test_RNN_xor<GRU>(
[](int s) { return GRU(GRUOptions(s, s).layers(2)); }, true));
}
- SECTION("rnn") {
- SECTION("relu") {
- REQUIRE(test_RNN_xor<RNN>(
+ CATCH_SECTION("rnn") {
+ CATCH_SECTION("relu") {
+ CATCH_REQUIRE(test_RNN_xor<RNN>(
[](int s) { return RNN(RNNOptions(s, s).relu().layers(2)); }, true));
}
- SECTION("tanh") {
- REQUIRE(test_RNN_xor<RNN>(
+ CATCH_SECTION("tanh") {
+ CATCH_REQUIRE(test_RNN_xor<RNN>(
[](int s) { return RNN(RNNOptions(s, s).tanh().layers(2)); }, true));
}
}
diff --git a/test/cpp/api/sequential.cpp b/test/cpp/api/sequential.cpp
index aef1332..777d6e2 100644
--- a/test/cpp/api/sequential.cpp
+++ b/test/cpp/api/sequential.cpp
@@ -1,4 +1,4 @@
-#include <catch.hpp>
+#include "catch_utils.hpp"
#include <torch/nn/modules.h>
#include <torch/nn/modules/batchnorm.h>
@@ -21,7 +21,7 @@
using Catch::StartsWith;
-TEST_CASE("Sequential/ConstructsFromSharedPointer") {
+CATCH_TEST_CASE("Sequential/ConstructsFromSharedPointer") {
struct M : torch::nn::Module {
explicit M(int value_) : value(value_) {}
int value;
@@ -31,10 +31,10 @@
};
Sequential sequential(
std::make_shared<M>(1), std::make_shared<M>(2), std::make_shared<M>(3));
- REQUIRE(sequential->size() == 3);
+ CATCH_REQUIRE(sequential->size() == 3);
}
-TEST_CASE("Sequential/ConstructsFromConcreteType") {
+CATCH_TEST_CASE("Sequential/ConstructsFromConcreteType") {
struct M : torch::nn::Module {
explicit M(int value_) : value(value_) {}
int value;
@@ -44,9 +44,9 @@
};
Sequential sequential(M(1), M(2), M(3));
- REQUIRE(sequential->size() == 3);
+ CATCH_REQUIRE(sequential->size() == 3);
}
-TEST_CASE("Sequential/ConstructsFromModuleHolder") {
+CATCH_TEST_CASE("Sequential/ConstructsFromModuleHolder") {
struct MImpl : torch::nn::Module {
explicit MImpl(int value_) : value(value_) {}
int forward() {
@@ -61,10 +61,10 @@
};
Sequential sequential(M(1), M(2), M(3));
- REQUIRE(sequential->size() == 3);
+ CATCH_REQUIRE(sequential->size() == 3);
}
-TEST_CASE("Sequential/PushBackAddsAnElement") {
+CATCH_TEST_CASE("Sequential/PushBackAddsAnElement") {
struct M : torch::nn::Module {
explicit M(int value_) : value(value_) {}
int forward() {
@@ -73,17 +73,17 @@
int value;
};
Sequential sequential;
- REQUIRE(sequential->size() == 0);
- REQUIRE(sequential->is_empty());
+ CATCH_REQUIRE(sequential->size() == 0);
+ CATCH_REQUIRE(sequential->is_empty());
sequential->push_back(Linear(3, 4));
- REQUIRE(sequential->size() == 1);
+ CATCH_REQUIRE(sequential->size() == 1);
sequential->push_back(std::make_shared<M>(1));
- REQUIRE(sequential->size() == 2);
+ CATCH_REQUIRE(sequential->size() == 2);
sequential->push_back(M(2));
- REQUIRE(sequential->size() == 3);
+ CATCH_REQUIRE(sequential->size() == 3);
}
-TEST_CASE("Sequential/AccessWithAt") {
+CATCH_TEST_CASE("Sequential/AccessWithAt") {
struct M : torch::nn::Module {
explicit M(int value_) : value(value_) {}
int forward() {
@@ -98,22 +98,22 @@
for (auto& module : modules) {
sequential->push_back(module);
}
- REQUIRE(sequential->size() == 3);
+ CATCH_REQUIRE(sequential->size() == 3);
// returns the correct module for a given index
for (size_t i = 0; i < modules.size(); ++i) {
- REQUIRE(&sequential->at<M>(i) == modules[i].get());
+ CATCH_REQUIRE(&sequential->at<M>(i) == modules[i].get());
}
// throws for a bad index
- REQUIRE_THROWS_WITH(
+ CATCH_REQUIRE_THROWS_WITH(
sequential->at<M>(modules.size() + 1), StartsWith("Index out of range"));
- REQUIRE_THROWS_WITH(
+ CATCH_REQUIRE_THROWS_WITH(
sequential->at<M>(modules.size() + 1000000),
StartsWith("Index out of range"));
}
-TEST_CASE("Sequential/AccessWithPtr") {
+CATCH_TEST_CASE("Sequential/AccessWithPtr") {
struct M : torch::nn::Module {
explicit M(int value_) : value(value_) {}
int forward() {
@@ -128,46 +128,46 @@
for (auto& module : modules) {
sequential->push_back(module);
}
- REQUIRE(sequential->size() == 3);
+ CATCH_REQUIRE(sequential->size() == 3);
// returns the correct module for a given index
for (size_t i = 0; i < modules.size(); ++i) {
- REQUIRE(sequential->ptr(i).get() == modules[i].get());
- REQUIRE(sequential[i].get() == modules[i].get());
- REQUIRE(sequential->ptr<M>(i).get() == modules[i].get());
+ CATCH_REQUIRE(sequential->ptr(i).get() == modules[i].get());
+ CATCH_REQUIRE(sequential[i].get() == modules[i].get());
+ CATCH_REQUIRE(sequential->ptr<M>(i).get() == modules[i].get());
}
// throws for a bad index
- REQUIRE_THROWS_WITH(
+ CATCH_REQUIRE_THROWS_WITH(
sequential->ptr(modules.size() + 1), StartsWith("Index out of range"));
- REQUIRE_THROWS_WITH(
+ CATCH_REQUIRE_THROWS_WITH(
sequential->ptr(modules.size() + 1000000),
StartsWith("Index out of range"));
}
-TEST_CASE("Sequential/CallingForwardOnEmptySequentialIsDisallowed") {
+CATCH_TEST_CASE("Sequential/CallingForwardOnEmptySequentialIsDisallowed") {
Sequential empty;
- REQUIRE_THROWS_WITH(
+ CATCH_REQUIRE_THROWS_WITH(
empty->forward<int>(),
StartsWith("Cannot call forward() on an empty Sequential"));
}
-TEST_CASE("Sequential/CallingForwardChainsCorrectly") {
+CATCH_TEST_CASE("Sequential/CallingForwardChainsCorrectly") {
struct MockModule : torch::nn::Module {
explicit MockModule(int value) : expected(value) {}
int expected;
int forward(int value) {
- REQUIRE(value == expected);
+ CATCH_REQUIRE(value == expected);
return value + 1;
}
};
Sequential sequential(MockModule{1}, MockModule{2}, MockModule{3});
- REQUIRE(sequential->forward<int>(1) == 4);
+ CATCH_REQUIRE(sequential->forward<int>(1) == 4);
}
-TEST_CASE("Sequential/CallingForwardWithTheWrongReturnTypeThrows") {
+CATCH_TEST_CASE("Sequential/CallingForwardWithTheWrongReturnTypeThrows") {
struct M : public torch::nn::Module {
int forward() {
return 5;
@@ -175,14 +175,14 @@
};
Sequential sequential(M{});
- REQUIRE(sequential->forward<int>() == 5);
- REQUIRE_THROWS_WITH(
+ CATCH_REQUIRE(sequential->forward<int>() == 5);
+ CATCH_REQUIRE_THROWS_WITH(
sequential->forward<float>(),
StartsWith("The type of the return value "
"is int, but you asked for type float"));
}
-TEST_CASE("Sequential/TheReturnTypeOfForwardDefaultsToTensor") {
+CATCH_TEST_CASE("Sequential/TheReturnTypeOfForwardDefaultsToTensor") {
struct M : public torch::nn::Module {
torch::Tensor forward(torch::Tensor v) {
return v;
@@ -191,21 +191,21 @@
Sequential sequential(M{});
auto variable = torch::ones({3, 3}, torch::requires_grad());
- REQUIRE(sequential->forward(variable).equal(variable));
+ CATCH_REQUIRE(sequential->forward(variable).equal(variable));
}
-TEST_CASE("Sequential/ForwardReturnsTheLastValue") {
+CATCH_TEST_CASE("Sequential/ForwardReturnsTheLastValue") {
torch::manual_seed(0);
Sequential sequential(Linear(10, 3), Linear(3, 5), Linear(5, 100));
auto x = torch::randn({1000, 10}, torch::requires_grad());
auto y = sequential->forward(x);
- REQUIRE(y.ndimension() == 2);
- REQUIRE(y.size(0) == 1000);
- REQUIRE(y.size(1) == 100);
+ CATCH_REQUIRE(y.ndimension() == 2);
+ CATCH_REQUIRE(y.size(0) == 1000);
+ CATCH_REQUIRE(y.size(1) == 100);
}
-TEST_CASE("Sequential/SanityCheckForHoldingStandardModules") {
+CATCH_TEST_CASE("Sequential/SanityCheckForHoldingStandardModules") {
Sequential sequential(
Linear(10, 3),
Conv2d(1, 2, 3),
@@ -215,7 +215,7 @@
LSTM(4, 5));
}
-TEST_CASE("Sequential/ExtendPushesModulesFromOtherSequential") {
+CATCH_TEST_CASE("Sequential/ExtendPushesModulesFromOtherSequential") {
struct A : torch::nn::Module {
int forward(int x) {
return x;
@@ -240,34 +240,34 @@
Sequential b(C{}, D{});
a->extend(*b);
- REQUIRE(a->size() == 4);
- REQUIRE(a[0]->as<A>());
- REQUIRE(a[1]->as<B>());
- REQUIRE(a[2]->as<C>());
- REQUIRE(a[3]->as<D>());
+ CATCH_REQUIRE(a->size() == 4);
+ CATCH_REQUIRE(a[0]->as<A>());
+ CATCH_REQUIRE(a[1]->as<B>());
+ CATCH_REQUIRE(a[2]->as<C>());
+ CATCH_REQUIRE(a[3]->as<D>());
- REQUIRE(b->size() == 2);
- REQUIRE(b[0]->as<C>());
- REQUIRE(b[1]->as<D>());
+ CATCH_REQUIRE(b->size() == 2);
+ CATCH_REQUIRE(b[0]->as<C>());
+ CATCH_REQUIRE(b[1]->as<D>());
std::vector<std::shared_ptr<A>> c = {std::make_shared<A>(),
std::make_shared<A>()};
b->extend(c);
- REQUIRE(b->size() == 4);
- REQUIRE(b[0]->as<C>());
- REQUIRE(b[1]->as<D>());
- REQUIRE(b[2]->as<A>());
- REQUIRE(b[3]->as<A>());
+ CATCH_REQUIRE(b->size() == 4);
+ CATCH_REQUIRE(b[0]->as<C>());
+ CATCH_REQUIRE(b[1]->as<D>());
+ CATCH_REQUIRE(b[2]->as<A>());
+ CATCH_REQUIRE(b[3]->as<A>());
}
-TEST_CASE("Sequential/HasReferenceSemantics") {
+CATCH_TEST_CASE("Sequential/HasReferenceSemantics") {
Sequential first(Linear(2, 3), Linear(4, 4), Linear(4, 5));
Sequential second(first);
- REQUIRE(first.get() == second.get());
- REQUIRE(first->size() == second->size());
- REQUIRE(std::equal(
+ CATCH_REQUIRE(first.get() == second.get());
+ CATCH_REQUIRE(first->size() == second->size());
+ CATCH_REQUIRE(std::equal(
first->begin(),
first->end(),
second->begin(),
@@ -276,17 +276,17 @@
}));
}
-TEST_CASE("Sequential/IsCloneable") {
+CATCH_TEST_CASE("Sequential/IsCloneable") {
Sequential sequential(Linear(3, 4), Functional(torch::relu), BatchNorm(3));
Sequential clone =
std::dynamic_pointer_cast<SequentialImpl>(sequential->clone());
- REQUIRE(sequential->size() == clone->size());
+ CATCH_REQUIRE(sequential->size() == clone->size());
for (size_t i = 0; i < sequential->size(); ++i) {
// The modules should be the same kind (type).
- REQUIRE(sequential[i]->name() == clone[i]->name());
+ CATCH_REQUIRE(sequential[i]->name() == clone[i]->name());
// But not pointer-equal (distinct objects).
- REQUIRE(sequential[i] != clone[i]);
+ CATCH_REQUIRE(sequential[i] != clone[i]);
}
// Verify that the clone is deep, i.e. parameters of modules are cloned too.
@@ -295,38 +295,38 @@
auto params1 = sequential->parameters();
auto params2 = clone->parameters();
- REQUIRE(params1.size() == params2.size());
+ CATCH_REQUIRE(params1.size() == params2.size());
for (auto& param : params1) {
- REQUIRE(!pointer_equal(param.value, params2[param.key]));
- REQUIRE(param->device() == params2[param.key].device());
- REQUIRE(param->allclose(params2[param.key]));
+ CATCH_REQUIRE(!pointer_equal(param.value, params2[param.key]));
+ CATCH_REQUIRE(param->device() == params2[param.key].device());
+ CATCH_REQUIRE(param->allclose(params2[param.key]));
param->add_(2);
}
for (auto& param : params1) {
- REQUIRE(!param->allclose(params2[param.key]));
+ CATCH_REQUIRE(!param->allclose(params2[param.key]));
}
}
-TEST_CASE("Sequential/RegistersElementsAsSubmodules") {
+CATCH_TEST_CASE("Sequential/RegistersElementsAsSubmodules") {
Sequential sequential(Linear(10, 3), Conv2d(1, 2, 3), FeatureDropout(0.5));
auto modules = sequential->modules();
- REQUIRE(modules.size() == sequential->children().size());
+ CATCH_REQUIRE(modules.size() == sequential->children().size());
- REQUIRE(modules[0]->as<Linear>());
- REQUIRE(modules[1]->as<Conv2d>());
- REQUIRE(modules[2]->as<FeatureDropout>());
+ CATCH_REQUIRE(modules[0]->as<Linear>());
+ CATCH_REQUIRE(modules[1]->as<Conv2d>());
+ CATCH_REQUIRE(modules[2]->as<FeatureDropout>());
}
-TEST_CASE("Sequential/CloneToDevice", "[cuda]") {
+CATCH_TEST_CASE("Sequential/CloneToDevice", "[cuda]") {
Sequential sequential(Linear(3, 4), Functional(torch::relu), BatchNorm(3));
torch::Device device(torch::kCUDA, 0);
Sequential clone =
std::dynamic_pointer_cast<SequentialImpl>(sequential->clone(device));
for (const auto& p : clone->parameters()) {
- REQUIRE(p->device() == device);
+ CATCH_REQUIRE(p->device() == device);
}
for (const auto& b : clone->buffers()) {
- REQUIRE(b->device() == device);
+ CATCH_REQUIRE(b->device() == device);
}
}
diff --git a/test/cpp/api/serialization.cpp b/test/cpp/api/serialization.cpp
index 3541089..fda133b 100644
--- a/test/cpp/api/serialization.cpp
+++ b/test/cpp/api/serialization.cpp
@@ -1,4 +1,4 @@
-#include <catch.hpp>
+#include "catch_utils.hpp"
#include <torch/nn/modules/functional.h>
#include <torch/nn/modules/linear.h>
@@ -30,12 +30,12 @@
}
} // namespace
-TEST_CASE("serialization") {
+CATCH_TEST_CASE("serialization") {
torch::manual_seed(0);
- SECTION("undefined") {
+ CATCH_SECTION("undefined") {
auto x = torch::Tensor();
- REQUIRE(!x.defined());
+ CATCH_REQUIRE(!x.defined());
auto y = torch::randn({5});
@@ -43,10 +43,10 @@
torch::save(ss, &x);
torch::load(ss, &y);
- REQUIRE(!y.defined());
+ CATCH_REQUIRE(!y.defined());
}
- SECTION("cputypes") {
+ CATCH_SECTION("cputypes") {
for (int i = 0; i < static_cast<int>(torch::Dtype::NumOptions); i++) {
if (i == static_cast<int>(torch::Dtype::Half)) {
// XXX can't serialize half tensors at the moment since contiguous() is
@@ -69,17 +69,17 @@
torch::save(ss, &x);
torch::load(ss, &y);
- REQUIRE(y.defined());
- REQUIRE(x.sizes().vec() == y.sizes().vec());
+ CATCH_REQUIRE(y.defined());
+ CATCH_REQUIRE(x.sizes().vec() == y.sizes().vec());
if (torch::isIntegralType(static_cast<torch::Dtype>(i))) {
- REQUIRE(x.equal(y));
+ CATCH_REQUIRE(x.equal(y));
} else {
- REQUIRE(x.allclose(y));
+ CATCH_REQUIRE(x.allclose(y));
}
}
}
- SECTION("binary") {
+ CATCH_SECTION("binary") {
auto x = torch::randn({5, 5});
auto y = torch::Tensor();
@@ -93,11 +93,11 @@
archive(y);
}
- REQUIRE(y.defined());
- REQUIRE(x.sizes().vec() == y.sizes().vec());
- REQUIRE(x.allclose(y));
+ CATCH_REQUIRE(y.defined());
+ CATCH_REQUIRE(x.sizes().vec() == y.sizes().vec());
+ CATCH_REQUIRE(x.allclose(y));
}
- SECTION("portable_binary") {
+ CATCH_SECTION("portable_binary") {
auto x = torch::randn({5, 5});
auto y = torch::Tensor();
@@ -111,12 +111,12 @@
archive(y);
}
- REQUIRE(y.defined());
- REQUIRE(x.sizes().vec() == y.sizes().vec());
- REQUIRE(x.allclose(y));
+ CATCH_REQUIRE(y.defined());
+ CATCH_REQUIRE(x.sizes().vec() == y.sizes().vec());
+ CATCH_REQUIRE(x.allclose(y));
}
- SECTION("resized") {
+ CATCH_SECTION("resized") {
auto x = torch::randn({11, 5});
x.resize_({5, 5});
auto y = torch::Tensor();
@@ -131,11 +131,11 @@
archive(y);
}
- REQUIRE(y.defined());
- REQUIRE(x.sizes().vec() == y.sizes().vec());
- REQUIRE(x.allclose(y));
+ CATCH_REQUIRE(y.defined());
+ CATCH_REQUIRE(x.sizes().vec() == y.sizes().vec());
+ CATCH_REQUIRE(x.allclose(y));
}
- SECTION("sliced") {
+ CATCH_SECTION("sliced") {
auto x = torch::randn({11, 5});
x = x.slice(0, 1, 3);
auto y = torch::Tensor();
@@ -150,12 +150,12 @@
archive(y);
}
- REQUIRE(y.defined());
- REQUIRE(x.sizes().vec() == y.sizes().vec());
- REQUIRE(x.allclose(y));
+ CATCH_REQUIRE(y.defined());
+ CATCH_REQUIRE(x.sizes().vec() == y.sizes().vec());
+ CATCH_REQUIRE(x.allclose(y));
}
- SECTION("noncontig") {
+ CATCH_SECTION("noncontig") {
auto x = torch::randn({11, 5});
x = x.slice(1, 1, 4);
auto y = torch::Tensor();
@@ -170,12 +170,12 @@
archive(y);
}
- REQUIRE(y.defined());
- REQUIRE(x.sizes().vec() == y.sizes().vec());
- REQUIRE(x.allclose(y));
+ CATCH_REQUIRE(y.defined());
+ CATCH_REQUIRE(x.sizes().vec() == y.sizes().vec());
+ CATCH_REQUIRE(x.allclose(y));
}
- SECTION("xor") {
+ CATCH_SECTION("xor") {
// We better be able to save and load a XOR model!
auto getLoss = [](Sequential model, uint32_t batch_size) {
auto inputs = torch::empty({batch_size, 2});
@@ -207,7 +207,7 @@
optimizer.step();
running_loss = running_loss * 0.99 + loss.sum().toCFloat() * 0.01;
- REQUIRE(epoch < 3000);
+ CATCH_REQUIRE(epoch < 3000);
epoch++;
}
@@ -216,10 +216,10 @@
torch::load(ss, model2);
auto loss = getLoss(model2, 100);
- REQUIRE(loss.toCFloat() < 0.1);
+ CATCH_REQUIRE(loss.toCFloat() < 0.1);
}
- SECTION("optim") {
+ CATCH_SECTION("optim") {
auto model1 = Linear(5, 2);
auto model2 = Linear(5, 2);
auto model3 = Linear(5, 2);
@@ -235,8 +235,8 @@
auto param2 = model2->parameters();
auto param3 = model3->parameters();
for (const auto& p : param1) {
- REQUIRE(param1[p.key].allclose(param2[p.key]));
- REQUIRE(param2[p.key].allclose(param3[p.key]));
+ CATCH_REQUIRE(param1[p.key].allclose(param2[p.key]));
+ CATCH_REQUIRE(param2[p.key].allclose(param3[p.key]));
}
// Make some optimizers with momentum (and thus state)
@@ -281,13 +281,13 @@
for (const auto& p : param1) {
const auto& name = p.key;
// Model 1 and 3 should be the same
- REQUIRE(param1[name].norm().toCFloat() == param3[name].norm().toCFloat());
- REQUIRE(param1[name].norm().toCFloat() != param2[name].norm().toCFloat());
+ CATCH_REQUIRE(param1[name].norm().toCFloat() == param3[name].norm().toCFloat());
+ CATCH_REQUIRE(param1[name].norm().toCFloat() != param2[name].norm().toCFloat());
}
}
}
-TEST_CASE("serialization_cuda", "[cuda]") {
+CATCH_TEST_CASE("serialization_cuda", "[cuda]") {
torch::manual_seed(0);
// We better be able to save and load a XOR model!
auto getLoss = [](Sequential model, uint32_t batch_size) {
@@ -318,7 +318,7 @@
optimizer.step();
running_loss = running_loss * 0.99 + loss.sum().toCFloat() * 0.01;
- REQUIRE(epoch < 3000);
+ CATCH_REQUIRE(epoch < 3000);
epoch++;
}
@@ -327,7 +327,7 @@
torch::load(ss, model2);
auto loss = getLoss(model2, 100);
- REQUIRE(loss.toCFloat() < 0.1);
+ CATCH_REQUIRE(loss.toCFloat() < 0.1);
model2->to(torch::kCUDA);
ss.clear();
@@ -335,5 +335,5 @@
torch::load(ss, model3);
loss = getLoss(model3, 100);
- REQUIRE(loss.toCFloat() < 0.1);
+ CATCH_REQUIRE(loss.toCFloat() < 0.1);
}
diff --git a/test/cpp/api/tensor.cpp b/test/cpp/api/tensor.cpp
index f08a30d..5760556 100644
--- a/test/cpp/api/tensor.cpp
+++ b/test/cpp/api/tensor.cpp
@@ -1,4 +1,4 @@
-#include <catch.hpp>
+#include "catch_utils.hpp"
#include <torch/tensor.h>
@@ -19,12 +19,12 @@
}
#define REQUIRE_TENSOR_OPTIONS(device_, index_, type_, layout_) \
- REQUIRE(tensor.device().type() == at::Device((device_), (index_)).type()); \
- REQUIRE(tensor.device().index() == at::Device((device_), (index_)).index()); \
- REQUIRE(tensor.dtype() == (type_)); \
- REQUIRE(tensor.layout() == (layout_))
+ CATCH_REQUIRE(tensor.device().type() == at::Device((device_), (index_)).type()); \
+ CATCH_REQUIRE(tensor.device().index() == at::Device((device_), (index_)).index()); \
+ CATCH_REQUIRE(tensor.dtype() == (type_)); \
+ CATCH_REQUIRE(tensor.layout() == (layout_))
-TEST_CASE("Tensor/ToDtype") {
+CATCH_TEST_CASE("Tensor/ToDtype") {
auto tensor = at::empty({3, 4});
REQUIRE_TENSOR_OPTIONS(at::kCPU, -1, at::kFloat, at::kStrided);
@@ -39,7 +39,7 @@
}
// Not currently supported.
-// TEST_CASE("Tensor/ToLayout") {
+// CATCH_TEST_CASE("Tensor/ToLayout") {
// auto tensor = at::empty({3, 4});
// REQUIRE_TENSOR_OPTIONS(at::kCPU, -1, at::kFloat, at::kStrided);
//
@@ -50,7 +50,7 @@
// REQUIRE_TENSOR_OPTIONS(at::kCPU, -1, at::kFloat, at::kStrided);
// }
-TEST_CASE("Tensor/ToDevice", "[cuda]") {
+CATCH_TEST_CASE("Tensor/ToDevice", "[cuda]") {
auto tensor = at::empty({3, 4});
REQUIRE_TENSOR_OPTIONS(at::kCPU, -1, at::kFloat, at::kStrided);
@@ -67,7 +67,7 @@
REQUIRE_TENSOR_OPTIONS(at::kCPU, -1, at::kFloat, at::kStrided);
}
-TEST_CASE("Tensor/ToDeviceAndDtype", "[cuda]") {
+CATCH_TEST_CASE("Tensor/ToDeviceAndDtype", "[cuda]") {
auto tensor = at::empty({3, 4});
REQUIRE_TENSOR_OPTIONS(at::kCPU, -1, at::kFloat, at::kStrided);
@@ -75,119 +75,119 @@
REQUIRE_TENSOR_OPTIONS(at::kCUDA, 1, at::kInt, at::kStrided);
}
-TEST_CASE("Tensor/ToOptionsRespectsRequiresGrad") {
+CATCH_TEST_CASE("Tensor/ToOptionsRespectsRequiresGrad") {
{
auto tensor = torch::empty({3, 4}, at::requires_grad());
- REQUIRE(tensor.requires_grad());
+ CATCH_REQUIRE(tensor.requires_grad());
tensor = tensor.to(at::kDouble);
- REQUIRE(tensor.requires_grad());
+ CATCH_REQUIRE(tensor.requires_grad());
}
{
auto tensor = torch::empty({3, 4});
- REQUIRE(!tensor.requires_grad());
+ CATCH_REQUIRE(!tensor.requires_grad());
tensor = tensor.to(at::kDouble);
- REQUIRE(!tensor.requires_grad());
+ CATCH_REQUIRE(!tensor.requires_grad());
}
}
-TEST_CASE("Tensor/ToDoesNotCopyWhenOptionsAreAllTheSame") {
+CATCH_TEST_CASE("Tensor/ToDoesNotCopyWhenOptionsAreAllTheSame") {
auto tensor = at::empty({3, 4}, at::kFloat);
auto hopefully_not_copy = tensor.to(at::kFloat);
- REQUIRE(hopefully_not_copy.data<float>() == tensor.data<float>());
+ CATCH_REQUIRE(hopefully_not_copy.data<float>() == tensor.data<float>());
}
-TEST_CASE("Tensor/ContainsCorrectValueForSingleValue") {
+CATCH_TEST_CASE("Tensor/ContainsCorrectValueForSingleValue") {
auto tensor = at::tensor(123);
- REQUIRE(tensor.numel() == 1);
- REQUIRE(tensor.dtype() == at::kInt);
- REQUIRE(tensor[0].toCInt() == 123);
+ CATCH_REQUIRE(tensor.numel() == 1);
+ CATCH_REQUIRE(tensor.dtype() == at::kInt);
+ CATCH_REQUIRE(tensor[0].toCInt() == 123);
tensor = at::tensor(123.456f);
- REQUIRE(tensor.numel() == 1);
- REQUIRE(tensor.dtype() == at::kFloat);
- REQUIRE(almost_equal(tensor[0], 123.456f));
+ CATCH_REQUIRE(tensor.numel() == 1);
+ CATCH_REQUIRE(tensor.dtype() == at::kFloat);
+ CATCH_REQUIRE(almost_equal(tensor[0], 123.456f));
tensor = at::tensor(123.456);
- REQUIRE(tensor.numel() == 1);
- REQUIRE(tensor.dtype() == at::kDouble);
- REQUIRE(almost_equal(tensor[0], 123.456));
+ CATCH_REQUIRE(tensor.numel() == 1);
+ CATCH_REQUIRE(tensor.dtype() == at::kDouble);
+ CATCH_REQUIRE(almost_equal(tensor[0], 123.456));
}
-TEST_CASE("Tensor/ContainsCorrectValuesForManyValues") {
+CATCH_TEST_CASE("Tensor/ContainsCorrectValuesForManyValues") {
auto tensor = at::tensor({1, 2, 3});
- REQUIRE(tensor.numel() == 3);
- REQUIRE(tensor.dtype() == at::kInt);
- REQUIRE(exactly_equal(tensor[0], 1));
- REQUIRE(exactly_equal(tensor[1], 2));
- REQUIRE(exactly_equal(tensor[2], 3));
+ CATCH_REQUIRE(tensor.numel() == 3);
+ CATCH_REQUIRE(tensor.dtype() == at::kInt);
+ CATCH_REQUIRE(exactly_equal(tensor[0], 1));
+ CATCH_REQUIRE(exactly_equal(tensor[1], 2));
+ CATCH_REQUIRE(exactly_equal(tensor[2], 3));
tensor = at::tensor({1.5, 2.25, 3.125});
- REQUIRE(tensor.numel() == 3);
- REQUIRE(tensor.dtype() == at::kDouble);
- REQUIRE(almost_equal(tensor[0], 1.5));
- REQUIRE(almost_equal(tensor[1], 2.25));
- REQUIRE(almost_equal(tensor[2], 3.125));
+ CATCH_REQUIRE(tensor.numel() == 3);
+ CATCH_REQUIRE(tensor.dtype() == at::kDouble);
+ CATCH_REQUIRE(almost_equal(tensor[0], 1.5));
+ CATCH_REQUIRE(almost_equal(tensor[1], 2.25));
+ CATCH_REQUIRE(almost_equal(tensor[2], 3.125));
}
-TEST_CASE("Tensor/ContainsCorrectValuesForManyValuesVariable") {
+CATCH_TEST_CASE("Tensor/ContainsCorrectValuesForManyValuesVariable") {
auto tensor = torch::tensor({1, 2, 3});
- REQUIRE(tensor.is_variable());
- REQUIRE(tensor.numel() == 3);
- REQUIRE(tensor.dtype() == at::kInt);
- REQUIRE(exactly_equal(tensor[0], 1));
- REQUIRE(exactly_equal(tensor[1], 2));
- REQUIRE(exactly_equal(tensor[2], 3));
+ CATCH_REQUIRE(tensor.is_variable());
+ CATCH_REQUIRE(tensor.numel() == 3);
+ CATCH_REQUIRE(tensor.dtype() == at::kInt);
+ CATCH_REQUIRE(exactly_equal(tensor[0], 1));
+ CATCH_REQUIRE(exactly_equal(tensor[1], 2));
+ CATCH_REQUIRE(exactly_equal(tensor[2], 3));
tensor = torch::tensor({1.5, 2.25, 3.125});
- REQUIRE(tensor.is_variable());
- REQUIRE(tensor.numel() == 3);
- REQUIRE(tensor.dtype() == at::kDouble);
- REQUIRE(almost_equal(tensor[0], 1.5));
- REQUIRE(almost_equal(tensor[1], 2.25));
- REQUIRE(almost_equal(tensor[2], 3.125));
+ CATCH_REQUIRE(tensor.is_variable());
+ CATCH_REQUIRE(tensor.numel() == 3);
+ CATCH_REQUIRE(tensor.dtype() == at::kDouble);
+ CATCH_REQUIRE(almost_equal(tensor[0], 1.5));
+ CATCH_REQUIRE(almost_equal(tensor[1], 2.25));
+ CATCH_REQUIRE(almost_equal(tensor[2], 3.125));
}
-TEST_CASE("Tensor/ContainsCorrectValuesWhenConstructedFromVector") {
+CATCH_TEST_CASE("Tensor/ContainsCorrectValuesWhenConstructedFromVector") {
std::vector<int> v = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10};
auto tensor = at::tensor(v);
- REQUIRE(tensor.numel() == v.size());
- REQUIRE(tensor.dtype() == at::kInt);
+ CATCH_REQUIRE(tensor.numel() == v.size());
+ CATCH_REQUIRE(tensor.dtype() == at::kInt);
for (size_t i = 0; i < v.size(); ++i) {
- REQUIRE(exactly_equal(tensor[i], v.at(i)));
+ CATCH_REQUIRE(exactly_equal(tensor[i], v.at(i)));
}
std::vector<float> w = {1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8, 9.9, 10.0};
tensor = at::tensor(w);
- REQUIRE(tensor.numel() == w.size());
- REQUIRE(tensor.dtype() == at::kFloat);
+ CATCH_REQUIRE(tensor.numel() == w.size());
+ CATCH_REQUIRE(tensor.dtype() == at::kFloat);
for (size_t i = 0; i < w.size(); ++i) {
- REQUIRE(almost_equal(tensor[i], w.at(i)));
+ CATCH_REQUIRE(almost_equal(tensor[i], w.at(i)));
}
}
-TEST_CASE("Tensor/UsesOptionsThatAreSupplied") {
+CATCH_TEST_CASE("Tensor/UsesOptionsThatAreSupplied") {
auto tensor = at::tensor(123, dtype(at::kFloat)) + 0.5;
- REQUIRE(tensor.numel() == 1);
- REQUIRE(tensor.dtype() == at::kFloat);
- REQUIRE(almost_equal(tensor[0], 123.5));
+ CATCH_REQUIRE(tensor.numel() == 1);
+ CATCH_REQUIRE(tensor.dtype() == at::kFloat);
+ CATCH_REQUIRE(almost_equal(tensor[0], 123.5));
tensor = at::tensor({1.1, 2.2, 3.3}, dtype(at::kInt));
- REQUIRE(tensor.numel() == 3);
- REQUIRE(tensor.dtype() == at::kInt);
- REQUIRE(tensor.layout() == at::kStrided);
- REQUIRE(exactly_equal(tensor[0], 1));
- REQUIRE(exactly_equal(tensor[1], 2));
- REQUIRE(exactly_equal(tensor[2], 3));
+ CATCH_REQUIRE(tensor.numel() == 3);
+ CATCH_REQUIRE(tensor.dtype() == at::kInt);
+ CATCH_REQUIRE(tensor.layout() == at::kStrided);
+ CATCH_REQUIRE(exactly_equal(tensor[0], 1));
+ CATCH_REQUIRE(exactly_equal(tensor[1], 2));
+ CATCH_REQUIRE(exactly_equal(tensor[2], 3));
}
-TEST_CASE("FromBlob") {
+CATCH_TEST_CASE("FromBlob") {
std::vector<int32_t> v = {1, 2, 3};
auto tensor = torch::from_blob(v.data(), v.size(), torch::kInt32);
- REQUIRE(tensor.is_variable());
- REQUIRE(tensor.numel() == 3);
- REQUIRE(tensor[0].toCInt() == 1);
- REQUIRE(tensor[1].toCInt() == 2);
- REQUIRE(tensor[2].toCInt() == 3);
+ CATCH_REQUIRE(tensor.is_variable());
+ CATCH_REQUIRE(tensor.numel() == 3);
+ CATCH_REQUIRE(tensor[0].toCInt() == 1);
+ CATCH_REQUIRE(tensor[1].toCInt() == 2);
+ CATCH_REQUIRE(tensor[2].toCInt() == 3);
}
diff --git a/test/cpp/api/tensor_cuda.cpp b/test/cpp/api/tensor_cuda.cpp
index 82d874e..8f85014 100644
--- a/test/cpp/api/tensor_cuda.cpp
+++ b/test/cpp/api/tensor_cuda.cpp
@@ -1,11 +1,11 @@
-#include <catch.hpp>
+#include "catch_utils.hpp"
#include <ATen/ATen.h>
#include <cmath>
-TEST_CASE("Tensor/AllocatesTensorOnTheCorrectDevice", "[multi-cuda]") {
+CATCH_TEST_CASE("Tensor/AllocatesTensorOnTheCorrectDevice", "[multi-cuda]") {
auto tensor = at::tensor({1, 2, 3}, at::device({at::kCUDA, 1}));
- REQUIRE(tensor.device().type() == at::Device::Type::CUDA);
- REQUIRE(tensor.device().index() == 1);
+ CATCH_REQUIRE(tensor.device().type() == at::Device::Type::CUDA);
+ CATCH_REQUIRE(tensor.device().index() == 1);
}
diff --git a/test/cpp/api/tensor_options.cpp b/test/cpp/api/tensor_options.cpp
index ab80c5f..7118a35 100644
--- a/test/cpp/api/tensor_options.cpp
+++ b/test/cpp/api/tensor_options.cpp
@@ -1,4 +1,4 @@
-#include "catch.hpp"
+#include "catch_utils.hpp"
#include <torch/tensor.h>
@@ -14,28 +14,28 @@
// A macro so we don't lose location information when an assertion fails.
#define REQUIRE_OPTIONS(device_, index_, type_, layout_) \
- REQUIRE(options.device().type() == Device((device_), (index_)).type()); \
- REQUIRE(options.device().index() == Device((device_), (index_)).index()); \
- REQUIRE(options.dtype() == (type_)); \
- REQUIRE(options.layout() == (layout_))
+ CATCH_REQUIRE(options.device().type() == Device((device_), (index_)).type()); \
+ CATCH_REQUIRE(options.device().index() == Device((device_), (index_)).index()); \
+ CATCH_REQUIRE(options.dtype() == (type_)); \
+ CATCH_REQUIRE(options.layout() == (layout_))
#define REQUIRE_TENSOR_OPTIONS(device_, index_, type_, layout_) \
- REQUIRE(tensor.device().type() == Device((device_), (index_)).type()); \
- REQUIRE(tensor.device().index() == Device((device_), (index_)).index()); \
- REQUIRE(tensor.type().scalarType() == (type_)); \
- REQUIRE(tensor.type().layout() == (layout_))
+ CATCH_REQUIRE(tensor.device().type() == Device((device_), (index_)).type()); \
+ CATCH_REQUIRE(tensor.device().index() == Device((device_), (index_)).index()); \
+ CATCH_REQUIRE(tensor.type().scalarType() == (type_)); \
+ CATCH_REQUIRE(tensor.type().layout() == (layout_))
-TEST_CASE("TensorOptions/DefaultsToTheRightValues") {
+CATCH_TEST_CASE("TensorOptions/DefaultsToTheRightValues") {
TensorOptions options;
REQUIRE_OPTIONS(kCPU, -1, kFloat, kStrided);
}
-TEST_CASE("TensorOptions/ReturnsTheCorrectType") {
+CATCH_TEST_CASE("TensorOptions/ReturnsTheCorrectType") {
auto options = TensorOptions().device(kCPU).dtype(kInt).layout(kSparse);
- REQUIRE(at::getType(options) == getNonVariableType(Backend::SparseCPU, kInt));
+ CATCH_REQUIRE(at::getType(options) == getNonVariableType(Backend::SparseCPU, kInt));
}
-TEST_CASE("TensorOptions/UtilityFunctionsReturnTheRightTensorOptions") {
+CATCH_TEST_CASE("TensorOptions/UtilityFunctionsReturnTheRightTensorOptions") {
auto options = dtype(kInt);
REQUIRE_OPTIONS(kCPU, -1, kInt, kStrided);
@@ -52,7 +52,7 @@
REQUIRE_OPTIONS(kCUDA, 3, kByte, kSparse);
}
-TEST_CASE("TensorOptions/ConstructsWellFromCPUTypes") {
+CATCH_TEST_CASE("TensorOptions/ConstructsWellFromCPUTypes") {
TensorOptions options;
REQUIRE_OPTIONS(kCPU, -1, kFloat, kStrided);
@@ -69,7 +69,7 @@
REQUIRE_OPTIONS(kCPU, -1, kByte, kSparse);
}
-TEST_CASE("TensorOptions/ConstructsWellFromCPUTensors") {
+CATCH_TEST_CASE("TensorOptions/ConstructsWellFromCPUTensors") {
auto options = empty(5, kDouble).options();
REQUIRE_OPTIONS(kCPU, -1, kDouble, kStrided);
@@ -77,37 +77,37 @@
REQUIRE_OPTIONS(kCPU, -1, kByte, kSparse);
}
-TEST_CASE("TensorOptions/ConstructsWellFromVariables") {
+CATCH_TEST_CASE("TensorOptions/ConstructsWellFromVariables") {
auto options = torch::empty(5).options();
REQUIRE_OPTIONS(kCPU, -1, kFloat, kStrided);
- REQUIRE(!options.requires_grad());
+ CATCH_REQUIRE(!options.requires_grad());
options = torch::empty(5, at::requires_grad()).options();
REQUIRE_OPTIONS(kCPU, -1, kFloat, kStrided);
- REQUIRE(!options.requires_grad());
+ CATCH_REQUIRE(!options.requires_grad());
}
-TEST_CASE("Device/ParsesCorrectlyFromString") {
+CATCH_TEST_CASE("Device/ParsesCorrectlyFromString") {
Device device("cpu:0");
- REQUIRE(device == Device(kCPU, 0));
+ CATCH_REQUIRE(device == Device(kCPU, 0));
device = Device("cpu");
- REQUIRE(device == Device(kCPU));
+ CATCH_REQUIRE(device == Device(kCPU));
device = Device("cuda:123");
- REQUIRE(device == Device(kCUDA, 123));
+ CATCH_REQUIRE(device == Device(kCUDA, 123));
device = Device("cuda");
- REQUIRE(device == Device(kCUDA));
+ CATCH_REQUIRE(device == Device(kCUDA));
std::vector<std::string> badnesses = {
"", "cud:1", "cuda:", "cpu::1", ":1", "3", "tpu:4", "??"};
for (const auto& badness : badnesses) {
- REQUIRE_THROWS(Device(badness));
+ _CATCH_REQUIRE_THROWS(Device(badness));
}
}
-TEST_CASE("OptionsGuard") {
+CATCH_TEST_CASE("OptionsGuard") {
Tensor tensor;
{
OptionsGuard guard(TensorOptions{});
@@ -132,5 +132,5 @@
tensor = torch::empty({10});
}
REQUIRE_TENSOR_OPTIONS(kCPU, -1, kFloat, kStrided);
- REQUIRE(tensor.requires_grad());
+ CATCH_REQUIRE(tensor.requires_grad());
}
diff --git a/test/cpp/api/tensor_options_cuda.cpp b/test/cpp/api/tensor_options_cuda.cpp
index ea33321..edeede8 100644
--- a/test/cpp/api/tensor_options_cuda.cpp
+++ b/test/cpp/api/tensor_options_cuda.cpp
@@ -1,4 +1,4 @@
-#include "catch.hpp"
+#include "catch_utils.hpp"
#include <ATen/Context.h>
#include <ATen/DeviceGuard.h>
@@ -10,18 +10,18 @@
// A macro so we don't lose location information when an assertion fails.
#define REQUIRE_OPTIONS(device_, index_, type_, layout_) \
- REQUIRE(options.device().type() == Device((device_), (index_)).type()); \
- REQUIRE(options.device().index() == Device((device_), (index_)).index()); \
- REQUIRE(options.dtype() == (type_)); \
- REQUIRE(options.layout() == (layout_))
+ CATCH_REQUIRE(options.device().type() == Device((device_), (index_)).type()); \
+ CATCH_REQUIRE(options.device().index() == Device((device_), (index_)).index()); \
+ CATCH_REQUIRE(options.dtype() == (type_)); \
+ CATCH_REQUIRE(options.layout() == (layout_))
#define REQUIRE_TENSOR_OPTIONS(device_, index_, type_, layout_) \
- REQUIRE(tensor.device().type() == Device((device_), (index_)).type()); \
- REQUIRE(tensor.device().index() == Device((device_), (index_)).index()); \
- REQUIRE(tensor.type().scalarType() == (type_)); \
- REQUIRE(tensor.type().layout() == (layout_))
+ CATCH_REQUIRE(tensor.device().type() == Device((device_), (index_)).type()); \
+ CATCH_REQUIRE(tensor.device().index() == Device((device_), (index_)).index()); \
+ CATCH_REQUIRE(tensor.type().scalarType() == (type_)); \
+ CATCH_REQUIRE(tensor.type().layout() == (layout_))
-TEST_CASE("TensorOptions/ConstructsWellFromCUDATypes", "[cuda]") {
+CATCH_TEST_CASE("TensorOptions/ConstructsWellFromCUDATypes", "[cuda]") {
auto options = CUDA(kFloat).options();
REQUIRE_OPTIONS(kCUDA, -1, kFloat, kStrided);
@@ -41,7 +41,7 @@
REQUIRE_OPTIONS(kCUDA, 5, kFloat, kSparse);
}
-TEST_CASE("TensorOptions/ConstructsWellFromCUDATensors", "[multi-cuda]") {
+CATCH_TEST_CASE("TensorOptions/ConstructsWellFromCUDATensors", "[multi-cuda]") {
auto options = empty(5, device(kCUDA).dtype(kDouble)).options();
REQUIRE_OPTIONS(kCUDA, 0, kDouble, kStrided);
@@ -66,7 +66,7 @@
}
}
-TEST_CASE("OptionsGuardCUDA", "[multi-cuda]") {
+CATCH_TEST_CASE("OptionsGuardCUDA", "[multi-cuda]") {
Tensor tensor;
{
OptionsGuard guard(device(kCUDA));
@@ -87,7 +87,7 @@
REQUIRE_TENSOR_OPTIONS(kCUDA, 0, kInt, kStrided);
}
-TEST_CASE("DeviceGuardOptionsGuardInteraction", "[multi-cuda]") {
+CATCH_TEST_CASE("DeviceGuardOptionsGuardInteraction", "[multi-cuda]") {
Tensor tensor;
{
// Check that OptionsGuard respects any active device before construction.
@@ -112,17 +112,17 @@
}
}
-TEST_CASE("DeviceGuardIsMovable", "[cuda]") {
+CATCH_TEST_CASE("DeviceGuardIsMovable", "[cuda]") {
DeviceGuard first(1);
- REQUIRE(first.original_index() == 0);
- REQUIRE(first.last_index() == 1);
+ CATCH_REQUIRE(first.original_index() == 0);
+ CATCH_REQUIRE(first.last_index() == 1);
DeviceGuard second(std::move(first));
- REQUIRE(second.original_index() == 0);
- REQUIRE(second.last_index() == 1);
- REQUIRE(first.original_index() == -1);
+ CATCH_REQUIRE(second.original_index() == 0);
+ CATCH_REQUIRE(second.last_index() == 1);
+ CATCH_REQUIRE(first.original_index() == -1);
DeviceGuard third;
third = std::move(second);
- REQUIRE(third.original_index() == 0);
- REQUIRE(third.last_index() == 1);
- REQUIRE(second.original_index() == -1);
+ CATCH_REQUIRE(third.original_index() == 0);
+ CATCH_REQUIRE(third.last_index() == 1);
+ CATCH_REQUIRE(second.original_index() == -1);
}
diff --git a/torch/csrc/jit/catch_utils.hpp b/torch/csrc/jit/catch_utils.hpp
new file mode 100644
index 0000000..b9b0a87
--- /dev/null
+++ b/torch/csrc/jit/catch_utils.hpp
@@ -0,0 +1,8 @@
+#pragma once
+
+#define CATCH_CONFIG_PREFIX_ALL
+#include <catch.hpp>
+
+// CATCH_REQUIRE_THROWS is not defined identically to REQUIRE_THROWS and causes warning;
+// define our own version that doesn't warn.
+#define _CATCH_REQUIRE_THROWS( ... ) INTERNAL_CATCH_THROWS( "CATCH_REQUIRE_THROWS", Catch::ResultDisposition::Normal, __VA_ARGS__ )
diff --git a/torch/csrc/jit/test_jit.cpp b/torch/csrc/jit/test_jit.cpp
index 28bf958..3110fb2c 100644
--- a/torch/csrc/jit/test_jit.cpp
+++ b/torch/csrc/jit/test_jit.cpp
@@ -1,13 +1,13 @@
#ifdef USE_CATCH
#define CATCH_CONFIG_MAIN
-#include "catch.hpp"
+#include "catch_utils.hpp"
using Catch::StartsWith;
#else
-#define REQUIRE JIT_ASSERT
+#define CATCH_REQUIRE JIT_ASSERT
#endif
@@ -110,9 +110,9 @@
e.v("what",{"is","this"});
TemplateEnv c(e);
c.s("hi","foo2");
- REQUIRE(e.s("hi") == "foo");
- REQUIRE(c.s("hi") == "foo2");
- REQUIRE(e.v("what")[0] == "is");
+ CATCH_REQUIRE(e.s("hi") == "foo");
+ CATCH_REQUIRE(c.s("hi") == "foo2");
+ CATCH_REQUIRE(e.v("what")[0] == "is");
}
{
@@ -126,7 +126,7 @@
auto s = ct.format(e);
//std::cout << "'" << s << "'\n";
//std::cout << "'" << ct_expect << "'\n";
- REQUIRE(s == ct_expect);
+ CATCH_REQUIRE(s == ct_expect);
}
}
@@ -146,11 +146,11 @@
auto b = at::rand({4,3}, at::kCUDA).transpose(0,1);
auto o = at::zeros({3,4}, at::kCUDA);
auto outputs = debugLaunchGraph(graph, 0, {a,b});
- REQUIRE(outputs.size() == 1);
+ CATCH_REQUIRE(outputs.size() == 1);
auto o2 = a*b;
float max_diff = (o2 - outputs[0]).abs().max().toCDouble();
//std::cout << "max diff: " << max_diff << "\n";
- REQUIRE(max_diff == 0);
+ CATCH_REQUIRE(max_diff == 0);
};
testSimple();
@@ -200,10 +200,10 @@
auto out0 = t16*t5;
auto outputs = debugLaunchGraph(graph, 0, inputs);
- REQUIRE(outputs.size() == graph.outputs().size());
- REQUIRE(out0.is_same_size(outputs.front()));
+ CATCH_REQUIRE(outputs.size() == graph.outputs().size());
+ CATCH_REQUIRE(out0.is_same_size(outputs.front()));
float max_diff = (outputs.front() - out0).abs().max().toCDouble();
- REQUIRE(max_diff < 1e-6);
+ CATCH_REQUIRE(max_diff < 1e-6);
};
testOne(0,0,0,0);
@@ -234,12 +234,12 @@
auto o_r = a*b;
auto o2_r = at::cat({a, o_r}, dim);
auto outputs = debugLaunchGraph(graph, 0, {a,b});
- REQUIRE(outputs.size() == 2);
+ CATCH_REQUIRE(outputs.size() == 2);
float max_diff = (o_r - outputs[0]).abs().max().toCDouble();
- REQUIRE(max_diff == 0);
+ CATCH_REQUIRE(max_diff == 0);
float max_diff2 = (o2_r - outputs[1]).abs().max().toCDouble();
- REQUIRE(max_diff2 == 0);
+ CATCH_REQUIRE(max_diff2 == 0);
};
testConcat(0);
testConcat(1);
@@ -255,58 +255,58 @@
auto four = attr::perm;
Attr attr;
attr.f_(one,3.4)->i_(two,5)->s_(three,"what");
- REQUIRE(attr.f(one) == 3.4);
- REQUIRE(attr.s(three) == "what");
- REQUIRE(attr.i(two) == 5);
+ CATCH_REQUIRE(attr.f(one) == 3.4);
+ CATCH_REQUIRE(attr.s(three) == "what");
+ CATCH_REQUIRE(attr.i(two) == 5);
attr.s_(one,"no");
- REQUIRE(attr.s(one) == "no");
- REQUIRE(attr.hasAttribute(three));
- REQUIRE(!attr.hasAttribute(four));
+ CATCH_REQUIRE(attr.s(one) == "no");
+ CATCH_REQUIRE(attr.hasAttribute(three));
+ CATCH_REQUIRE(!attr.hasAttribute(four));
attr.ss_(two, {"hi", "now"});
- REQUIRE(attr.ss(two).at(1) == "now");
+ CATCH_REQUIRE(attr.ss(two).at(1) == "now");
Attr attr2;
attr2.copyAttributes(attr);
- REQUIRE(attr2.s(one) == "no");
+ CATCH_REQUIRE(attr2.s(one) == "no");
attr2.f_(one,5);
- REQUIRE(attr.s(one) == "no");
- REQUIRE(attr2.f(one) == 5);
+ CATCH_REQUIRE(attr.s(one) == "no");
+ CATCH_REQUIRE(attr2.f(one) == 5);
}
void internedStringsTests () {
- REQUIRE(prim::Param == Symbol::prim("Param"));
- REQUIRE(prim::Return == Symbol::prim("Return"));
- REQUIRE(prim::Return.toUnqualString() == std::string("Return"));
- REQUIRE(prim::Return.toQualString() == std::string("prim::Return"));
+ CATCH_REQUIRE(prim::Param == Symbol::prim("Param"));
+ CATCH_REQUIRE(prim::Return == Symbol::prim("Return"));
+ CATCH_REQUIRE(prim::Return.toUnqualString() == std::string("Return"));
+ CATCH_REQUIRE(prim::Return.toQualString() == std::string("prim::Return"));
Symbol newsym = Symbol::aten("__NEW_SYMBOL");
size_t symstart = newsym;
- REQUIRE(newsym.toQualString() == std::string("aten::__NEW_SYMBOL"));
+ CATCH_REQUIRE(newsym.toQualString() == std::string("aten::__NEW_SYMBOL"));
// TODO: This test is a bit too close to the implementation details.
- REQUIRE(Symbol::aten("What") == symstart+1);
- REQUIRE(Symbol::aten("What2") == symstart+2);
- REQUIRE(Symbol::aten("What") == symstart+1);
- REQUIRE(Symbol::aten("What2") == symstart+2);
- REQUIRE(Symbol(symstart+2).toUnqualString() == std::string("What2"));
+ CATCH_REQUIRE(Symbol::aten("What") == symstart+1);
+ CATCH_REQUIRE(Symbol::aten("What2") == symstart+2);
+ CATCH_REQUIRE(Symbol::aten("What") == symstart+1);
+ CATCH_REQUIRE(Symbol::aten("What2") == symstart+2);
+ CATCH_REQUIRE(Symbol(symstart+2).toUnqualString() == std::string("What2"));
}
void fromQualStringTests() {
- REQUIRE(Symbol::fromQualString("prim::Param") == Symbol::prim("Param"));
- REQUIRE(Symbol::fromQualString("aten::mm") == Symbol::aten("mm"));
- REQUIRE(Symbol::fromQualString("onnx::LSTM") == Symbol::onnx("LSTM"));
- REQUIRE(Symbol::fromQualString("attr::value") == Symbol::attr("value"));
- REQUIRE(Symbol::fromQualString("scope::") == Symbol::scope(""));
- REQUIRE(Symbol::fromQualString("::").toUnqualString() == std::string(""));
- REQUIRE(Symbol::fromQualString("::").ns().toQualString() == std::string("namespaces::"));
- REQUIRE(Symbol::fromQualString("new_ns::param").toUnqualString() == std::string("param"));
- REQUIRE(Symbol::fromQualString("new_ns::param").ns().toUnqualString() == std::string("new_ns"));
- REQUIRE(Symbol::fromQualString("new_ns::param").ns() == Symbol::fromQualString("namespaces::new_ns"));
+ CATCH_REQUIRE(Symbol::fromQualString("prim::Param") == Symbol::prim("Param"));
+ CATCH_REQUIRE(Symbol::fromQualString("aten::mm") == Symbol::aten("mm"));
+ CATCH_REQUIRE(Symbol::fromQualString("onnx::LSTM") == Symbol::onnx("LSTM"));
+ CATCH_REQUIRE(Symbol::fromQualString("attr::value") == Symbol::attr("value"));
+ CATCH_REQUIRE(Symbol::fromQualString("scope::") == Symbol::scope(""));
+ CATCH_REQUIRE(Symbol::fromQualString("::").toUnqualString() == std::string(""));
+ CATCH_REQUIRE(Symbol::fromQualString("::").ns().toQualString() == std::string("namespaces::"));
+ CATCH_REQUIRE(Symbol::fromQualString("new_ns::param").toUnqualString() == std::string("param"));
+ CATCH_REQUIRE(Symbol::fromQualString("new_ns::param").ns().toUnqualString() == std::string("new_ns"));
+ CATCH_REQUIRE(Symbol::fromQualString("new_ns::param").ns() == Symbol::fromQualString("namespaces::new_ns"));
auto bad_inputs = {"scope", ":", ""};
for (auto input : bad_inputs) {
try {
Symbol::fromQualString(input);
- REQUIRE(0);
+ CATCH_REQUIRE(0);
} catch (std::runtime_error c) {
}
}
@@ -467,8 +467,8 @@
std::tie(hx, cx) = lstm(input[0], hx, cx, w_ih, w_hh);
//std::cout << almostEqual(outputs[0],hx) << "\n";
- REQUIRE(exactlyEqual(outputs[0],hx));
- REQUIRE(exactlyEqual(outputs[1],cx));
+ CATCH_REQUIRE(exactlyEqual(outputs[0],hx));
+ CATCH_REQUIRE(exactlyEqual(outputs[1],cx));
}
void interpStageTest() {
@@ -500,8 +500,8 @@
std::tie(hx, cx) = lstm(input[0], hx, cx1, w_ih, w_hh);
//std::cout << almostEqual(outputs[0],hx) << "\n";
- REQUIRE(exactlyEqual(outputs[0],hx));
- REQUIRE(exactlyEqual(outputs[1],cx));
+ CATCH_REQUIRE(exactlyEqual(outputs[0],hx));
+ CATCH_REQUIRE(exactlyEqual(outputs[1],cx));
}
using var_meta_type = std::vector<int64_t>;
@@ -554,10 +554,10 @@
}
void assertAllClose(const tensor_list& a, const tensor_list& b) {
- REQUIRE(a.size() == b.size());
+ CATCH_REQUIRE(a.size() == b.size());
for (size_t i = 0; i < a.size(); ++i) {
- REQUIRE(a[i].is_same_size(b[i]));
- REQUIRE(a[i].allclose(b[i]));
+ CATCH_REQUIRE(a[i].is_same_size(b[i]));
+ CATCH_REQUIRE(a[i].allclose(b[i]));
}
}
@@ -654,11 +654,11 @@
std::vector<size_t> expected_captured_outputs = {1};
std::vector<size_t> expected_input_vjps = {0, 1};
std::vector<size_t> expected_output_vjps = {0, 1};
- REQUIRE(grad_spec.f_real_outputs == 1);
- REQUIRE(grad_spec.df_input_captured_inputs == expected_captured_inputs);
- REQUIRE(grad_spec.df_input_captured_outputs == expected_captured_outputs);
- REQUIRE(grad_spec.df_input_vjps == expected_input_vjps);
- REQUIRE(grad_spec.df_output_vjps == expected_output_vjps);
+ CATCH_REQUIRE(grad_spec.f_real_outputs == 1);
+ CATCH_REQUIRE(grad_spec.df_input_captured_inputs == expected_captured_inputs);
+ CATCH_REQUIRE(grad_spec.df_input_captured_outputs == expected_captured_outputs);
+ CATCH_REQUIRE(grad_spec.df_input_vjps == expected_input_vjps);
+ CATCH_REQUIRE(grad_spec.df_output_vjps == expected_output_vjps);
out << "testDifferentiate\n";
out << *grad_spec.f;
out << *grad_spec.df;
@@ -684,11 +684,11 @@
auto grad_spec = differentiate(graph);
std::vector<size_t> expected_input_vjps = {1, 2}; // for e and %4 = (d + a)
std::vector<size_t> expected_output_vjps = {0}; // only a requires grad
- REQUIRE(grad_spec.f_real_outputs == 2); // we need one temporary %4 = (d + a)
- REQUIRE(grad_spec.df_input_captured_inputs == std::vector<size_t>({0}));
- REQUIRE(grad_spec.df_input_captured_outputs == std::vector<size_t>({2}));
- REQUIRE(grad_spec.df_input_vjps == expected_input_vjps);
- REQUIRE(grad_spec.df_output_vjps == expected_output_vjps);
+ CATCH_REQUIRE(grad_spec.f_real_outputs == 2); // we need one temporary %4 = (d + a)
+ CATCH_REQUIRE(grad_spec.df_input_captured_inputs == std::vector<size_t>({0}));
+ CATCH_REQUIRE(grad_spec.df_input_captured_outputs == std::vector<size_t>({2}));
+ CATCH_REQUIRE(grad_spec.df_input_vjps == expected_input_vjps);
+ CATCH_REQUIRE(grad_spec.df_output_vjps == expected_output_vjps);
out << "testDifferentiateWithRequiresGrad\n";
out << *grad_spec.f;
out << *grad_spec.df;
@@ -718,7 +718,7 @@
}
bool isEqual(const CompleteArgumentInfo & ti, const autograd::Variable & v) {
- REQUIRE(ti.isTensor());
+ CATCH_REQUIRE(ti.isTensor());
if(!ti.defined())
return ti.defined() == v.defined();
return
@@ -754,34 +754,34 @@
CompleteArgumentSpec a(true, list);
CompleteArgumentSpec b(true, list);
- REQUIRE(a.hashCode() == b.hashCode());
+ CATCH_REQUIRE(a.hashCode() == b.hashCode());
- REQUIRE(a == b);
+ CATCH_REQUIRE(a == b);
CompleteArgumentSpec d(true, list2);
- REQUIRE(d == a);
- REQUIRE(d.hashCode() == a.hashCode());
+ CATCH_REQUIRE(d == a);
+ CATCH_REQUIRE(d.hashCode() == a.hashCode());
for(size_t i = 0; i < list.size(); ++i) {
- REQUIRE(isEqual(a.at(i), list[i].toTensor()));
+ CATCH_REQUIRE(isEqual(a.at(i), list[i].toTensor()));
}
CompleteArgumentSpec no_grad(/*with_grad=*/false, list);
- REQUIRE(no_grad != a);
+ CATCH_REQUIRE(no_grad != a);
std::unordered_set<CompleteArgumentSpec> spec;
spec.insert(std::move(a));
- REQUIRE(spec.count(b) > 0);
- REQUIRE(spec.count(no_grad) == 0);
+ CATCH_REQUIRE(spec.count(b) > 0);
+ CATCH_REQUIRE(spec.count(no_grad) == 0);
spec.insert(std::move(no_grad));
- REQUIRE(spec.count(CompleteArgumentSpec(true,list)) == 1);
+ CATCH_REQUIRE(spec.count(CompleteArgumentSpec(true,list)) == 1);
list2[1].toTensor().transpose_(0,1);
CompleteArgumentSpec c(true, list2); // same as list, except for one stride
- REQUIRE(!(c == a));
- REQUIRE(spec.count(c) == 0);
+ CATCH_REQUIRE(!(c == a));
+ CATCH_REQUIRE(spec.count(c) == 0);
Stack stack = { var(CF, {1,2}, true), 3, var(CF, {1,2}, true) };
CompleteArgumentSpec with_const(true, stack);
- REQUIRE(with_const.at(2).sizes().size() == 2);
+ CATCH_REQUIRE(with_const.at(2).sizes().size() == 2);
}
void testGraphExecutor() {
@@ -802,11 +802,11 @@
GraphExecutor executor(g);
auto stack = createStack({v(input), v(hx), v(cx), v(w_ih), v(w_hh)});
executor.run(stack);
- REQUIRE(stack.size() == 2);
+ CATCH_REQUIRE(stack.size() == 2);
at::Tensor r0, r1;
std::tie(r0, r1) = lstm(input, hx, cx, w_ih, w_hh);
- REQUIRE(almostEqual(Variable(stack[0].toTensor()).data(), r0));
- REQUIRE(almostEqual(Variable(stack[1].toTensor()).data(), r1));
+ CATCH_REQUIRE(almostEqual(Variable(stack[0].toTensor()).data(), r0));
+ CATCH_REQUIRE(almostEqual(Variable(stack[1].toTensor()).data(), r1));
}
void testBlocks(std::ostream & out) {
@@ -877,11 +877,11 @@
auto run_binary = [&](const std::string & name, int64_t a, int64_t b) {
return V(run(name, {L(a), L(b)})[0]);
};
- REQUIRE(2 == run_binary("if_test", 1, 2));
- REQUIRE(3 == run_binary("if_test", 3, 2));
- REQUIRE(2 == run_binary("if_one", 2, 3));
- REQUIRE(2 == run_binary("if_one", 3, 2));
- REQUIRE(256 == run_binary("while_test",2,0));
+ CATCH_REQUIRE(2 == run_binary("if_test", 1, 2));
+ CATCH_REQUIRE(3 == run_binary("if_test", 3, 2));
+ CATCH_REQUIRE(2 == run_binary("if_one", 2, 3));
+ CATCH_REQUIRE(2 == run_binary("if_one", 3, 2));
+ CATCH_REQUIRE(256 == run_binary("while_test",2,0));
}
void testIValue() {
@@ -939,18 +939,18 @@
RegisterOperators reg({createOperator(
"foo::bar", [](double a, at::Tensor b) { return a + b; })});
auto& ops = getAllOperatorsFor(Symbol::fromQualString("foo::bar"));
- REQUIRE(ops.size() == 1);
+ CATCH_REQUIRE(ops.size() == 1);
auto& op = ops.front();
- REQUIRE(op->schema().name == "foo::bar");
+ CATCH_REQUIRE(op->schema().name == "foo::bar");
- REQUIRE(op->schema().arguments.size() == 2);
- REQUIRE(op->schema().arguments[0].name == "_0");
- REQUIRE(op->schema().arguments[0].type->kind() == TypeKind::FloatType);
- REQUIRE(op->schema().arguments[1].name == "_1");
- REQUIRE(op->schema().arguments[1].type->kind() == TypeKind::DynamicType);
+ CATCH_REQUIRE(op->schema().arguments.size() == 2);
+ CATCH_REQUIRE(op->schema().arguments[0].name == "_0");
+ CATCH_REQUIRE(op->schema().arguments[0].type->kind() == TypeKind::FloatType);
+ CATCH_REQUIRE(op->schema().arguments[1].name == "_1");
+ CATCH_REQUIRE(op->schema().arguments[1].type->kind() == TypeKind::DynamicType);
- REQUIRE(op->schema().returns[0].type->kind() == TypeKind::DynamicType);
+ CATCH_REQUIRE(op->schema().returns[0].type->kind() == TypeKind::DynamicType);
Stack stack;
push(stack, 2.0f, autograd::make_variable(at::ones(5)));
@@ -958,7 +958,7 @@
at::Tensor output;
pop(stack, output);
- REQUIRE(output.allclose(autograd::make_variable(at::full(5, 3.0f))));
+ CATCH_REQUIRE(output.allclose(autograd::make_variable(at::full(5, 3.0f))));
}
{
RegisterOperators reg({createOperator(
@@ -967,19 +967,19 @@
auto& ops =
getAllOperatorsFor(Symbol::fromQualString("foo::bar_with_schema"));
- REQUIRE(ops.size() == 1);
+ CATCH_REQUIRE(ops.size() == 1);
auto& op = ops.front();
- REQUIRE(op->schema().name == "foo::bar_with_schema");
+ CATCH_REQUIRE(op->schema().name == "foo::bar_with_schema");
- REQUIRE(op->schema().arguments.size() == 2);
- REQUIRE(op->schema().arguments[0].name == "a");
- REQUIRE(op->schema().arguments[0].type->kind() == TypeKind::FloatType);
- REQUIRE(op->schema().arguments[1].name == "b");
- REQUIRE(op->schema().arguments[1].type->kind() == TypeKind::DynamicType);
+ CATCH_REQUIRE(op->schema().arguments.size() == 2);
+ CATCH_REQUIRE(op->schema().arguments[0].name == "a");
+ CATCH_REQUIRE(op->schema().arguments[0].type->kind() == TypeKind::FloatType);
+ CATCH_REQUIRE(op->schema().arguments[1].name == "b");
+ CATCH_REQUIRE(op->schema().arguments[1].type->kind() == TypeKind::DynamicType);
- REQUIRE(op->schema().returns.size() == 1);
- REQUIRE(op->schema().returns[0].type->kind() == TypeKind::DynamicType);
+ CATCH_REQUIRE(op->schema().returns.size() == 1);
+ CATCH_REQUIRE(op->schema().returns[0].type->kind() == TypeKind::DynamicType);
Stack stack;
push(stack, 2.0f, autograd::make_variable(at::ones(5)));
@@ -987,7 +987,7 @@
at::Tensor output;
pop(stack, output);
- REQUIRE(output.allclose(autograd::make_variable(at::full(5, 3.0f))));
+ CATCH_REQUIRE(output.allclose(autograd::make_variable(at::full(5, 3.0f))));
}
{
// Check that lists work well.
@@ -999,21 +999,21 @@
auto& ops =
getAllOperatorsFor(Symbol::fromQualString("foo::lists"));
- REQUIRE(ops.size() == 1);
+ CATCH_REQUIRE(ops.size() == 1);
auto& op = ops.front();
- REQUIRE(op->schema().name == "foo::lists");
+ CATCH_REQUIRE(op->schema().name == "foo::lists");
- REQUIRE(op->schema().arguments.size() == 3);
- REQUIRE(op->schema().arguments[0].name == "ints");
- REQUIRE(op->schema().arguments[0].type->isSubtypeOf(ListType::ofInts()));
- REQUIRE(op->schema().arguments[1].name == "floats");
- REQUIRE(op->schema().arguments[1].type->isSubtypeOf(ListType::ofFloats()));
- REQUIRE(op->schema().arguments[2].name == "tensors");
- REQUIRE(op->schema().arguments[2].type->isSubtypeOf(ListType::ofTensors()));
+ CATCH_REQUIRE(op->schema().arguments.size() == 3);
+ CATCH_REQUIRE(op->schema().arguments[0].name == "ints");
+ CATCH_REQUIRE(op->schema().arguments[0].type->isSubtypeOf(ListType::ofInts()));
+ CATCH_REQUIRE(op->schema().arguments[1].name == "floats");
+ CATCH_REQUIRE(op->schema().arguments[1].type->isSubtypeOf(ListType::ofFloats()));
+ CATCH_REQUIRE(op->schema().arguments[2].name == "tensors");
+ CATCH_REQUIRE(op->schema().arguments[2].type->isSubtypeOf(ListType::ofTensors()));
- REQUIRE(op->schema().returns.size() == 1);
- REQUIRE(op->schema().returns[0].type->isSubtypeOf(ListType::ofFloats()));
+ CATCH_REQUIRE(op->schema().returns.size() == 1);
+ CATCH_REQUIRE(op->schema().returns[0].type->isSubtypeOf(ListType::ofFloats()));
Stack stack;
push(stack, std::vector<int64_t>{1, 2});
@@ -1023,9 +1023,9 @@
std::vector<double> output;
pop(stack, output);
- REQUIRE(output.size() == 2);
- REQUIRE(output[0] == 1.0);
- REQUIRE(output[1] == 2.0);
+ CATCH_REQUIRE(output.size() == 2);
+ CATCH_REQUIRE(output[0] == 1.0);
+ CATCH_REQUIRE(output[1] == 2.0);
}
{
RegisterOperators reg(
@@ -1034,17 +1034,17 @@
auto& ops =
getAllOperatorsFor(Symbol::fromQualString("foo::lists2"));
- REQUIRE(ops.size() == 1);
+ CATCH_REQUIRE(ops.size() == 1);
auto& op = ops.front();
- REQUIRE(op->schema().name == "foo::lists2");
+ CATCH_REQUIRE(op->schema().name == "foo::lists2");
- REQUIRE(op->schema().arguments.size() == 1);
- REQUIRE(op->schema().arguments[0].name == "tensors");
- REQUIRE(op->schema().arguments[0].type->isSubtypeOf(ListType::ofTensors()));
+ CATCH_REQUIRE(op->schema().arguments.size() == 1);
+ CATCH_REQUIRE(op->schema().arguments[0].name == "tensors");
+ CATCH_REQUIRE(op->schema().arguments[0].type->isSubtypeOf(ListType::ofTensors()));
- REQUIRE(op->schema().returns.size() == 1);
- REQUIRE(op->schema().returns[0].type->isSubtypeOf(ListType::ofTensors()));
+ CATCH_REQUIRE(op->schema().returns.size() == 1);
+ CATCH_REQUIRE(op->schema().returns[0].type->isSubtypeOf(ListType::ofTensors()));
Stack stack;
push(stack, std::vector<at::Tensor>{autograd::make_variable(at::ones(5))});
@@ -1052,31 +1052,31 @@
std::vector<at::Tensor> output;
pop(stack, output);
- REQUIRE(output.size() == 1);
- REQUIRE(output[0].allclose(autograd::make_variable(at::ones(5))));
+ CATCH_REQUIRE(output.size() == 1);
+ CATCH_REQUIRE(output[0].allclose(autograd::make_variable(at::ones(5))));
}
{
#ifdef USE_CATCH
- REQUIRE_THROWS_WITH(
+ CATCH_REQUIRE_THROWS_WITH(
createOperator(
"foo::bar_with_bad_schema(Tensor a) -> Tensor",
[](double a, at::Tensor b) { return a + b; }),
StartsWith("Inferred 2 argument(s) for operator implementation, "
"but the provided schema specified 1 argument(s)."));
- REQUIRE_THROWS_WITH(
+ CATCH_REQUIRE_THROWS_WITH(
createOperator(
"foo::bar_with_bad_schema(Tensor a) -> Tensor",
[](double a) { return a; }),
StartsWith("Inferred type for argument #0 was float, "
"but the provided schema specified type Dynamic "
"for the argument in that position"));
- REQUIRE_THROWS_WITH(
+ CATCH_REQUIRE_THROWS_WITH(
createOperator(
"foo::bar_with_bad_schema(float a) -> (float, float)",
[](double a) { return a; }),
StartsWith("Inferred 1 return value(s) for operator implementation, "
"but the provided schema specified 2 return value(s)."));
- REQUIRE_THROWS_WITH(
+ CATCH_REQUIRE_THROWS_WITH(
createOperator(
"foo::bar_with_bad_schema(float a) -> Tensor",
[](double a) { return a; }),
@@ -1109,7 +1109,7 @@
break;
}
}
- REQUIRE(contains_traced_op);
+ CATCH_REQUIRE(contains_traced_op);
}
{
#ifdef USE_CATCH
@@ -1124,7 +1124,7 @@
Stack stack;
push(stack, std::vector<double>{1.0});
- REQUIRE_THROWS_WITH(
+ CATCH_REQUIRE_THROWS_WITH(
op.getOperation()(stack),
StartsWith("Tracing float lists currently not supported!"));
#endif
@@ -1156,42 +1156,42 @@
#ifdef USE_CATCH
-TEST_CASE( "jit test CPU", "[cpu]" ) {
+CATCH_TEST_CASE( "jit test CPU", "[cpu]" ) {
std::stringstream out;
- SECTION( "control flow" )
+ CATCH_SECTION( "control flow" )
testControlFlow();
- SECTION( "blocks" )
+ CATCH_SECTION( "blocks" )
testBlocks(out);
- SECTION( "create autodiff subgraphs" )
+ CATCH_SECTION( "create autodiff subgraphs" )
testCreateAutodiffSubgraphs(out);
- SECTION( "differentiate" )
+ CATCH_SECTION( "differentiate" )
testDifferentiate(out);
- SECTION( "differentiate with requires grad" )
+ CATCH_SECTION( "differentiate with requires grad" )
testDifferentiateWithRequiresGrad(out);
- SECTION( "AD formulas" )
+ CATCH_SECTION( "AD formulas" )
testADFormulas();
- SECTION( "code template" )
+ CATCH_SECTION( "code template" )
codeTemplateTest();
- SECTION( "attributes" )
+ CATCH_SECTION( "attributes" )
attributesTest();
- SECTION( "interned strings" )
+ CATCH_SECTION( "interned strings" )
internedStringsTests();
- SECTION( "custom operators" )
+ CATCH_SECTION( "custom operators" )
testCustomOperators();
}
-TEST_CASE( "jit test CUDA", "[cuda]" ) {
+CATCH_TEST_CASE( "jit test CUDA", "[cuda]" ) {
- SECTION( "graph executor" )
+ CATCH_SECTION( "graph executor" )
testGraphExecutor();
- SECTION( "fusion" )
+ CATCH_SECTION( "fusion" )
fusionTests();
- SECTION( "interp" )
+ CATCH_SECTION( "interp" )
interpTest();
- SECTION( "interp stage" )
+ CATCH_SECTION( "interp stage" )
interpStageTest();
- SECTION( "argument spec" )
+ CATCH_SECTION( "argument spec" )
argumentSpecTest();
}