Use CATCH prefix to avoid name conflicts with Caffe2.

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/11780

Differential Revision: D9889925

Pulled By: gchanan

fbshipit-source-id: 5eca849c36ced00b8ae7482b7945b445a3e1687e
diff --git a/aten/src/ATen/test/apply_test.cpp b/aten/src/ATen/test/apply_test.cpp
index 986f599..fc39ecc 100644
--- a/aten/src/ATen/test/apply_test.cpp
+++ b/aten/src/ATen/test/apply_test.cpp
@@ -1,5 +1,5 @@
 #define CATCH_CONFIG_MAIN
-#include "catch.hpp"
+#include "catch_utils.hpp"
 
 #include "cuda.h"
 #include "cuda_runtime.h"
@@ -11,111 +11,111 @@
 */
 #ifndef _WIN32
 
-TEST_CASE("2D Contiguous", "Collapses a 2D contiguous tensor to 1D contiguous") {
+CATCH_TEST_CASE("2D Contiguous", "Collapses a 2D contiguous tensor to 1D contiguous") {
     int sizes[] = {4, 4};
     int strides[] = {4, 1};
     ::at::cuda::detail::TensorInfo<void, int> ti{nullptr, 2, sizes, strides};
     ti.collapseDims();
-    REQUIRE(ti.dims == 1);
-    REQUIRE(ti.sizes[0] == (4 * 4));
+    CATCH_REQUIRE(ti.dims == 1);
+    CATCH_REQUIRE(ti.sizes[0] == (4 * 4));
 }
 
-TEST_CASE("3D Contiguous", "Collapses a 3D contiguous tensor to a 1D contiguous") {
+CATCH_TEST_CASE("3D Contiguous", "Collapses a 3D contiguous tensor to a 1D contiguous") {
     int sizes[] = {6, 3, 7};
     int strides[] = {3 * 7, 7, 1};
     ::at::cuda::detail::TensorInfo<void, int> ti{nullptr, 3, sizes, strides};
     ti.collapseDims();
-    REQUIRE(ti.dims == 1);
-    REQUIRE(ti.sizes[0] == (6 * 3 * 7));
+    CATCH_REQUIRE(ti.dims == 1);
+    CATCH_REQUIRE(ti.sizes[0] == (6 * 3 * 7));
 }
 
-TEST_CASE("3D Partial Collapse", "Collapses a 3D noncontiguous tensor to a 2D tensor") {
+CATCH_TEST_CASE("3D Partial Collapse", "Collapses a 3D noncontiguous tensor to a 2D tensor") {
     int sizes[] = {4, 3, 2};
     int strides[] = {3 * 3, 3, 1};
     ::at::cuda::detail::TensorInfo<void, int> ti{nullptr, 3, sizes, strides};
     ti.collapseDims();
-    REQUIRE(ti.dims == 2);
-    REQUIRE(ti.sizes[0] == (4 * 3));
-    REQUIRE(ti.sizes[1] == 2);
+    CATCH_REQUIRE(ti.dims == 2);
+    CATCH_REQUIRE(ti.sizes[0] == (4 * 3));
+    CATCH_REQUIRE(ti.sizes[1] == 2);
 }
 
-TEST_CASE("2D Strided Collapse", "Collapses a 2D skip contiguous tensor to a 1D skip contiguous tensor") {
+CATCH_TEST_CASE("2D Strided Collapse", "Collapses a 2D skip contiguous tensor to a 1D skip contiguous tensor") {
     int sizes[] = {3, 2};
     int strides[] = {2 * 2, 2};
     ::at::cuda::detail::TensorInfo<void, int> ti{nullptr, 2, sizes, strides};
     ti.collapseDims();
-    REQUIRE(ti.dims == 1);
-    REQUIRE(ti.sizes[0] == (3 * 2));
-    REQUIRE(ti.strides[0] == 2);
+    CATCH_REQUIRE(ti.dims == 1);
+    CATCH_REQUIRE(ti.sizes[0] == (3 * 2));
+    CATCH_REQUIRE(ti.strides[0] == 2);
 }
 
-TEST_CASE("4D Partial Strided Collapse", "Collapses a 4D tensor to a 2D tensor"){
+CATCH_TEST_CASE("4D Partial Strided Collapse", "Collapses a 4D tensor to a 2D tensor"){
     int sizes[] = {3, 6, 5, 2};
     int strides[] = {6 * 22, 22, 2 * 2, 2};
     ::at::cuda::detail::TensorInfo<void, int> ti{nullptr, 4, sizes, strides};
     ti.collapseDims();
-    REQUIRE(ti.dims == 2);
-    REQUIRE(ti.sizes[0] == (3 * 6));
-    REQUIRE(ti.strides[0] == 22);
-    REQUIRE(ti.sizes[1] == (5 * 2));
-    REQUIRE(ti.strides[1] == 2);
+    CATCH_REQUIRE(ti.dims == 2);
+    CATCH_REQUIRE(ti.sizes[0] == (3 * 6));
+    CATCH_REQUIRE(ti.strides[0] == 22);
+    CATCH_REQUIRE(ti.sizes[1] == (5 * 2));
+    CATCH_REQUIRE(ti.strides[1] == 2);
 }
 
-TEST_CASE("Collapsing Zeros and Ones", "Collapses a 5D tensor to a 1D tensor") {
+CATCH_TEST_CASE("Collapsing Zeros and Ones", "Collapses a 5D tensor to a 1D tensor") {
     int sizes[] = {1, 10, 1, 5, 4};
     int strides[] = {4, 0, 16, 0, 1};
     ::at::cuda::detail::TensorInfo<void, int> ti{nullptr, 5, sizes, strides};
     ti.collapseDims();
-    REQUIRE(ti.dims == 2);
-    REQUIRE(ti.sizes[0] == (10 * 5));
-    REQUIRE(ti.strides[0] == 0);
-    REQUIRE(ti.sizes[1] == 4);
-    REQUIRE(ti.strides[1] == 1);
+    CATCH_REQUIRE(ti.dims == 2);
+    CATCH_REQUIRE(ti.sizes[0] == (10 * 5));
+    CATCH_REQUIRE(ti.strides[0] == 0);
+    CATCH_REQUIRE(ti.sizes[1] == 4);
+    CATCH_REQUIRE(ti.strides[1] == 1);
 }
 
-TEST_CASE("Collapsing to a Point Tensor", "Collapses a 3D tensor to a point tensor") {
+CATCH_TEST_CASE("Collapsing to a Point Tensor", "Collapses a 3D tensor to a point tensor") {
     int sizes[] = {1, 1, 1};
     int strides[] = {17, 12, 3};
     ::at::cuda::detail::TensorInfo<void, int> ti{nullptr, 3, sizes, strides};
-    REQUIRE(ti.collapseDims() == 0);
-    REQUIRE(ti.dims == 1);
-    REQUIRE(ti.sizes[0] == 1);
-    REQUIRE(ti.strides[0] == 1);
+    CATCH_REQUIRE(ti.collapseDims() == 0);
+    CATCH_REQUIRE(ti.dims == 1);
+    CATCH_REQUIRE(ti.sizes[0] == 1);
+    CATCH_REQUIRE(ti.strides[0] == 1);
 }
 
-TEST_CASE("Excluding in a 4D Contiguous", "Collapses a 4D tensor to a 3D tensor") {
+CATCH_TEST_CASE("Excluding in a 4D Contiguous", "Collapses a 4D tensor to a 3D tensor") {
     int sizes[] = {3, 6, 5, 2};
     int strides[] = {6 * 22, 22, 2 * 2, 2};
     ::at::cuda::detail::TensorInfo<void, int> ti{nullptr, 4, sizes, strides};
-    REQUIRE(ti.collapseDims(1) == 1);
-    REQUIRE(ti.dims == 3);
-    REQUIRE(ti.sizes[0] == 3);
-    REQUIRE(ti.strides[0] == (6 * 22));
-    REQUIRE(ti.sizes[1] == 6);
-    REQUIRE(ti.strides[1] == 22);
-    REQUIRE(ti.sizes[2] == (5 * 2));
-    REQUIRE(ti.strides[2] == 2);
+    CATCH_REQUIRE(ti.collapseDims(1) == 1);
+    CATCH_REQUIRE(ti.dims == 3);
+    CATCH_REQUIRE(ti.sizes[0] == 3);
+    CATCH_REQUIRE(ti.strides[0] == (6 * 22));
+    CATCH_REQUIRE(ti.sizes[1] == 6);
+    CATCH_REQUIRE(ti.strides[1] == 22);
+    CATCH_REQUIRE(ti.sizes[2] == (5 * 2));
+    CATCH_REQUIRE(ti.strides[2] == 2);
 }
 
-TEST_CASE("Roving Exclusion", "Collapses a 4D tensor to a 3D tensor") {
+CATCH_TEST_CASE("Roving Exclusion", "Collapses a 4D tensor to a 3D tensor") {
     int sizes[] = {3, 6, 5, 2};
     int strides[] = {6 * 22, 22, 2 * 2, 2};
     ::at::cuda::detail::TensorInfo<void, int> ti{nullptr, 4, sizes, strides};
-    REQUIRE(ti.collapseDims(2) == 1);
-    REQUIRE(ti.dims == 3);
-    REQUIRE(ti.sizes[0] == (3 * 6));
-    REQUIRE(ti.strides[0] == 22);
-    REQUIRE(ti.sizes[1] == 5);
-    REQUIRE(ti.strides[1] == 4);
-    REQUIRE(ti.sizes[2] == 2);
-    REQUIRE(ti.strides[2] == 2);
+    CATCH_REQUIRE(ti.collapseDims(2) == 1);
+    CATCH_REQUIRE(ti.dims == 3);
+    CATCH_REQUIRE(ti.sizes[0] == (3 * 6));
+    CATCH_REQUIRE(ti.strides[0] == 22);
+    CATCH_REQUIRE(ti.sizes[1] == 5);
+    CATCH_REQUIRE(ti.strides[1] == 4);
+    CATCH_REQUIRE(ti.sizes[2] == 2);
+    CATCH_REQUIRE(ti.strides[2] == 2);
 }
 
-TEST_CASE("Invalid Exclusion", "Attempts to exclude a nonexisting dimension") {
+CATCH_TEST_CASE("Invalid Exclusion", "Attempts to exclude a nonexisting dimension") {
     int sizes[] = {1, 1, 1};
     int strides[] = {17, 12, 3};
     ::at::cuda::detail::TensorInfo<void, int> ti{nullptr, 3, sizes, strides};
-    REQUIRE_THROWS(ti.collapseDims(5));
+    _CATCH_REQUIRE_THROWS(ti.collapseDims(5));
 } 
 
 #endif
diff --git a/aten/src/ATen/test/apply_utils_test.cpp b/aten/src/ATen/test/apply_utils_test.cpp
index 38027ba..22be6de 100644
--- a/aten/src/ATen/test/apply_utils_test.cpp
+++ b/aten/src/ATen/test/apply_utils_test.cpp
@@ -1,5 +1,5 @@
 #define CATCH_CONFIG_MAIN
-#include "catch.hpp"
+#include "catch_utils.hpp"
 
 #include "ATen/ATen.h"
 #include "ATen/CPUApplyUtils.h"
@@ -108,32 +108,32 @@
   });
 }
 
-TEST_CASE("apply utils test 2-dim small contiguous", "[cpu]") {
+CATCH_TEST_CASE("apply utils test 2-dim small contiguous", "[cpu]") {
   manual_seed(123, at::kCPU);
   test(CPU(kDouble), {2, 1}, -1, -1);
 }
 
-TEST_CASE("apply utils test 2-dim small", "[cpu]") {
+CATCH_TEST_CASE("apply utils test 2-dim small", "[cpu]") {
   manual_seed(123, at::kCPU);
   test(CPU(kDouble), {2, 1});
 }
 
-TEST_CASE("apply utils test 2-dim", "[cpu]") {
+CATCH_TEST_CASE("apply utils test 2-dim", "[cpu]") {
   manual_seed(123, at::kCPU);
   test(CPU(kDouble), {20, 10});
 }
 
-TEST_CASE("apply utils test 3-dim", "[cpu]") {
+CATCH_TEST_CASE("apply utils test 3-dim", "[cpu]") {
   manual_seed(123, at::kCPU);
   test(CPU(kDouble), {3, 4, 2});
 }
 
-TEST_CASE("apply utils test 3-dim medium", "[cpu]") {
+CATCH_TEST_CASE("apply utils test 3-dim medium", "[cpu]") {
   manual_seed(123, at::kCPU);
   test(CPU(kDouble), {3, 40, 2});
 }
 
-TEST_CASE("apply utils test 10-dim", "[cpu]") {
+CATCH_TEST_CASE("apply utils test 10-dim", "[cpu]") {
   manual_seed(123, at::kCPU);
   test(CPU(kDouble), {3, 4, 2, 5, 2, 1, 3, 4, 2, 3});
 }
diff --git a/aten/src/ATen/test/basic.cpp b/aten/src/ATen/test/basic.cpp
index 9498812..c64fdec 100644
--- a/aten/src/ATen/test/basic.cpp
+++ b/aten/src/ATen/test/basic.cpp
@@ -1,5 +1,5 @@
 #define CATCH_CONFIG_MAIN
-#include "catch.hpp"
+#include "catch_utils.hpp"
 
 #include "ATen/ATen.h"
 #include "ATen/core/Reduction.h"
@@ -20,66 +20,66 @@
 using Catch::Matchers::StartsWith;
 
 static void test(Type & type) {
-  SECTION( "resize" ) {
+  CATCH_SECTION( "resize" ) {
     auto a = type.tensor();
     a.resize_({3,4});
-    REQUIRE(a.numel() == 12);
+    CATCH_REQUIRE(a.numel() == 12);
     a.resize_({5, 7});
-    REQUIRE(a.numel() == 35);
+    CATCH_REQUIRE(a.numel() == 35);
 
   }
 
-  SECTION( "ones and dot" ) {
+  CATCH_SECTION( "ones and dot" ) {
     Tensor b0 = ones({1, 1}, type);
-    REQUIRE(2 == (b0+b0).sum().toCDouble());
+    CATCH_REQUIRE(2 == (b0+b0).sum().toCDouble());
 
     Tensor b1 = ones({1, 2}, type);
-    REQUIRE(4 == (b1+b1).sum().toCDouble());
+    CATCH_REQUIRE(4 == (b1+b1).sum().toCDouble());
 
     Tensor b = ones({3, 4}, type);
-    REQUIRE(24 == (b+b).sum().toCDouble());
-    REQUIRE(12 == b.numel());
-    REQUIRE(b.view(-1).dot(b.view(-1)).toCDouble() == 12);
+    CATCH_REQUIRE(24 == (b+b).sum().toCDouble());
+    CATCH_REQUIRE(12 == b.numel());
+    CATCH_REQUIRE(b.view(-1).dot(b.view(-1)).toCDouble() == 12);
   }
 
-  SECTION( "rand" ) {
+  CATCH_SECTION( "rand" ) {
     for(auto i = 0; i < 10; i++) {
       Tensor a = rand({3,4}, type.toScalarType(i % 2 == 0 ? kFloat : kDouble));
     }
   }
 
-  SECTION( "sort" ) {
+  CATCH_SECTION( "sort" ) {
     Tensor b = rand({3, 4}, type);
 
     auto z = b.sort(1);
     auto z_sorted = std::get<0>(z);
 
-    REQUIRE(z_sorted[0][0].toCFloat() < z_sorted[0][1].toCFloat());
+    CATCH_REQUIRE(z_sorted[0][0].toCFloat() < z_sorted[0][1].toCFloat());
   }
 
   if(type.backend() != Backend::CUDA)
-  SECTION( "randperm" ) {
+  CATCH_SECTION( "randperm" ) {
     Tensor b = randperm(15, type);
     Tensor rv, ri;
     std::tie(rv, ri) = sort(b, 0);
-    REQUIRE(rv[0].toCFloat() <= rv[1].toCFloat());
+    CATCH_REQUIRE(rv[0].toCFloat() <= rv[1].toCFloat());
   }
 
-  SECTION( "context" ) {
+  CATCH_SECTION( "context" ) {
     std::stringstream ss;
     ss << "context: " << std::hex << (int64_t)&globalContext() << std::endl;
   }
 
-  SECTION( "add" ) {
+  CATCH_SECTION( "add" ) {
     Tensor a = rand({3, 4}, type);
     Tensor b = rand({3, 4}, type);
     Tensor c = add(a, add(a, b));
     //TODO:0-dim Tensor d(3.f);
     Scalar d = 3.f;
-    REQUIRE( add(c, d).allclose(a + a + b + d) );
+    CATCH_REQUIRE( add(c, d).allclose(a + a + b + d) );
   }
 
-  SECTION( "loads of adds" ) {
+  CATCH_SECTION( "loads of adds" ) {
     auto begin = std::chrono::high_resolution_clock::now();
     Tensor d = ones({3, 4}, type);
     Tensor r = zeros({3, 4}, type);
@@ -89,10 +89,10 @@
     auto end = std::chrono::high_resolution_clock::now();
     //TODO TEST PERF?
     std::cout << std::dec << "   " << std::chrono::duration_cast<std::chrono::milliseconds>(end-begin).count() << " ms" << std::endl;
-    REQUIRE(norm(100000*d).toCDouble() == norm(r).toCDouble());
+    CATCH_REQUIRE(norm(100000*d).toCDouble() == norm(r).toCDouble());
   }
 
-  SECTION( "loads of adds (with copy)" ) {
+  CATCH_SECTION( "loads of adds (with copy)" ) {
     auto begin = std::chrono::high_resolution_clock::now();
     Tensor d = ones({3, 4}, type);
     Tensor r = zeros({3, 4}, type);
@@ -102,59 +102,59 @@
     auto end = std::chrono::high_resolution_clock::now();
     //TODO TEST PERF?
     std::cout << std::dec << "   " << std::chrono::duration_cast<std::chrono::milliseconds>(end-begin).count() << " ms" << std::endl;
-    REQUIRE(norm(100000*d).toCDouble() == norm(r).toCDouble());
+    CATCH_REQUIRE(norm(100000*d).toCDouble() == norm(r).toCDouble());
   }
 
-  SECTION( "isContiguous" ) {
+  CATCH_SECTION( "isContiguous" ) {
     Tensor a = rand({3, 4}, type);
-    REQUIRE(a.is_contiguous());
+    CATCH_REQUIRE(a.is_contiguous());
     a = a.transpose(0, 1);
-    REQUIRE(!a.is_contiguous());
+    CATCH_REQUIRE(!a.is_contiguous());
   }
 
-  SECTION( "permute" ) {
+  CATCH_SECTION( "permute" ) {
     Tensor a = rand({3, 4, 5}, type);
     Tensor b = a.permute({1, 2, 0});
-    REQUIRE(b.sizes().equals({4, 5, 3}));
-    REQUIRE(b.strides().equals({5, 1, 20}));
+    CATCH_REQUIRE(b.sizes().equals({4, 5, 3}));
+    CATCH_REQUIRE(b.strides().equals({5, 1, 20}));
   }
 
-  SECTION( "mm" ) {
+  CATCH_SECTION( "mm" ) {
     Tensor a = rand({3, 4}, type);
     Tensor b = rand({4}, type);
     Tensor c = mv(a, b);
-    REQUIRE(c.equal(addmv(zeros({3}, type), a, b, 0, 1)));
+    CATCH_REQUIRE(c.equal(addmv(zeros({3}, type), a, b, 0, 1)));
   }
 
-  SECTION( "squeeze" ) {
+  CATCH_SECTION( "squeeze" ) {
     Tensor a = rand({2, 1}, type);
     Tensor b = squeeze(a);
-    REQUIRE(b.dim() == 1);
+    CATCH_REQUIRE(b.dim() == 1);
     a = rand({1}, type);
     b = squeeze(a);
     //TODO 0-dim squeeze
-    REQUIRE(a[0].equal(b));
+    CATCH_REQUIRE(a[0].equal(b));
   }
 
-  SECTION( "copy" ) {
+  CATCH_SECTION( "copy" ) {
     Tensor a = zeros({4, 3}, type);
     Tensor e = rand({4, 3}, type);
     a.copy_(e);
-    REQUIRE(a.equal(e));
+    CATCH_REQUIRE(a.equal(e));
   }
 
-  SECTION( "copy (broadcasting)" ) {
+  CATCH_SECTION( "copy (broadcasting)" ) {
     Tensor a = zeros({4, 3}, type);
     Tensor e = rand({3}, type);
     a.copy_(e);
     for (int i = 0; i < 4; ++i) {
-      REQUIRE(a[i].equal(e));
+      CATCH_REQUIRE(a[i].equal(e));
     }
   }
 
-  SECTION( "abs(value)" ) {
+  CATCH_SECTION( "abs(value)" ) {
     Tensor r = at::abs(type.scalarTensor(-3));
-    REQUIRE(r.toCInt() == 3);
+    CATCH_REQUIRE(r.toCInt() == 3);
   }
 
 //TODO(zach): operator overloads
@@ -168,120 +168,120 @@
   }
 #endif
 
-  SECTION( "adding a value with a scalar" ) {
+  CATCH_SECTION( "adding a value with a scalar" ) {
     Tensor a = rand({4, 3}, type);
-    REQUIRE((ones({4,3}, type) + a).equal(add(a,1)));
+    CATCH_REQUIRE((ones({4,3}, type) + a).equal(add(a,1)));
   }
 
-  SECTION( "select" ) {
+  CATCH_SECTION( "select" ) {
     Tensor a = rand({3, 7}, type);
     auto a_13 = select(a, 1, 3);
     auto a_13_02 = select(select(a, 1, 3), 0, 2);
-    REQUIRE( a[0][3].equal(a_13[0]) );
-    REQUIRE( a[2][3].equal(a_13_02) );
+    CATCH_REQUIRE( a[0][3].equal(a_13[0]) );
+    CATCH_REQUIRE( a[2][3].equal(a_13_02) );
   }
 
-  SECTION( "zero-dim" ) {
+  CATCH_SECTION( "zero-dim" ) {
     Tensor a =  type.scalarTensor(4); //rand(type, {1});
 
     Tensor b = rand({3,4}, type);
-    REQUIRE((a + a).dim() == 0);
-    REQUIRE((1 + a).dim() == 0);
-    REQUIRE((b + a).dim() == 2);
-    REQUIRE((a + b).dim() == 2);
+    CATCH_REQUIRE((a + a).dim() == 0);
+    CATCH_REQUIRE((1 + a).dim() == 0);
+    CATCH_REQUIRE((b + a).dim() == 2);
+    CATCH_REQUIRE((a + b).dim() == 2);
     auto c = rand({3,4}, type);
-    REQUIRE(c[1][2].dim() == 0);
+    CATCH_REQUIRE(c[1][2].dim() == 0);
 
     auto f = rand({3,4}, type);
     f[2] = zeros({4}, type);
     f[1][0] = -1;
-    REQUIRE(f[2][0].toCDouble() == 0);
+    CATCH_REQUIRE(f[2][0].toCDouble() == 0);
   }
 
-  SECTION( "tensor from TH" ) {
+  CATCH_SECTION( "tensor from TH" ) {
     int a = 4;
     THFloatTensor *t = THFloatTensor_newWithSize2d(a, a);
     THFloatTensor_fill(t, a);
     Tensor tt = CPU(kFloat).unsafeTensorFromTH(t,false);
-    REQUIRE_NOTHROW(tt);
+    CATCH_REQUIRE_NOTHROW(tt);
   }
 
-  SECTION( "toCFloat" ) {
+  CATCH_SECTION( "toCFloat" ) {
     Tensor a = zeros({3,4});
     Tensor b = ones({3,7});
     Tensor c = cat({a,b},1);
-    REQUIRE(c.size(1) == 11);
+    CATCH_REQUIRE(c.size(1) == 11);
 
     Tensor e = rand({});
-    REQUIRE(*e.data<float>() == e.sum().toCFloat());
+    CATCH_REQUIRE(*e.data<float>() == e.sum().toCFloat());
   }
 
-  SECTION( "to string" ) {
+  CATCH_SECTION( "to string" ) {
     Tensor b = ones({3,7})*.0000001f;
     std::stringstream s;
     s << b << "\n";
     std::string expect = "1e-07 *";
-    REQUIRE(s.str().substr(0,expect.size()) == expect);
+    CATCH_REQUIRE(s.str().substr(0,expect.size()) == expect);
   }
-  SECTION("indexing by Scalar") {
+  CATCH_SECTION("indexing by Scalar") {
     Tensor tensor = arange(0, 10, kInt);
     Tensor one = ones({}, kInt);
     for (int64_t i = 0; i < tensor.numel(); ++i) {
-      REQUIRE(tensor[i].equal(one * i));
+      CATCH_REQUIRE(tensor[i].equal(one * i));
     }
     for (size_t i = 0; i < static_cast<uint64_t>(tensor.numel()); ++i) {
-      REQUIRE(tensor[i].equal(one * static_cast<int64_t>(i)));
+      CATCH_REQUIRE(tensor[i].equal(one * static_cast<int64_t>(i)));
     }
     for (int i = 0; i < tensor.numel(); ++i) {
-      REQUIRE(tensor[i].equal(one * i));
+      CATCH_REQUIRE(tensor[i].equal(one * i));
     }
     for (int16_t i = 0; i < tensor.numel(); ++i) {
-      REQUIRE(tensor[i].equal(one * i));
+      CATCH_REQUIRE(tensor[i].equal(one * i));
     }
     for (int8_t i = 0; i < tensor.numel(); ++i) {
-      REQUIRE(tensor[i].equal(one * i));
+      CATCH_REQUIRE(tensor[i].equal(one * i));
     }
-    REQUIRE_THROWS_WITH(
+    CATCH_REQUIRE_THROWS_WITH(
         tensor[Scalar(3.14)].equal(one),
         StartsWith(
             "Can only index tensors with integral scalars"));
   }
-  SECTION("indexing by zero-dim tensor") {
+  CATCH_SECTION("indexing by zero-dim tensor") {
     Tensor tensor = arange(0, 10, kInt);
     Tensor one = ones({}, kInt);
     for (int i = 0; i < tensor.numel(); ++i) {
-      REQUIRE(tensor[one * i].equal(one * i));
+      CATCH_REQUIRE(tensor[one * i].equal(one * i));
     }
-    REQUIRE_THROWS_WITH(
+    CATCH_REQUIRE_THROWS_WITH(
         tensor[ones({}) * 3.14].equal(one),
         StartsWith(
             "Can only index tensors with integral scalars"));
-    REQUIRE_THROWS_WITH(
+    CATCH_REQUIRE_THROWS_WITH(
         tensor[Tensor()].equal(one),
         StartsWith("Can only index with tensors that are defined"));
-    REQUIRE_THROWS_WITH(
+    CATCH_REQUIRE_THROWS_WITH(
         tensor[ones({2, 3, 4}, kInt)].equal(one),
         StartsWith("Can only index with tensors that are scalars (zero-dim)"));
   }
-  SECTION("dispatch") {
+  CATCH_SECTION("dispatch") {
     Tensor tensor = randn({20, 20});
     Tensor other = randn({20, 20});
     auto result = tensor.m(relu).m(mse_loss, other, Reduction::ElementwiseMean);
-    REQUIRE(result.allclose(mse_loss(relu(tensor), other)));
+    CATCH_REQUIRE(result.allclose(mse_loss(relu(tensor), other)));
   }
-  SECTION("core") {
+  CATCH_SECTION("core") {
     int i = CoreTest();
-    REQUIRE(i + 1 == CoreTest());
+    CATCH_REQUIRE(i + 1 == CoreTest());
   }
 }
 
-TEST_CASE( "basic tests CPU", "[cpu]" ) {
+CATCH_TEST_CASE( "basic tests CPU", "[cpu]" ) {
   manual_seed(123, at::kCPU);
 
   test(CPU(kFloat));
 }
 
-TEST_CASE( "basic tests GPU", "[cuda]" ) {
+CATCH_TEST_CASE( "basic tests GPU", "[cuda]" ) {
   manual_seed(123, at::kCUDA);
 
   if(at::hasCUDA()) {
diff --git a/aten/src/ATen/test/broadcast_test.cpp b/aten/src/ATen/test/broadcast_test.cpp
index cd5c43d..822a1d7 100644
--- a/aten/src/ATen/test/broadcast_test.cpp
+++ b/aten/src/ATen/test/broadcast_test.cpp
@@ -1,154 +1,154 @@
 #define CATCH_CONFIG_MAIN
-#include "catch.hpp"
+#include "catch_utils.hpp"
 
 #include "ATen/ATen.h"
 #include "test_seed.h"
 
 using namespace at;
 
-TEST_CASE( "broadcast", "[]" ) {
+CATCH_TEST_CASE( "broadcast", "[]" ) {
 
   manual_seed(123, at::kCPU);
 
   Type & T = CPU(kFloat);
 
   // 0) pre-req tests:
-  SECTION( "can't expand empty tensor" ) {
+  CATCH_SECTION( "can't expand empty tensor" ) {
     auto empty = randn({0}, T);
-    REQUIRE_THROWS(empty.expand({3}));
+    _CATCH_REQUIRE_THROWS(empty.expand({3}));
   }
 
   // 1) out-place function with 2 args
-  SECTION( "out-place function with 2 args" ) {
+  CATCH_SECTION( "out-place function with 2 args" ) {
 
-    SECTION( "basic" ) {
+    CATCH_SECTION( "basic" ) {
       auto a = randn({3, 1}, T);
       auto b = randn({5}, T);
       std::vector<int64_t> expanded_sizes = {3, 5};
-      REQUIRE((a + b).equal(a.expand(expanded_sizes) + b.expand(expanded_sizes)));
+      CATCH_REQUIRE((a + b).equal(a.expand(expanded_sizes) + b.expand(expanded_sizes)));
     }
 
-    SECTION( "with scalar" ) {
+    CATCH_SECTION( "with scalar" ) {
       auto aScalar = ones({1}, T);
       aScalar.unsafeGetTensorImpl()->maybe_zero_dim(true);
       auto b = randn({3, 5}, T);
-      REQUIRE((aScalar + b).equal(aScalar.expand(b.sizes()) + b.expand(b.sizes())));
+      CATCH_REQUIRE((aScalar + b).equal(aScalar.expand(b.sizes()) + b.expand(b.sizes())));
     }
 
-    SECTION( "old fallback behavior yields error" ) {
+    CATCH_SECTION( "old fallback behavior yields error" ) {
       auto a = randn({3, 5}, T);
       auto b = randn({5, 3}, T);
-      REQUIRE_THROWS(a + b);
+      _CATCH_REQUIRE_THROWS(a + b);
     }
 
-    SECTION( "with mismatched sizes" ) {
+    CATCH_SECTION( "with mismatched sizes" ) {
       auto a = randn({3, 5}, T);
       auto b = randn({7, 5}, T);
-      REQUIRE_THROWS(a + b);
+      _CATCH_REQUIRE_THROWS(a + b);
     }
   }
 
-  SECTION( "out-place function with 3 args" ) {
+  CATCH_SECTION( "out-place function with 3 args" ) {
 
-    SECTION( "basic" ) {
+    CATCH_SECTION( "basic" ) {
       auto a = randn({3, 1, 1}, T);
       auto b = randn({1, 2, 1}, T);
       auto c = randn({1, 1, 5}, T);
       std::vector<int64_t> expanded_sizes = {3, 2, 5};
-      REQUIRE((a + b + c).equal(a.expand(expanded_sizes) + b.expand(expanded_sizes) + c.expand(expanded_sizes)));
+      CATCH_REQUIRE((a + b + c).equal(a.expand(expanded_sizes) + b.expand(expanded_sizes) + c.expand(expanded_sizes)));
     }
 
-    SECTION( "with scalar" ) {
+    CATCH_SECTION( "with scalar" ) {
       auto aTensorScalar = ones({1}, T);
       aTensorScalar.unsafeGetTensorImpl()->maybe_zero_dim(true);
       auto b = randn({3, 2, 1}, T);
       auto c = randn({1, 2, 5}, T);
       std::vector<int64_t> expanded_sizes = {3, 2, 5};
-      REQUIRE(aTensorScalar.addcmul(b, c).equal(
+      CATCH_REQUIRE(aTensorScalar.addcmul(b, c).equal(
                 aTensorScalar.expand(expanded_sizes).addcmul(b.expand(expanded_sizes), c.expand(expanded_sizes))));
     }
 
-    SECTION( "old fallback behavior yields error" ) {
+    CATCH_SECTION( "old fallback behavior yields error" ) {
       auto a = randn({3, 2, 5}, T);
       auto b = randn({2, 3, 5}, T);
       auto c = randn({5, 3, 2}, T);
-      REQUIRE_THROWS(a.addcmul(b, c));
+      _CATCH_REQUIRE_THROWS(a.addcmul(b, c));
     }
 
-    SECTION( "with mismatched sizes" ){
+    CATCH_SECTION( "with mismatched sizes" ){
       auto a = randn({3, 2, 5}, T);
       auto b = randn({2, 3, 5}, T);
       auto c = randn({5, 5, 5}, T);
-      REQUIRE_THROWS(a.addcmul(b, c));
+      _CATCH_REQUIRE_THROWS(a.addcmul(b, c));
     }
   }
 
-  SECTION( "in-place function with 2 args" ) {
-    SECTION( "basic" ) {
+  CATCH_SECTION( "in-place function with 2 args" ) {
+    CATCH_SECTION( "basic" ) {
       auto a = randn({3, 5}, T);
       auto b = randn({3, 1}, T);
-      REQUIRE((a + b).equal(a + b.expand({3, 5})));
+      CATCH_REQUIRE((a + b).equal(a + b.expand({3, 5})));
     }
 
-    SECTION( "with scalar" ) {
+    CATCH_SECTION( "with scalar" ) {
       auto a = randn({3, 5}, T);
       auto bScalar = ones({1}, T);
       bScalar.unsafeGetTensorImpl()->maybe_zero_dim(true);
-      REQUIRE((a + bScalar).equal(a + bScalar.expand(a.sizes())));
+      CATCH_REQUIRE((a + bScalar).equal(a + bScalar.expand(a.sizes())));
     }
 
-    SECTION( "error: would have to expand inplace arg" ) {
+    CATCH_SECTION( "error: would have to expand inplace arg" ) {
       auto a = randn({1, 5}, T);
       auto b = randn({3, 1}, T);
-      REQUIRE_THROWS(a.add_(b));
+      _CATCH_REQUIRE_THROWS(a.add_(b));
     }
   }
 
-  SECTION( "in-place function with 3 args" ) {
+  CATCH_SECTION( "in-place function with 3 args" ) {
 
     auto a = randn({3, 5, 2}, T);
     auto b = randn({3, 1, 2}, T);
     auto c = randn({1, 5, 1}, T);
 
-    SECTION( "basic" ) {
+    CATCH_SECTION( "basic" ) {
       auto aClone = a.clone();
-      REQUIRE(a.addcmul_(b, c).equal(aClone.addcmul_(b.expand(a.sizes()), c.expand(a.sizes()))));
+      CATCH_REQUIRE(a.addcmul_(b, c).equal(aClone.addcmul_(b.expand(a.sizes()), c.expand(a.sizes()))));
     }
 
-    SECTION( "with scalar" ) {
+    CATCH_SECTION( "with scalar" ) {
       auto aClone = a.clone();
       auto bScalar = ones({1}, T);
       bScalar.unsafeGetTensorImpl()->maybe_zero_dim(true);
-      REQUIRE(a.addcmul_(bScalar, c).equal(aClone.addcmul_(bScalar.expand(a.sizes()), c.expand(a.sizes()))));
+      CATCH_REQUIRE(a.addcmul_(bScalar, c).equal(aClone.addcmul_(bScalar.expand(a.sizes()), c.expand(a.sizes()))));
     }
 
-    SECTION( "error: would have to expand inplace arg" ) {
+    CATCH_SECTION( "error: would have to expand inplace arg" ) {
       auto a = randn({1, 3, 5}, T);
       auto b = randn({4, 1, 1}, T);
       auto c = randn({1, 3, 1}, T);
-      REQUIRE_THROWS(a.addcmul_(b, c));
+      _CATCH_REQUIRE_THROWS(a.addcmul_(b, c));
     }
   }
 
-  SECTION( "explicit dim specification" ) {
+  CATCH_SECTION( "explicit dim specification" ) {
 
     auto a = randn({1}, T);
     auto b = randn({5, 3}, T);
     auto c = randn({3, 7}, T);
 
-    SECTION( "basic" ) {
-      REQUIRE(a.addmm(b, c).equal(a.expand({5,7}).addmm(b, c)));
+    CATCH_SECTION( "basic" ) {
+      CATCH_REQUIRE(a.addmm(b, c).equal(a.expand({5,7}).addmm(b, c)));
     }
 
-    SECTION( "with scalar" ) {
+    CATCH_SECTION( "with scalar" ) {
       Tensor aScalar = ones({1}, T);
       aScalar.unsafeGetTensorImpl()->maybe_zero_dim(true);
-      REQUIRE(aScalar.addmm(b, c).equal(aScalar.expand({5, 7}).addmm(b, c)));
+      CATCH_REQUIRE(aScalar.addmm(b, c).equal(aScalar.expand({5, 7}).addmm(b, c)));
     }
 
-    SECTION( "with mismatched sizes" ) {
+    CATCH_SECTION( "with mismatched sizes" ) {
       auto a = randn({3, 3}, T);
-      REQUIRE_THROWS(a.addmm(b, c));
+      _CATCH_REQUIRE_THROWS(a.addmm(b, c));
     }
   }
 }
diff --git a/aten/src/ATen/test/catch_utils.hpp b/aten/src/ATen/test/catch_utils.hpp
new file mode 100644
index 0000000..b9b0a87
--- /dev/null
+++ b/aten/src/ATen/test/catch_utils.hpp
@@ -0,0 +1,8 @@
+#pragma once
+
+#define CATCH_CONFIG_PREFIX_ALL
+#include <catch.hpp>
+
+// CATCH_REQUIRE_THROWS is not defined identically to REQUIRE_THROWS and causes warning;
+// define our own version that doesn't warn.
+#define _CATCH_REQUIRE_THROWS( ... ) INTERNAL_CATCH_THROWS( "CATCH_REQUIRE_THROWS", Catch::ResultDisposition::Normal, __VA_ARGS__ )
diff --git a/aten/src/ATen/test/cuda_half_test.cu b/aten/src/ATen/test/cuda_half_test.cu
index fa00e53..cce2671 100644
--- a/aten/src/ATen/test/cuda_half_test.cu
+++ b/aten/src/ATen/test/cuda_half_test.cu
@@ -1,5 +1,5 @@
 #define CATCH_CONFIG_MAIN
-#include "catch.hpp"
+#include "catch_utils.hpp"
 
 #include "ATen/ATen.h"
 #include "ATen/cuda/NumericLimits.cuh"
@@ -82,9 +82,9 @@
   kernel<<<1,1>>>();
 }
 
-TEST_CASE( "half common math functions tests in device", "[cuda]" ) {
+CATCH_TEST_CASE( "half common math functions tests in device", "[cuda]" ) {
   launch_function();
   cudaError_t err = cudaDeviceSynchronize();
-  REQUIRE(err == cudaSuccess);
+  CATCH_REQUIRE(err == cudaSuccess);
 }
 
diff --git a/aten/src/ATen/test/cuda_optional_test.cu b/aten/src/ATen/test/cuda_optional_test.cu
index 9956dcf..b64c530 100644
--- a/aten/src/ATen/test/cuda_optional_test.cu
+++ b/aten/src/ATen/test/cuda_optional_test.cu
@@ -1,5 +1,5 @@
 #define CATCH_CONFIG_MAIN
-#include "catch.hpp"
+#include "catch_utils.hpp"
 
 #include "ATen/ATen.h"
 #include "ATen/optional.h"
@@ -8,15 +8,15 @@
 
 using namespace at;
 
-TEST_CASE( "optional in cuda files", "[cuda]" ) {
+CATCH_TEST_CASE( "optional in cuda files", "[cuda]" ) {
   at::optional<int64_t> trivially_destructible;
   at::optional<std::vector<int64_t>> non_trivially_destructible;
-  REQUIRE(!trivially_destructible.has_value());
-  REQUIRE(!non_trivially_destructible.has_value());
+  CATCH_REQUIRE(!trivially_destructible.has_value());
+  CATCH_REQUIRE(!non_trivially_destructible.has_value());
 
   trivially_destructible = {5};
   non_trivially_destructible = std::vector<int64_t>{5, 10};
-  REQUIRE(trivially_destructible.has_value());
-  REQUIRE(non_trivially_destructible.has_value());
+  CATCH_REQUIRE(trivially_destructible.has_value());
+  CATCH_REQUIRE(non_trivially_destructible.has_value());
 }
 
diff --git a/aten/src/ATen/test/cuda_packedtensoraccessor_test.cu b/aten/src/ATen/test/cuda_packedtensoraccessor_test.cu
index f1eb5cb..a529f38 100644
--- a/aten/src/ATen/test/cuda_packedtensoraccessor_test.cu
+++ b/aten/src/ATen/test/cuda_packedtensoraccessor_test.cu
@@ -1,5 +1,5 @@
 #define CATCH_CONFIG_MAIN
-#include "catch.hpp"
+#include "catch_utils.hpp"
 
 #include "ATen/ATen.h"
 #include "test_seed.h"
@@ -22,7 +22,7 @@
   }
 }
 
-TEST_CASE( "test PackedTensorAccessor and Tensor.packed_accessor", "[cuda]" ) {
+CATCH_TEST_CASE( "test PackedTensorAccessor and Tensor.packed_accessor", "[cuda]" ) {
   manual_seed(123, at::kCPU);
   manual_seed(123, at::kCUDA);
 
@@ -38,9 +38,9 @@
   
   test_tensor_packed_accessor_kernel<<<1, 1, 0, stream>>>(resa, t1a, t2a);
   cudaError_t err = cudaDeviceSynchronize();
-  REQUIRE(err == cudaSuccess);
+  CATCH_REQUIRE(err == cudaSuccess);
 
   auto expected = mv(t1, t2);
 
-  REQUIRE(res.allclose(expected));
+  CATCH_REQUIRE(res.allclose(expected));
 }
diff --git a/aten/src/ATen/test/cuda_rng_test.cpp b/aten/src/ATen/test/cuda_rng_test.cpp
index d32903d..7b14174 100644
--- a/aten/src/ATen/test/cuda_rng_test.cpp
+++ b/aten/src/ATen/test/cuda_rng_test.cpp
@@ -1,5 +1,5 @@
 #define CATCH_CONFIG_MAIN
-#include "catch.hpp"
+#include "catch_utils.hpp"
 
 #include "ATen/ATen.h"
 #include "cuda.h"
@@ -21,7 +21,7 @@
   }
 };
 
-TEST_CASE( "CUDA RNG test", "[cuda]" ) {
-  SECTION( "multithread" )
+CATCH_TEST_CASE( "CUDA RNG test", "[cuda]" ) {
+  CATCH_SECTION( "multithread" )
     testCudaRNGMultithread();
 }
diff --git a/aten/src/ATen/test/cudnn_test.cpp b/aten/src/ATen/test/cudnn_test.cpp
index 31786e8..4391867 100644
--- a/aten/src/ATen/test/cudnn_test.cpp
+++ b/aten/src/ATen/test/cudnn_test.cpp
@@ -1,5 +1,5 @@
 #define CATCH_CONFIG_MAIN
-#include "catch.hpp"
+#include "catch_utils.hpp"
 
 #include "ATen/ATen.h"
 #include "ATen/cudnn/Descriptors.h"
@@ -9,7 +9,7 @@
 using namespace at;
 using namespace at::native;
 
-TEST_CASE( "cudnn", "[cuda]" ) {
+CATCH_TEST_CASE( "cudnn", "[cuda]" ) {
   manual_seed(123, at::kCUDA);
 
 #if CUDNN_VERSION < 7000
@@ -18,8 +18,8 @@
   desc1.initialize_rng(at::CUDA(kByte), handle, 0.5, 42);
   desc2.set(handle, 0.5, desc1.state);
 
-  REQUIRE(desc1.desc()->dropout == desc2.desc()->dropout);
-  REQUIRE(desc1.desc()->nstates == desc2.desc()->nstates);
-  REQUIRE(desc1.desc()->states == desc2.desc()->states);
+  CATCH_REQUIRE(desc1.desc()->dropout == desc2.desc()->dropout);
+  CATCH_REQUIRE(desc1.desc()->nstates == desc2.desc()->nstates);
+  CATCH_REQUIRE(desc1.desc()->states == desc2.desc()->states);
 #endif
 }
diff --git a/aten/src/ATen/test/dlconvertor_test.cpp b/aten/src/ATen/test/dlconvertor_test.cpp
index 4882929..bf0cf93 100644
--- a/aten/src/ATen/test/dlconvertor_test.cpp
+++ b/aten/src/ATen/test/dlconvertor_test.cpp
@@ -1,5 +1,5 @@
 #define CATCH_CONFIG_MAIN
-#include "catch.hpp"
+#include "catch_utils.hpp"
 
 #include "ATen/ATen.h"
 #include "ATen/DLConvertor.h"
@@ -11,17 +11,17 @@
 
 using namespace at;
 
-TEST_CASE( "dlconvertor", "[cpu]" ) {
+CATCH_TEST_CASE( "dlconvertor", "[cpu]" ) {
 
   manual_seed(123, at::kCPU);
 
-  INFO( "convert ATen to DLTensor" );
+  CATCH_INFO( "convert ATen to DLTensor" );
 
   Tensor a = rand({3,4});
   DLManagedTensor* dlMTensor = toDLPack(a);
 
-  INFO( "convert DLTensor to ATen" );
+  CATCH_INFO( "convert DLTensor to ATen" );
   Tensor b = fromDLPack(dlMTensor);
 
-  REQUIRE(a.equal(b));
+  CATCH_REQUIRE(a.equal(b));
 }
diff --git a/aten/src/ATen/test/half_test.cpp b/aten/src/ATen/test/half_test.cpp
index 3b29448..3217770 100644
--- a/aten/src/ATen/test/half_test.cpp
+++ b/aten/src/ATen/test/half_test.cpp
@@ -1,5 +1,5 @@
 #define CATCH_CONFIG_MAIN
-#include "catch.hpp"
+#include "catch_utils.hpp"
 
 #include <ATen/ATen.h>
 #include <iostream>
@@ -12,53 +12,53 @@
 
 using namespace at;
 
-TEST_CASE( "half arithmetic", "[]" ) {
+CATCH_TEST_CASE( "half arithmetic", "[]" ) {
   Half zero = 0;
   Half one = 1;
-  REQUIRE(zero + one == one);
-  REQUIRE(zero + zero == zero);
-  REQUIRE(zero * one == zero);
-  REQUIRE(one * one == one);
-  REQUIRE(one / one == one);
-  REQUIRE(one - one == zero);
-  REQUIRE(one - zero == one);
-  REQUIRE(zero - one == -one);
-  REQUIRE(one + one == Half(2));
-  REQUIRE(one + one == 2);
+  CATCH_REQUIRE(zero + one == one);
+  CATCH_REQUIRE(zero + zero == zero);
+  CATCH_REQUIRE(zero * one == zero);
+  CATCH_REQUIRE(one * one == one);
+  CATCH_REQUIRE(one / one == one);
+  CATCH_REQUIRE(one - one == zero);
+  CATCH_REQUIRE(one - zero == one);
+  CATCH_REQUIRE(zero - one == -one);
+  CATCH_REQUIRE(one + one == Half(2));
+  CATCH_REQUIRE(one + one == 2);
 }
 
-TEST_CASE( "half comparisons", "[]" ) {
+CATCH_TEST_CASE( "half comparisons", "[]" ) {
   Half zero = 0;
   Half one = 1;
-  REQUIRE(zero < one);
-  REQUIRE(zero < 1);
-  REQUIRE(1 > zero);
-  REQUIRE(0 >= zero);
-  REQUIRE(0 != one);
-  REQUIRE(zero == 0);
-  REQUIRE(zero == zero);
-  REQUIRE(zero == -zero);
+  CATCH_REQUIRE(zero < one);
+  CATCH_REQUIRE(zero < 1);
+  CATCH_REQUIRE(1 > zero);
+  CATCH_REQUIRE(0 >= zero);
+  CATCH_REQUIRE(0 != one);
+  CATCH_REQUIRE(zero == 0);
+  CATCH_REQUIRE(zero == zero);
+  CATCH_REQUIRE(zero == -zero);
 }
 
-TEST_CASE( "half cast", "[]" ) {
+CATCH_TEST_CASE( "half cast", "[]" ) {
   Half value = 1.5f;
-  REQUIRE((int)value == 1);
-  REQUIRE((short)value == 1);
-  REQUIRE((long long)value == 1LL);
-  REQUIRE((float)value == 1.5f);
-  REQUIRE((double)value == 1.5);
-  REQUIRE((bool)value == true);
-  REQUIRE((bool)Half(0.0f) == false);
+  CATCH_REQUIRE((int)value == 1);
+  CATCH_REQUIRE((short)value == 1);
+  CATCH_REQUIRE((long long)value == 1LL);
+  CATCH_REQUIRE((float)value == 1.5f);
+  CATCH_REQUIRE((double)value == 1.5);
+  CATCH_REQUIRE((bool)value == true);
+  CATCH_REQUIRE((bool)Half(0.0f) == false);
 }
 
-TEST_CASE( "half construction", "[]" ) {
-  REQUIRE(Half((short)3) == Half(3.0f));
-  REQUIRE(Half((unsigned short)3) == Half(3.0f));
-  REQUIRE(Half(3) == Half(3.0f));
-  REQUIRE(Half(3U) == Half(3.0f));
-  REQUIRE(Half(3LL) == Half(3.0f));
-  REQUIRE(Half(3ULL) == Half(3.0f));
-  REQUIRE(Half(3.5) == Half(3.5f));
+CATCH_TEST_CASE( "half construction", "[]" ) {
+  CATCH_REQUIRE(Half((short)3) == Half(3.0f));
+  CATCH_REQUIRE(Half((unsigned short)3) == Half(3.0f));
+  CATCH_REQUIRE(Half(3) == Half(3.0f));
+  CATCH_REQUIRE(Half(3U) == Half(3.0f));
+  CATCH_REQUIRE(Half(3LL) == Half(3.0f));
+  CATCH_REQUIRE(Half(3ULL) == Half(3.0f));
+  CATCH_REQUIRE(Half(3.5) == Half(3.5f));
 }
 
 static std::string to_string(const Half& h) {
@@ -67,22 +67,22 @@
   return ss.str();
 }
 
-TEST_CASE( "half to string", "[]" ) {
-  REQUIRE(to_string(Half(3.5f)) == "3.5");
-  REQUIRE(to_string(Half(-100.0f)) == "-100");
+CATCH_TEST_CASE( "half to string", "[]" ) {
+  CATCH_REQUIRE(to_string(Half(3.5f)) == "3.5");
+  CATCH_REQUIRE(to_string(Half(-100.0f)) == "-100");
 }
 
-TEST_CASE( "half numeric limits", "[]" ) {
+CATCH_TEST_CASE( "half numeric limits", "[]" ) {
   using limits = std::numeric_limits<Half>;
-  REQUIRE(limits::lowest() == -65504.0f);
-  REQUIRE(limits::max() == 65504.0f);
-  REQUIRE(limits::min() > 0);
-  REQUIRE(limits::min() < 1);
-  REQUIRE(limits::denorm_min() > 0);
-  REQUIRE(limits::denorm_min() / 2  == 0);
-  REQUIRE(limits::infinity() == std::numeric_limits<float>::infinity());
-  REQUIRE(limits::quiet_NaN() != limits::quiet_NaN());
-  REQUIRE(limits::signaling_NaN() != limits::signaling_NaN());
+  CATCH_REQUIRE(limits::lowest() == -65504.0f);
+  CATCH_REQUIRE(limits::max() == 65504.0f);
+  CATCH_REQUIRE(limits::min() > 0);
+  CATCH_REQUIRE(limits::min() < 1);
+  CATCH_REQUIRE(limits::denorm_min() > 0);
+  CATCH_REQUIRE(limits::denorm_min() / 2  == 0);
+  CATCH_REQUIRE(limits::infinity() == std::numeric_limits<float>::infinity());
+  CATCH_REQUIRE(limits::quiet_NaN() != limits::quiet_NaN());
+  CATCH_REQUIRE(limits::signaling_NaN() != limits::signaling_NaN());
 }
 
 // Check the declared type of members of numeric_limits<Half> matches
@@ -119,7 +119,7 @@
 ASSERT_SAME_TYPE(traps);
 ASSERT_SAME_TYPE(tinyness_before);
 
-TEST_CASE( "half common math functions test", "[]" ) {
+CATCH_TEST_CASE( "half common math functions test", "[]" ) {
   float threshold = 0.00001;
   assert(std::abs(std::lgamma(Half(10.0)) - std::lgamma(10.0f)) <= threshold);
   assert(std::abs(std::exp(Half(1.0)) - std::exp(1.0f)) <= threshold);
diff --git a/aten/src/ATen/test/integer_divider_test.cu b/aten/src/ATen/test/integer_divider_test.cu
index 4c63ab3..d09a423 100644
--- a/aten/src/ATen/test/integer_divider_test.cu
+++ b/aten/src/ATen/test/integer_divider_test.cu
@@ -1,5 +1,5 @@
 #define CATCH_CONFIG_MAIN
-#include "catch.hpp"
+#include "catch_utils.hpp"
 
 // Test IntegerDivider: this tests *all* 32-bit pairs (a, b) where a % b is 0 or
 // (b-1), so it takes a few minutes to run.
@@ -62,18 +62,18 @@
     cudaError_t err;
 
     err = cudaMalloc(&dividersBuf_, NUM_CASES * sizeof(IntDivider<Value>));
-    REQUIRE(err == cudaSuccess);
+    CATCH_REQUIRE(err == cudaSuccess);
     err = cudaMalloc(&testCasesBuf_, NUM_CASES * sizeof(TestCase<Value>));
-    REQUIRE(err == cudaSuccess);
+    CATCH_REQUIRE(err == cudaSuccess);
   }
 
   ~IntDividerTester() {
     cudaError_t err;
 
     err = cudaFree(dividersBuf_);
-    REQUIRE(err == cudaSuccess);
+    CATCH_REQUIRE(err == cudaSuccess);
     err = cudaFree(testCasesBuf_);
-    REQUIRE(err == cudaSuccess);
+    CATCH_REQUIRE(err == cudaSuccess);
   }
 
   void addTestCase(Value dividend, Value divisor, int steps) {
@@ -92,18 +92,18 @@
     cudaError_t err;
 
     if (testCases_.empty()) return;
-    REQUIRE(!dividers_.empty());
+    CATCH_REQUIRE(!dividers_.empty());
 
-    REQUIRE(dividers_.size() <= NUM_CASES);
-    REQUIRE(testCases_.size() <= NUM_CASES);
+    CATCH_REQUIRE(dividers_.size() <= NUM_CASES);
+    CATCH_REQUIRE(testCases_.size() <= NUM_CASES);
     err = cudaMemcpy(dividersBuf_, dividers_.data(),
                      dividers_.size() * sizeof(IntDivider<Value>),
                      cudaMemcpyHostToDevice);
-    REQUIRE(err == cudaSuccess);
+    CATCH_REQUIRE(err == cudaSuccess);
     err = cudaMemcpy(testCasesBuf_, testCases_.data(),
                      testCases_.size() * sizeof(TestCase<Value>),
                      cudaMemcpyHostToDevice);
-    REQUIRE(err == cudaSuccess);
+    CATCH_REQUIRE(err == cudaSuccess);
 
     int numCases = testCases_.size();
     testIntDivider<Value><<<512, 512>>>(
@@ -180,11 +180,11 @@
   tester.flush();
 }
 
-TEST_CASE( "CUDA integer divider", "[cuda]" ) {
+CATCH_TEST_CASE( "CUDA integer divider", "[cuda]" ) {
 
   testUint64Divider();
   testUint32Divider();
 
   cudaError_t err = cudaDeviceSynchronize();
-  REQUIRE(err == cudaSuccess);
+  CATCH_REQUIRE(err == cudaSuccess);
 }
diff --git a/aten/src/ATen/test/native_test.cpp b/aten/src/ATen/test/native_test.cpp
index e10de30..4c57b7d 100644
--- a/aten/src/ATen/test/native_test.cpp
+++ b/aten/src/ATen/test/native_test.cpp
@@ -1,5 +1,5 @@
 #define CATCH_CONFIG_MAIN
-#include "catch.hpp"
+#include "catch_utils.hpp"
 
 #include "ATen/ATen.h"
 #include "test_seed.h"
@@ -9,18 +9,18 @@
 using Catch::Matchers::StartsWith;
 
 #define REQUIRE_EQUAL(t1, t2) \
-  REQUIRE(t1.equal(t2));
+  CATCH_REQUIRE(t1.equal(t2));
 
 #define REQUIRE_ALLCLOSE(t1, t2)   \
-  REQUIRE(t1.is_same_size(t2));    \
-  REQUIRE(t1.allclose(t2));
+  CATCH_REQUIRE(t1.is_same_size(t2));    \
+  CATCH_REQUIRE(t1.allclose(t2));
 
 #define REQUIRE_ALLCLOSE_TOLERANCES(t1, t2, atol, rtol)   \
-  REQUIRE(t1.is_same_size(t2));    \
-  REQUIRE(t1.allclose(t2, atol, rtol));
+  CATCH_REQUIRE(t1.is_same_size(t2));    \
+  CATCH_REQUIRE(t1.allclose(t2, atol, rtol));
 
 void requireEqualTensorList(TensorList t1, TensorList t2) {
-  REQUIRE(t1.size() == t2.size());
+  CATCH_REQUIRE(t1.size() == t2.size());
   for (size_t i = 0; i < t1.size(); ++i) {
     REQUIRE_EQUAL(t1[ i ], t2[ i ]);
   }
@@ -29,7 +29,7 @@
 void test(Type & T, Type & AccT) {
   auto t = randn({3, 3}, T);
 
-  SECTION( "split: test method, type, namespace give same result" ) {
+  CATCH_SECTION( "split: test method, type, namespace give same result" ) {
     auto splitMethod = t.split(1, 0);
     auto splitType = T.split(t, 1, 0);
     auto splitNs = at::split(t, 1, 0);
@@ -40,7 +40,7 @@
     REQUIRE_EQUAL(at::cat(splitMethod, 0), t);
   }
 
-  SECTION( "chunk: test method, type, namespace give same result" ) {
+  CATCH_SECTION( "chunk: test method, type, namespace give same result" ) {
     // test method, type, namespace give same result
     auto chunkMethod = t.chunk(3, 0);
     auto chunkType = T.chunk(t, 3, 0);
@@ -53,7 +53,7 @@
   }
 
   // stack
-  SECTION( "stack" ) {
+  CATCH_SECTION( "stack" ) {
     auto x = rand({2, 3, 4});
     auto y = rand({2, 3, 4});
     auto z = rand({2, 3, 4});
@@ -66,36 +66,36 @@
       expected_size.insert(expected_size.end(), x.sizes().begin() + dim, x.sizes().end());
 
       REQUIRE_EQUAL(res, res_neg);
-      REQUIRE(res.sizes().equals(expected_size));
+      CATCH_REQUIRE(res.sizes().equals(expected_size));
       REQUIRE_EQUAL(res.select(dim, 0), x);
       REQUIRE_EQUAL(res.select(dim, 1), y);
       REQUIRE_EQUAL(res.select(dim, 2), z);
     }
   }
 
-  SECTION( "size / stride" ) {
+  CATCH_SECTION( "size / stride" ) {
     auto scalar = randn({}, T);
-    REQUIRE_THROWS_WITH(scalar.size(0), StartsWith("dimension specified as 0 but tensor has no dimensions"));
-    REQUIRE_THROWS_WITH(scalar.size(-1), StartsWith("dimension specified as -1 but tensor has no dimensions"));
-    REQUIRE_THROWS_WITH(scalar.stride(0), StartsWith("dimension specified as 0 but tensor has no dimensions"));
-    REQUIRE_THROWS_WITH(scalar.stride(-1), StartsWith("dimension specified as -1 but tensor has no dimensions"));
+    CATCH_REQUIRE_THROWS_WITH(scalar.size(0), StartsWith("dimension specified as 0 but tensor has no dimensions"));
+    CATCH_REQUIRE_THROWS_WITH(scalar.size(-1), StartsWith("dimension specified as -1 but tensor has no dimensions"));
+    CATCH_REQUIRE_THROWS_WITH(scalar.stride(0), StartsWith("dimension specified as 0 but tensor has no dimensions"));
+    CATCH_REQUIRE_THROWS_WITH(scalar.stride(-1), StartsWith("dimension specified as -1 but tensor has no dimensions"));
 
     auto empty = randn({0}, T);
-    REQUIRE(empty.size(0) == 0);
-    REQUIRE(empty.size(-1) == 0);
-    REQUIRE(empty.stride(0) == 1);
-    REQUIRE(empty.stride(-1) == 1);
+    CATCH_REQUIRE(empty.size(0) == 0);
+    CATCH_REQUIRE(empty.size(-1) == 0);
+    CATCH_REQUIRE(empty.stride(0) == 1);
+    CATCH_REQUIRE(empty.stride(-1) == 1);
   }
 
   // matmul
-  SECTION( "matmul" ) {
+  CATCH_SECTION( "matmul" ) {
     auto scalar = randn({}, T);
     auto d1 = randn({3}, T);
     auto d2 = randn({2, 3}, T);
 
     // 0-d
-    REQUIRE_THROWS_WITH(scalar.matmul(d2), Catch::StartsWith("both arguments to matmul need to be at least 1D"));
-    REQUIRE_THROWS_WITH(d2.matmul(scalar), Catch::StartsWith("both arguments to matmul need to be at least 1D"));
+    CATCH_REQUIRE_THROWS_WITH(scalar.matmul(d2), Catch::StartsWith("both arguments to matmul need to be at least 1D"));
+    CATCH_REQUIRE_THROWS_WITH(d2.matmul(scalar), Catch::StartsWith("both arguments to matmul need to be at least 1D"));
 
     // 1-d
     REQUIRE_ALLCLOSE(d1.matmul(d1), d1.dot(d1));
@@ -140,11 +140,11 @@
 
     // non-expandable case
     auto d5wrong = randn({2, 4, 2, 4, 3, 2}, T);
-    REQUIRE_THROWS_WITH(d5.matmul(d5wrong), Catch::Contains("must match the size"));
+    CATCH_REQUIRE_THROWS_WITH(d5.matmul(d5wrong), Catch::Contains("must match the size"));
   }
 
   // _standard_gamma_grad
-  SECTION( "_standard_gamma_grad" ) {
+  CATCH_SECTION( "_standard_gamma_grad" ) {
     // check empty
     auto empty = ones({0}, T);
     REQUIRE_EQUAL(empty, at::_standard_gamma_grad(empty, empty));
@@ -158,10 +158,10 @@
     // check mixing types
     auto t1 = randn({3, 4}, T);
     auto t2 = randn({3, 4}, T).toType(kDouble);
-    REQUIRE_THROWS_WITH(at::_standard_gamma_grad(t1, t2), Catch::StartsWith("expected scalar type"));
+    CATCH_REQUIRE_THROWS_WITH(at::_standard_gamma_grad(t1, t2), Catch::StartsWith("expected scalar type"));
   }
 
-  SECTION( "where" ) {
+  CATCH_SECTION( "where" ) {
     // empty
     auto empty = ones({0}, T);
     auto &bT = T.toScalarType(ScalarType::Byte);
@@ -180,13 +180,13 @@
   }
 }
 
-TEST_CASE( "native test CPU", "[cpu]" ) {
+CATCH_TEST_CASE( "native test CPU", "[cpu]" ) {
   manual_seed(123, at::kCPU);
 
   test(CPU(kFloat), CPU(kDouble));
 }
 
-TEST_CASE( "native test CUDA", "[cuda]" ) {
+CATCH_TEST_CASE( "native test CUDA", "[cuda]" ) {
   manual_seed(123, at::kCUDA);
 
   if (at::hasCUDA()) {
diff --git a/aten/src/ATen/test/scalar_tensor_test.cpp b/aten/src/ATen/test/scalar_tensor_test.cpp
index d52dc27..964f626 100644
--- a/aten/src/ATen/test/scalar_tensor_test.cpp
+++ b/aten/src/ATen/test/scalar_tensor_test.cpp
@@ -1,5 +1,5 @@
 #define CATCH_CONFIG_MAIN
-#include "catch.hpp"
+#include "catch_utils.hpp"
 
 #include "ATen/ATen.h"
 #include "test_seed.h"
@@ -18,14 +18,14 @@
       _passed = true;                                           \
       els;                                                      \
     } catch (std::exception &e) {                               \
-      REQUIRE(!_passed);                                        \
+      CATCH_REQUIRE(!_passed);                                        \
       catc;                                                     \
     }                                                           \
   }
 
 void require_equal_size_dim(const Tensor &lhs, const Tensor &rhs) {
-  REQUIRE(lhs.dim() == rhs.dim());
-  REQUIRE(lhs.sizes().equals(rhs.sizes()));
+  CATCH_REQUIRE(lhs.dim() == rhs.dim());
+  CATCH_REQUIRE(lhs.sizes().equals(rhs.sizes()));
 }
 
 bool should_expand(const IntList &from_size, const IntList &to_size) {
@@ -49,15 +49,15 @@
   for (auto s = sizes.begin(); s != sizes.end(); ++s) {
     // verify that the dim, sizes, strides, etc match what was requested.
     auto t = ones(*s, T);
-    REQUIRE((size_t)t.dim() == s->size());
-    REQUIRE((size_t)t.ndimension() == s->size());
-    REQUIRE(t.sizes().equals(*s));
-    REQUIRE(t.strides().size() == s->size());
+    CATCH_REQUIRE((size_t)t.dim() == s->size());
+    CATCH_REQUIRE((size_t)t.ndimension() == s->size());
+    CATCH_REQUIRE(t.sizes().equals(*s));
+    CATCH_REQUIRE(t.strides().size() == s->size());
     auto numel = std::accumulate(s->begin(), s->end(), 1, std::multiplies<int64_t>());
-    REQUIRE(t.numel() == numel);
+    CATCH_REQUIRE(t.numel() == numel);
     // verify we can output
     std::stringstream ss;
-    REQUIRE_NOTHROW(ss << t << std::endl);
+    CATCH_REQUIRE_NOTHROW(ss << t << std::endl);
 
     // set_
     auto t2 = ones(*s, T);
@@ -65,22 +65,22 @@
     require_equal_size_dim(t2, ones({0}, T));
 
     // unsqueeze
-    REQUIRE(t.unsqueeze(0).dim() == t.dim() + 1);
+    CATCH_REQUIRE(t.unsqueeze(0).dim() == t.dim() + 1);
 
     // unsqueeze_
     {
       auto t2 = ones(*s, T);
       auto r = t2.unsqueeze_(0);
-      REQUIRE(r.dim() == t.dim() + 1);
+      CATCH_REQUIRE(r.dim() == t.dim() + 1);
     }
 
     // squeeze (with dimension argument)
     if (t.dim() == 0 || t.sizes()[0] == 1) {
-      REQUIRE(t.squeeze(0).dim() == std::max<int64_t>(t.dim() - 1, 0));
+      CATCH_REQUIRE(t.squeeze(0).dim() == std::max<int64_t>(t.dim() - 1, 0));
     } else {
       // In PyTorch, it is a no-op to try to squeeze a dimension that has size != 1;
       // in NumPy this is an error.
-      REQUIRE(t.squeeze(0).dim() == t.dim());
+      CATCH_REQUIRE(t.squeeze(0).dim() == t.dim());
     }
 
     // squeeze (with no dimension argument)
@@ -99,11 +99,11 @@
       // squeeze_ (with dimension argument)
       auto t2 = ones(*s, T);
       if (t2.dim() == 0 ||  t2.sizes()[0] == 1) {
-        REQUIRE(t2.squeeze_(0).dim() == std::max<int64_t>(t.dim() - 1, 0));
+        CATCH_REQUIRE(t2.squeeze_(0).dim() == std::max<int64_t>(t.dim() - 1, 0));
       } else {
         // In PyTorch, it is a no-op to try to squeeze a dimension that has size != 1;
         // in NumPy this is an error.
-        REQUIRE(t2.squeeze_(0).dim() == t.dim());
+        CATCH_REQUIRE(t2.squeeze_(0).dim() == t.dim());
       }
     }
 
@@ -122,31 +122,31 @@
 
     // reduce (with dimension argument and with 1 return argument)
     if (t.numel() != 0) {
-      REQUIRE(t.sum(0).dim() == std::max<int64_t>(t.dim() - 1, 0));
+      CATCH_REQUIRE(t.sum(0).dim() == std::max<int64_t>(t.dim() - 1, 0));
     } else {
-      REQUIRE(t.sum(0).equal(at::zeros({}, T)));
+      CATCH_REQUIRE(t.sum(0).equal(at::zeros({}, T)));
     }
 
     // reduce (with dimension argument and with 2 return arguments)
     if (t.numel() != 0) {
       auto ret = t.min(0);
-      REQUIRE(std::get<0>(ret).dim() == std::max<int64_t>(t.dim() - 1, 0));
-      REQUIRE(std::get<1>(ret).dim() == std::max<int64_t>(t.dim() - 1, 0));
+      CATCH_REQUIRE(std::get<0>(ret).dim() == std::max<int64_t>(t.dim() - 1, 0));
+      CATCH_REQUIRE(std::get<1>(ret).dim() == std::max<int64_t>(t.dim() - 1, 0));
     } else {
-      REQUIRE_THROWS(t.min(0));
+      _CATCH_REQUIRE_THROWS(t.min(0));
     }
 
     // simple indexing
     if (t.dim() > 0 && t.numel() != 0) {
-      REQUIRE(t[0].dim() == std::max<int64_t>(t.dim() - 1, 0));
+      CATCH_REQUIRE(t[0].dim() == std::max<int64_t>(t.dim() - 1, 0));
     } else {
-      REQUIRE_THROWS(t[0]);
+      _CATCH_REQUIRE_THROWS(t[0]);
     }
 
     // fill_ (argument to fill_ can only be a 0-dim tensor)
     TRY_CATCH_ELSE(t.fill_(t.sum(0)),
-                   REQUIRE(t.dim() > 1),
-                   REQUIRE(t.dim() <= 1));
+                   CATCH_REQUIRE(t.dim() > 1),
+                   CATCH_REQUIRE(t.dim() <= 1));
   }
 
   for (auto lhs_it = sizes.begin(); lhs_it != sizes.end(); ++lhs_it) {
@@ -156,8 +156,8 @@
           auto lhs = ones(*lhs_it, T);
           auto rhs = ones(*rhs_it, T);
           if(*lhs_it != *rhs_it) {
-            REQUIRE(!lhs.is_same_size(rhs));
-            REQUIRE(!rhs.is_same_size(lhs));
+            CATCH_REQUIRE(!lhs.is_same_size(rhs));
+            CATCH_REQUIRE(!rhs.is_same_size(lhs));
           }
       }
       // forced size functions (resize_, resize_as, set_)
@@ -192,7 +192,7 @@
             auto storage = T.storage(rhs.numel(), false);
             lhs.set_(storage);
             // should not be dim 0 because an empty storage is dim 1; all other storages aren't scalars
-            REQUIRE(lhs.dim() != 0);
+            CATCH_REQUIRE(lhs.dim() != 0);
           }
           {
             // with storage, offset, sizes, strides
@@ -211,8 +211,8 @@
         auto rhs = ones(*rhs_it, T);
         auto rhs_size = *rhs_it;
         TRY_CATCH_ELSE(auto result = lhs.view(rhs_size),
-                       REQUIRE(lhs.numel() != rhs.numel()),
-                       REQUIRE(lhs.numel() == rhs.numel()); require_equal_size_dim(result, rhs););
+                       CATCH_REQUIRE(lhs.numel() != rhs.numel()),
+                       CATCH_REQUIRE(lhs.numel() == rhs.numel()); require_equal_size_dim(result, rhs););
       }
 
       // take
@@ -220,7 +220,7 @@
         auto lhs = ones(*lhs_it, T);
         auto rhs = zeros(*rhs_it, T).toType(ScalarType::Long);
         TRY_CATCH_ELSE(auto result = lhs.take(rhs),
-                       REQUIRE(lhs.numel() == 0); REQUIRE(rhs.numel() != 0),
+                       CATCH_REQUIRE(lhs.numel() == 0); CATCH_REQUIRE(rhs.numel() != 0),
                        require_equal_size_dim(result, rhs));
       }
 
@@ -230,7 +230,7 @@
         auto lhs = ones(*lhs_it, T);
         auto rhs = ones(*rhs_it, T);
         TRY_CATCH_ELSE(auto result = lhs.ger(rhs),
-                       REQUIRE((lhs.numel() == 0 || rhs.numel() == 0 || lhs.dim() != 1 || rhs.dim() != 1)),
+                       CATCH_REQUIRE((lhs.numel() == 0 || rhs.numel() == 0 || lhs.dim() != 1 || rhs.dim() != 1)),
                        [&]() {
                          int64_t dim0 = lhs.dim() == 0 ? 1 : lhs.size(0);
                          int64_t dim1 = rhs.dim() == 0 ? 1 : rhs.size(0);
@@ -246,8 +246,8 @@
         auto rhs_size = *rhs_it;
         bool should_pass = should_expand(lhs_size, rhs_size);
         TRY_CATCH_ELSE(auto result = lhs.expand(rhs_size),
-                       REQUIRE(!should_pass),
-                       REQUIRE(should_pass); require_equal_size_dim(result, rhs););
+                       CATCH_REQUIRE(!should_pass),
+                       CATCH_REQUIRE(should_pass); require_equal_size_dim(result, rhs););
 
         // in-place functions (would be good if we can also do a non-broadcasting one, b/c
         // broadcasting functions will always end up operating on tensors of same size;
@@ -255,21 +255,21 @@
         {
           bool should_pass_inplace = should_expand(rhs_size, lhs_size);
           TRY_CATCH_ELSE(lhs.add_(rhs),
-                         REQUIRE(!should_pass_inplace),
-                         REQUIRE(should_pass_inplace); require_equal_size_dim(lhs, ones(*lhs_it, T)););
+                         CATCH_REQUIRE(!should_pass_inplace),
+                         CATCH_REQUIRE(should_pass_inplace); require_equal_size_dim(lhs, ones(*lhs_it, T)););
         }
       }
     }
   }
 }
 
-TEST_CASE( "scalar tensor test CPU", "[cpu]" ) {
+CATCH_TEST_CASE( "scalar tensor test CPU", "[cpu]" ) {
   manual_seed(123, at::kCPU);
 
   test(CPU(kFloat));
 }
 
-TEST_CASE( "scalar tensor test CUDA", "[cuda]" ) {
+CATCH_TEST_CASE( "scalar tensor test CUDA", "[cuda]" ) {
   manual_seed(123, at::kCUDA);
 
   if (at::hasCUDA()) {
diff --git a/aten/src/ATen/test/scalar_test.cpp b/aten/src/ATen/test/scalar_test.cpp
index 72ef4e4..247830c 100644
--- a/aten/src/ATen/test/scalar_test.cpp
+++ b/aten/src/ATen/test/scalar_test.cpp
@@ -1,5 +1,5 @@
 #define CATCH_CONFIG_MAIN
-#include "catch.hpp"
+#include "catch_utils.hpp"
 
 #include <iostream>
 // define constants like M_PI and C keywords for MSVC
@@ -33,25 +33,25 @@
 
 void test_overflow() {
   auto s1 = Scalar(M_PI);
-  REQUIRE(s1.toFloat() == static_cast<float>(M_PI));
+  CATCH_REQUIRE(s1.toFloat() == static_cast<float>(M_PI));
   s1.toHalf();
 
   s1 = Scalar(100000);
-  REQUIRE(s1.toFloat() == 100000.0);
-  REQUIRE(s1.toInt() == 100000);
+  CATCH_REQUIRE(s1.toFloat() == 100000.0);
+  CATCH_REQUIRE(s1.toInt() == 100000);
 
-  REQUIRE_THROWS_AS(s1.toHalf(), std::domain_error);
+  CATCH_REQUIRE_THROWS_AS(s1.toHalf(), std::domain_error);
 
   s1 = Scalar(NAN);
-  REQUIRE(std::isnan(s1.toFloat()));
-  REQUIRE_THROWS_AS(s1.toInt(), std::domain_error);
+  CATCH_REQUIRE(std::isnan(s1.toFloat()));
+  CATCH_REQUIRE_THROWS_AS(s1.toInt(), std::domain_error);
 
   s1 = Scalar(INFINITY);
-  REQUIRE(std::isinf(s1.toFloat()));
-  REQUIRE_THROWS_AS(s1.toInt(), std::domain_error);
+  CATCH_REQUIRE(std::isinf(s1.toFloat()));
+  CATCH_REQUIRE_THROWS_AS(s1.toInt(), std::domain_error);
 }
 
-TEST_CASE( "scalar test", "[]" ) {
+CATCH_TEST_CASE( "scalar test", "[]" ) {
 
   manual_seed(123, at::kCPU);
   manual_seed(123, at::kCUDA);
@@ -62,7 +62,7 @@
   Scalar h2 = h;
   cout << "H2: " << h2.toDouble() << " " << what.toFloat() << " " << bar.toDouble() << " " << what.isIntegral() <<  "\n";
   Generator & gen = at::globalContext().defaultGenerator(at::kCPU);
-  REQUIRE_NOTHROW(gen.seed());
+  CATCH_REQUIRE_NOTHROW(gen.seed());
   auto && C = at::globalContext();
   if(at::hasCUDA()) {
     auto t2 = zeros({4,4}, at::kCUDA);
@@ -71,12 +71,12 @@
   auto t = ones({4,4});
 
   auto wha2 = zeros({4,4}).add(t).sum();
-  REQUIRE( wha2.toCDouble() == 16.0 );
+  CATCH_REQUIRE( wha2.toCDouble() == 16.0 );
 
-  REQUIRE( t.sizes()[0] == 4 );
-  REQUIRE( t.sizes()[1] == 4 );
-  REQUIRE( t.strides()[0] == 4 );
-  REQUIRE( t.strides()[1] == 1 );
+  CATCH_REQUIRE( t.sizes()[0] == 4 );
+  CATCH_REQUIRE( t.sizes()[1] == 4 );
+  CATCH_REQUIRE( t.strides()[0] == 4 );
+  CATCH_REQUIRE( t.strides()[1] == 1 );
 
   Type & T = CPU(Float);
   Tensor x = randn({1,10}, T);
@@ -88,26 +88,26 @@
   Tensor next_h = i2h.add(h2h);
   next_h = next_h.tanh();
 
-  REQUIRE_THROWS(at::_local_scalar(Tensor{}));
+  _CATCH_REQUIRE_THROWS(at::_local_scalar(Tensor{}));
 
   test_overflow();
 
   if(at::hasCUDA()) {
     auto r = CUDA(Float).copy(next_h);
-    REQUIRE(CPU(Float).copy(r).equal(next_h));
+    CATCH_REQUIRE(CPU(Float).copy(r).equal(next_h));
   }
-  REQUIRE_NOTHROW(randn({10,10,2}, T));
+  CATCH_REQUIRE_NOTHROW(randn({10,10,2}, T));
 
   // check Scalar.toTensor on Scalars backed by different data types
-  REQUIRE(scalar_to_tensor(bar).type().scalarType() == kDouble);
-  REQUIRE(scalar_to_tensor(what).type().scalarType() == kLong);
-  REQUIRE(scalar_to_tensor(ones({})._local_scalar()).type().scalarType() == kDouble);
+  CATCH_REQUIRE(scalar_to_tensor(bar).type().scalarType() == kDouble);
+  CATCH_REQUIRE(scalar_to_tensor(what).type().scalarType() == kLong);
+  CATCH_REQUIRE(scalar_to_tensor(ones({})._local_scalar()).type().scalarType() == kDouble);
 
   if (x.type().scalarType() != ScalarType::Half) {
     AT_DISPATCH_ALL_TYPES(x.type(), "foo", [&] {
       scalar_t s = 1;
       std::stringstream ss;
-      REQUIRE_NOTHROW(ss << "hello, dispatch" << x.type().toString() << s << "\n");
+      CATCH_REQUIRE_NOTHROW(ss << "hello, dispatch" << x.type().toString() << s << "\n");
       auto data = (scalar_t*)x.data_ptr();
       (void)data;
     });
@@ -116,10 +116,10 @@
   // test direct C-scalar type conversions
   {
     auto x = ones({1,2}, T);
-    REQUIRE_THROWS(x.toCFloat());
+    _CATCH_REQUIRE_THROWS(x.toCFloat());
   }
   auto float_one = ones({}, T);
-  REQUIRE(float_one.toCFloat() == 1);
-  REQUIRE(float_one.toCInt() == 1);
-  REQUIRE((float_one.toCHalf() == 1));
+  CATCH_REQUIRE(float_one.toCFloat() == 1);
+  CATCH_REQUIRE(float_one.toCInt() == 1);
+  CATCH_REQUIRE((float_one.toCHalf() == 1));
 }
diff --git a/aten/src/ATen/test/stream_test.cpp b/aten/src/ATen/test/stream_test.cpp
index 145c4f4..8dc015d 100644
--- a/aten/src/ATen/test/stream_test.cpp
+++ b/aten/src/ATen/test/stream_test.cpp
@@ -1,5 +1,5 @@
 #define CATCH_CONFIG_MAIN
-#include "catch.hpp"
+#include "catch_utils.hpp"
 
 #include "ATen/cuda/CUDAContext.h"
 #include "ATen/cuda/CUDAGuard.h"
@@ -14,7 +14,7 @@
 /*
 Tests related to ATen streams.
 */
-TEST_CASE(
+CATCH_TEST_CASE(
     "Copying and Moving Streams",
     "Verifies streams are live through copying and moving") {
   int32_t device = -1;
@@ -29,14 +29,14 @@
 
     copyStream = s;
 
-    REQUIRE(copyStream.internals() == s.internals());
-    REQUIRE(copyStream.device() == device);
-    REQUIRE(copyStream.stream() == cuda_stream);
+    CATCH_REQUIRE(copyStream.internals() == s.internals());
+    CATCH_REQUIRE(copyStream.device() == device);
+    CATCH_REQUIRE(copyStream.stream() == cuda_stream);
   }
 
-  REQUIRE(copyStream.internals());
-  REQUIRE(copyStream.device() == device);
-  REQUIRE(copyStream.stream() == cuda_stream);
+  CATCH_REQUIRE(copyStream.internals());
+  CATCH_REQUIRE(copyStream.device() == device);
+  CATCH_REQUIRE(copyStream.stream() == cuda_stream);
 
   // Tests that moving works as expected and preserves the stream
   at::cuda::CUDAStream moveStream;
@@ -47,41 +47,41 @@
 
     moveStream = std::move(s);
 
-    REQUIRE(moveStream.device() == device);
-    REQUIRE(moveStream.stream() == cuda_stream);
+    CATCH_REQUIRE(moveStream.device() == device);
+    CATCH_REQUIRE(moveStream.stream() == cuda_stream);
   }
 
-  REQUIRE(moveStream.internals());
-  REQUIRE(moveStream.device() == device);
-  REQUIRE(moveStream.stream() == cuda_stream);
+  CATCH_REQUIRE(moveStream.internals());
+  CATCH_REQUIRE(moveStream.device() == device);
+  CATCH_REQUIRE(moveStream.stream() == cuda_stream);
 }
 
-TEST_CASE("Getting and Setting Streams", "Verifies streams are set properly") {
+CATCH_TEST_CASE("Getting and Setting Streams", "Verifies streams are set properly") {
   at::cuda::CUDAStream myStream = at::cuda::createCUDAStream();
 
   // Sets and gets
   at::cuda::setCurrentCUDAStream(myStream);
   at::cuda::CUDAStream curStream = at::cuda::getCurrentCUDAStream();
 
-  REQUIRE(myStream == curStream);
+  CATCH_REQUIRE(myStream == curStream);
 
   // Gets, sets, and gets default stream
   at::cuda::CUDAStream defaultStream = at::cuda::getDefaultCUDAStream();
   at::cuda::setCurrentCUDAStream(defaultStream);
   curStream = at::cuda::getCurrentCUDAStream();
 
-  REQUIRE(defaultStream != myStream);
-  REQUIRE(curStream == defaultStream);
+  CATCH_REQUIRE(defaultStream != myStream);
+  CATCH_REQUIRE(curStream == defaultStream);
 }
 
 void thread_fun(at::cuda::CUDAStream& cur_thread_stream) {
   auto new_stream = at::cuda::createCUDAStream();
   at::cuda::setCurrentCUDAStream(new_stream);
   cur_thread_stream = at::cuda::getCurrentCUDAStream();
-  REQUIRE(cur_thread_stream == new_stream);
+  CATCH_REQUIRE(cur_thread_stream == new_stream);
 }
 
-TEST_CASE(
+CATCH_TEST_CASE(
     "Multithread Getting and Setting",
     "Ensures streams are thread local") {
   at::cuda::CUDAStream s0, s1;
@@ -94,25 +94,25 @@
   at::cuda::CUDAStream cur_stream = at::cuda::getCurrentCUDAStream();
   at::cuda::CUDAStream default_stream = at::cuda::getDefaultCUDAStream();
 
-  REQUIRE(cur_stream == default_stream);
-  REQUIRE(cur_stream != s0);
-  REQUIRE(cur_stream != s1);
-  REQUIRE(s0 != s1);
+  CATCH_REQUIRE(cur_stream == default_stream);
+  CATCH_REQUIRE(cur_stream != s0);
+  CATCH_REQUIRE(cur_stream != s1);
+  CATCH_REQUIRE(s0 != s1);
 }
 
-TEST_CASE("CUDAGuard") {
+CATCH_TEST_CASE("CUDAGuard") {
   if (at::cuda::getNumGPUs() < 2) {
     return;
   }
 
   // -- begin setup
 
-  REQUIRE(at::cuda::current_device() == 0);
+  CATCH_REQUIRE(at::cuda::current_device() == 0);
   std::vector<at::cuda::CUDAStream> streams0 = {
       at::cuda::getDefaultCUDAStream(),
       at::cuda::createCUDAStream()};
-  REQUIRE(streams0[0].device() == 0);
-  REQUIRE(streams0[1].device() == 0);
+  CATCH_REQUIRE(streams0[0].device() == 0);
+  CATCH_REQUIRE(streams0[1].device() == 0);
   at::cuda::setCurrentCUDAStream(streams0[0]);
 
   std::vector<at::cuda::CUDAStream> streams1;
@@ -121,47 +121,47 @@
     streams1.push_back(at::cuda::getDefaultCUDAStream());
     streams1.push_back(at::cuda::createCUDAStream());
   }
-  REQUIRE(streams1[0].device() == 1);
-  REQUIRE(streams1[1].device() == 1);
+  CATCH_REQUIRE(streams1[0].device() == 1);
+  CATCH_REQUIRE(streams1[1].device() == 1);
   at::cuda::setCurrentCUDAStream(streams1[0]);
 
-  REQUIRE(at::cuda::current_device() == 0);
+  CATCH_REQUIRE(at::cuda::current_device() == 0);
 
   // -- end setup
 
   // Test that all original streams are recorded.
   {
     at::cuda::CUDAGuard guard;
-    REQUIRE(guard.original_streams().empty());
+    CATCH_REQUIRE(guard.original_streams().empty());
     guard.set_stream(streams0[0]);
-    REQUIRE(
+    CATCH_REQUIRE(
         guard.original_streams().size() == at::cuda::getNumGPUs());
-    REQUIRE(guard.original_streams()[0] == streams0[0]);
-    REQUIRE(guard.original_streams()[1] == streams1[0]);
+    CATCH_REQUIRE(guard.original_streams()[0] == streams0[0]);
+    CATCH_REQUIRE(guard.original_streams()[1] == streams1[0]);
   }
 
   // Setting a stream changes the current device and the stream on that device
   {
     at::cuda::CUDAGuard guard(streams1[1]);
-    REQUIRE(guard.last_device() == 1);
-    REQUIRE(at::cuda::current_device() == 1);
-    REQUIRE(at::cuda::getCurrentCUDAStream(1) == streams1[1]);
+    CATCH_REQUIRE(guard.last_device() == 1);
+    CATCH_REQUIRE(at::cuda::current_device() == 1);
+    CATCH_REQUIRE(at::cuda::getCurrentCUDAStream(1) == streams1[1]);
   }
 
   // Device and stream are now reset
-  REQUIRE(at::cuda::current_device() == 0);
-  REQUIRE(at::cuda::getCurrentCUDAStream(1) == streams1[0]);
+  CATCH_REQUIRE(at::cuda::current_device() == 0);
+  CATCH_REQUIRE(at::cuda::getCurrentCUDAStream(1) == streams1[0]);
 
   // Setting only the device changes only the current device and not the stream
   {
     at::cuda::CUDAGuard guard(/*device=*/1);
-    REQUIRE(guard.last_device() == 1);
-    REQUIRE(at::cuda::current_device() == 1);
-    REQUIRE(at::cuda::getCurrentCUDAStream(1) == streams1[0]);
+    CATCH_REQUIRE(guard.last_device() == 1);
+    CATCH_REQUIRE(at::cuda::current_device() == 1);
+    CATCH_REQUIRE(at::cuda::getCurrentCUDAStream(1) == streams1[0]);
   }
 
-  REQUIRE(at::cuda::current_device() == 0);
-  REQUIRE(at::cuda::getCurrentCUDAStream(0) == streams0[0]);
+  CATCH_REQUIRE(at::cuda::current_device() == 0);
+  CATCH_REQUIRE(at::cuda::getCurrentCUDAStream(0) == streams0[0]);
 
   // Setting the stream first, and then the device, first changes the devices
   // back, and then resets the stream on the initial device.
@@ -171,12 +171,12 @@
     guard.set_device(1);
   }
 
-  REQUIRE(at::cuda::current_device() == 0);
-  REQUIRE(at::cuda::getCurrentCUDAStream(0) == streams0[0]);
-  REQUIRE(at::cuda::getCurrentCUDAStream(1) == streams1[0]);
+  CATCH_REQUIRE(at::cuda::current_device() == 0);
+  CATCH_REQUIRE(at::cuda::getCurrentCUDAStream(0) == streams0[0]);
+  CATCH_REQUIRE(at::cuda::getCurrentCUDAStream(1) == streams1[0]);
 }
 
-TEST_CASE("CUDAGuardIsMovable") {
+CATCH_TEST_CASE("CUDAGuardIsMovable") {
   if (at::cuda::getNumGPUs() < 2) {
     return;
   }
@@ -185,17 +185,17 @@
   at::cuda::CUDAGuard first(stream);
   first.set_device(1);
   at::cuda::CUDAGuard second(std::move(first));
-  REQUIRE(second.original_streams().size() == device_count);
-  REQUIRE(second.original_device() == 0);
-  REQUIRE(second.last_device() == 1);
+  CATCH_REQUIRE(second.original_streams().size() == device_count);
+  CATCH_REQUIRE(second.original_device() == 0);
+  CATCH_REQUIRE(second.last_device() == 1);
   at::cuda::CUDAGuard third;
   third = std::move(second);
-  REQUIRE(third.original_streams().size() == device_count);
-  REQUIRE(third.original_device() == 0);
-  REQUIRE(third.last_device() == 1);
+  CATCH_REQUIRE(third.original_streams().size() == device_count);
+  CATCH_REQUIRE(third.original_device() == 0);
+  CATCH_REQUIRE(third.last_device() == 1);
 }
 
-TEST_CASE("Streampool Round Robin") {
+CATCH_TEST_CASE("Streampool Round Robin") {
   std::vector<at::cuda::CUDAStream> streams{};
   for (int i = 0; i < 200; ++i) {
     streams.emplace_back(at::cuda::detail::CUDAStream_createStream());
@@ -209,10 +209,10 @@
     if (!result_pair.second) hasDuplicates = true;
   }
 
-  REQUIRE(hasDuplicates);
+  CATCH_REQUIRE(hasDuplicates);
 }
 
-TEST_CASE("Multi-GPU") {
+CATCH_TEST_CASE("Multi-GPU") {
   if (at::cuda::getNumGPUs() < 2) return;
 
   at::cuda::CUDAStream s0 = at::cuda::createCUDAStream(true, 0);
@@ -221,17 +221,17 @@
   at::cuda::setCurrentCUDAStream(s0);
   at::cuda::setCurrentCUDAStream(s1);
 
-  REQUIRE(s0 == at::cuda::getCurrentCUDAStream());
+  CATCH_REQUIRE(s0 == at::cuda::getCurrentCUDAStream());
 
   at::DeviceGuard device_guard{1};
-  REQUIRE(s1 == at::cuda::getCurrentCUDAStream());
+  CATCH_REQUIRE(s1 == at::cuda::getCurrentCUDAStream());
 }
 
-TEST_CASE("CUDAEvent Syncs") {
+CATCH_TEST_CASE("CUDAEvent Syncs") {
   const auto stream = at::cuda::createCUDAStream();
   at::cuda::CUDAEvent event;
 
-  REQUIRE(!event.happened());
+  CATCH_REQUIRE(!event.happened());
 
   event.recordOnce(stream);
 
@@ -242,10 +242,10 @@
   wait_stream1.synchronize_with(event);
 
   cudaStreamSynchronize(wait_stream0);
-  REQUIRE(event.happened());
+  CATCH_REQUIRE(event.happened());
 }
 
-TEST_CASE("Cross-Device Events") {
+CATCH_TEST_CASE("Cross-Device Events") {
   if (at::cuda::getNumGPUs() < 2) return;
 
   const auto stream0 = at::cuda::createCUDAStream();
@@ -260,10 +260,10 @@
   
   event0 = std::move(event1);
   
-  REQUIRE(event0.device() == 1);
+  CATCH_REQUIRE(event0.device() == 1);
 
   stream0.synchronize_with(event0);
   
   cudaStreamSynchronize(stream0);
-  REQUIRE(event0.happened());
+  CATCH_REQUIRE(event0.happened());
 }
diff --git a/aten/src/ATen/test/test_parallel.cpp b/aten/src/ATen/test/test_parallel.cpp
index 5523280..8170173 100644
--- a/aten/src/ATen/test/test_parallel.cpp
+++ b/aten/src/ATen/test/test_parallel.cpp
@@ -1,5 +1,5 @@
 #define CATCH_CONFIG_MAIN
-#include "catch.hpp"
+#include "catch_utils.hpp"
 
 #include "ATen/ATen.h"
 #include "ATen/DLConvertor.h"
@@ -11,7 +11,7 @@
 
 using namespace at;
 
-TEST_CASE( "parallel", "[cpu]" ) {
+CATCH_TEST_CASE( "parallel", "[cpu]" ) {
 
   manual_seed(123, at::kCPU);
   set_num_threads(1);
@@ -24,5 +24,5 @@
   as[0] = 1;
   as[1] = 0;
   as[2] = 0;
-  REQUIRE(a.sum(0).equal(as));
+  CATCH_REQUIRE(a.sum(0).equal(as));
 }
diff --git a/aten/src/ATen/test/undefined_tensor_test.cpp b/aten/src/ATen/test/undefined_tensor_test.cpp
index e47772a..c01dff2 100644
--- a/aten/src/ATen/test/undefined_tensor_test.cpp
+++ b/aten/src/ATen/test/undefined_tensor_test.cpp
@@ -1,5 +1,5 @@
 #define CATCH_CONFIG_MAIN
-#include "catch.hpp"
+#include "catch_utils.hpp"
 
 #include "ATen/ATen.h"
 #include "ATen/core/UndefinedTensorImpl.h"
@@ -8,7 +8,7 @@
 
 using namespace at;
 
-TEST_CASE( "undefined tensor test", "[]" ) {
+CATCH_TEST_CASE( "undefined tensor test", "[]" ) {
   manual_seed(123, at::kCPU);
 
   // mainly test ops on undefined tensors don't segfault and give a reasonable errror message.
@@ -17,36 +17,36 @@
 
   std::stringstream ss;
   ss << und << std::endl;
-  REQUIRE(!und.defined());
-  REQUIRE(std::string("UndefinedType") == und.toString());
+  CATCH_REQUIRE(!und.defined());
+  CATCH_REQUIRE(std::string("UndefinedType") == und.toString());
 
-  REQUIRE_THROWS(und.strides());
-  REQUIRE_THROWS(und.dim());
-  REQUIRE_THROWS([]() {return Tensor();}() = Scalar(5));
-  REQUIRE_THROWS(und.add(und));
-  REQUIRE_THROWS(und.add(ft));
-  REQUIRE_THROWS(ft.add(und));
-  REQUIRE_THROWS(und.add(5));
-  REQUIRE_THROWS(und.mm(und));
+  _CATCH_REQUIRE_THROWS(und.strides());
+  _CATCH_REQUIRE_THROWS(und.dim());
+  _CATCH_REQUIRE_THROWS([]() {return Tensor();}() = Scalar(5));
+  _CATCH_REQUIRE_THROWS(und.add(und));
+  _CATCH_REQUIRE_THROWS(und.add(ft));
+  _CATCH_REQUIRE_THROWS(ft.add(und));
+  _CATCH_REQUIRE_THROWS(und.add(5));
+  _CATCH_REQUIRE_THROWS(und.mm(und));
 
   und.toType(und.type());
-  REQUIRE_THROWS(und.toType(ft.type()));
-  REQUIRE_THROWS(ft.toType(und.type()));
+  _CATCH_REQUIRE_THROWS(und.toType(ft.type()));
+  _CATCH_REQUIRE_THROWS(ft.toType(und.type()));
   und.toType(ScalarType::Undefined);
-  REQUIRE_THROWS(und.toType(ScalarType::Float));
-  REQUIRE_THROWS(ft.toType(ScalarType::Undefined));
+  _CATCH_REQUIRE_THROWS(und.toType(ScalarType::Float));
+  _CATCH_REQUIRE_THROWS(ft.toType(ScalarType::Undefined));
 
   // copy_
-  REQUIRE_THROWS(und.copy_(und));
-  REQUIRE_THROWS(und.copy_(ft));
-  REQUIRE_THROWS(ft.copy_(und));
+  _CATCH_REQUIRE_THROWS(und.copy_(und));
+  _CATCH_REQUIRE_THROWS(und.copy_(ft));
+  _CATCH_REQUIRE_THROWS(ft.copy_(und));
 
   und.toBackend(Backend::Undefined);
-  REQUIRE_THROWS(und.toBackend(Backend::CPU));
-  REQUIRE_THROWS(ft.toBackend(Backend::Undefined));
+  _CATCH_REQUIRE_THROWS(und.toBackend(Backend::CPU));
+  _CATCH_REQUIRE_THROWS(ft.toBackend(Backend::Undefined));
 
   Tensor to_move = ones({1}, CPU(kFloat));
   Tensor m(std::move(to_move));
-  REQUIRE(!to_move.defined());
-  REQUIRE(to_move.unsafeGetTensorImpl() == UndefinedTensorImpl::singleton());
+  CATCH_REQUIRE(!to_move.defined());
+  CATCH_REQUIRE(to_move.unsafeGetTensorImpl() == UndefinedTensorImpl::singleton());
 }
diff --git a/aten/src/ATen/test/weakref_test.cpp b/aten/src/ATen/test/weakref_test.cpp
index 167520b..42c9f61 100644
--- a/aten/src/ATen/test/weakref_test.cpp
+++ b/aten/src/ATen/test/weakref_test.cpp
@@ -1,5 +1,5 @@
 #define CATCH_CONFIG_MAIN
-#include "catch.hpp"
+#include "catch_utils.hpp"
 
 #include "ATen/ATen.h"
 
@@ -10,53 +10,53 @@
 using at::Tensor;
 using at::WeakTensor;
 
-TEST_CASE( "Weak pointer tests", "" ) {
-  SECTION("gets invalidated") {
+CATCH_TEST_CASE( "Weak pointer tests", "" ) {
+  CATCH_SECTION("gets invalidated") {
     Tensor a = at::ones({2, 2});
     WeakTensor b = a;
     a.reset();
-    REQUIRE_FALSE(b.lock().defined());
+    CATCH_REQUIRE_FALSE(b.lock().defined());
   }
 
-  SECTION("can successfully lock") {
+  CATCH_SECTION("can successfully lock") {
     Tensor a = at::ones({2, 2});
     WeakTensor b = a;
     auto c = b.lock();
-    REQUIRE(c.defined());
+    CATCH_REQUIRE(c.defined());
 
     a.reset();
-    REQUIRE(b.lock().defined());
+    CATCH_REQUIRE(b.lock().defined());
     c.reset();
-    REQUIRE_FALSE(b.lock().defined());
+    CATCH_REQUIRE_FALSE(b.lock().defined());
   }
 
-  SECTION("updates refcounts correctly") {
+  CATCH_SECTION("updates refcounts correctly") {
     Tensor a = at::ones({2, 2});
-    REQUIRE(a.use_count() == 1);
-    REQUIRE(a.weak_use_count() == 1);
+    CATCH_REQUIRE(a.use_count() == 1);
+    CATCH_REQUIRE(a.weak_use_count() == 1);
     {
       WeakTensor b = a;
-      REQUIRE(a.use_count() == 1);
-      REQUIRE(a.weak_use_count() == 2);
+      CATCH_REQUIRE(a.use_count() == 1);
+      CATCH_REQUIRE(a.weak_use_count() == 2);
     }
-    REQUIRE(a.use_count() == 1);
-    REQUIRE(a.weak_use_count() == 1);
+    CATCH_REQUIRE(a.use_count() == 1);
+    CATCH_REQUIRE(a.weak_use_count() == 1);
     {
       WeakTensor b = a;
-      REQUIRE(a.use_count() == 1);
+      CATCH_REQUIRE(a.use_count() == 1);
       auto locked = b.lock();
-      REQUIRE(locked.defined());
-      REQUIRE(a.use_count() == 2);
+      CATCH_REQUIRE(locked.defined());
+      CATCH_REQUIRE(a.use_count() == 2);
     }
-    REQUIRE(a.use_count() == 1);
-    REQUIRE(a.weak_use_count() == 1);
+    CATCH_REQUIRE(a.use_count() == 1);
+    CATCH_REQUIRE(a.weak_use_count() == 1);
     {
       WeakTensor b = a;
-      REQUIRE(a.use_count() == 1);
-      REQUIRE(a.weak_use_count() == 2);
+      CATCH_REQUIRE(a.use_count() == 1);
+      CATCH_REQUIRE(a.weak_use_count() == 2);
       a.reset();
-      REQUIRE(b.use_count() == 0);
-      REQUIRE(b.weak_use_count() == 1);
+      CATCH_REQUIRE(b.use_count() == 0);
+      CATCH_REQUIRE(b.weak_use_count() == 1);
     }
   }
 }
diff --git a/aten/src/ATen/test/wrapdim_test.cpp b/aten/src/ATen/test/wrapdim_test.cpp
index 8e813bc..f76dac2 100644
--- a/aten/src/ATen/test/wrapdim_test.cpp
+++ b/aten/src/ATen/test/wrapdim_test.cpp
@@ -1,43 +1,43 @@
 #define CATCH_CONFIG_MAIN
-#include "catch.hpp"
+#include "catch_utils.hpp"
 
 #include "ATen/ATen.h"
 #include "test_seed.h"
 
 using namespace at;
 
-TEST_CASE( "wrapdim test", "[]" ) {
+CATCH_TEST_CASE( "wrapdim test", "[]" ) {
   manual_seed(123, at::kCPU);
 
   Type & T = CPU(kFloat);
 
-  SECTION( "simple case" ) {
+  CATCH_SECTION( "simple case" ) {
     auto a = randn({2, 3, 4, 5}, T);
-    REQUIRE(a.prod(-4).equal(a.prod(0)));
-    REQUIRE(a.prod(3).equal(a.prod(-1)));
+    CATCH_REQUIRE(a.prod(-4).equal(a.prod(0)));
+    CATCH_REQUIRE(a.prod(3).equal(a.prod(-1)));
   }
 
-  SECTION( "expression specification" ) {
+  CATCH_SECTION( "expression specification" ) {
     auto a = randn({2, 3, 4, 5}, T);
-    REQUIRE(a.unsqueeze(-5).equal(a.unsqueeze(0)));
-    REQUIRE(a.unsqueeze(4).equal(a.unsqueeze(-1)));
+    CATCH_REQUIRE(a.unsqueeze(-5).equal(a.unsqueeze(0)));
+    CATCH_REQUIRE(a.unsqueeze(4).equal(a.unsqueeze(-1)));
 
     // can unsqueeze scalar
     auto b = randn(1, T);
     b.unsafeGetTensorImpl()->maybe_zero_dim(true);
-    REQUIRE(b.unsqueeze(0).equal(b.unsqueeze(-1)));
+    CATCH_REQUIRE(b.unsqueeze(0).equal(b.unsqueeze(-1)));
   }
 
-  SECTION( "empty tensor" ) {
+  CATCH_SECTION( "empty tensor" ) {
     auto a = randn(0, T);
-    REQUIRE(a.prod(0).equal(at::ones({}, T)));
+    CATCH_REQUIRE(a.prod(0).equal(at::ones({}, T)));
   }
 
-  SECTION( "scalar vs 1-dim, 1-size" ) {
+  CATCH_SECTION( "scalar vs 1-dim, 1-size" ) {
     auto a = randn(1, T);
-    REQUIRE(a.prod(0).equal(a.prod(-1)));
+    CATCH_REQUIRE(a.prod(0).equal(a.prod(-1)));
     a.unsafeGetTensorImpl()->maybe_zero_dim(true);
-    REQUIRE(a.dim() == 0);
-    REQUIRE(a.prod(0).equal(a.prod(-1)));
+    CATCH_REQUIRE(a.dim() == 0);
+    CATCH_REQUIRE(a.prod(0).equal(a.prod(-1)));
   }
 }
diff --git a/test/cpp/api/any.cpp b/test/cpp/api/any.cpp
index 9368d4d..18db2f5 100644
--- a/test/cpp/api/any.cpp
+++ b/test/cpp/api/any.cpp
@@ -1,4 +1,4 @@
-#include <catch.hpp>
+#include "catch_utils.hpp"
 
 #include <torch/nn/modules/any.h>
 #include <torch/torch.h>
@@ -13,39 +13,39 @@
 using Catch::Contains;
 using Catch::StartsWith;
 
-TEST_CASE("any-module") {
+CATCH_TEST_CASE("any-module") {
   torch::manual_seed(0);
-  SECTION("int()") {
+  CATCH_SECTION("int()") {
     struct M : torch::nn::Module {
       int forward() {
         return 123;
       }
     };
     AnyModule any(M{});
-    REQUIRE(any.forward<int>() == 123);
+    CATCH_REQUIRE(any.forward<int>() == 123);
   }
 
-  SECTION("int(int)") {
+  CATCH_SECTION("int(int)") {
     struct M : torch::nn::Module {
       int forward(int x) {
         return x;
       }
     };
     AnyModule any(M{});
-    REQUIRE(any.forward<int>(5) == 5);
+    CATCH_REQUIRE(any.forward<int>(5) == 5);
   }
 
-  SECTION("const char*(const char*)") {
+  CATCH_SECTION("const char*(const char*)") {
     struct M : torch::nn::Module {
       const char* forward(const char* x) {
         return x;
       }
     };
     AnyModule any(M{});
-    REQUIRE(any.forward<const char*>("hello") == std::string("hello"));
+    CATCH_REQUIRE(any.forward<const char*>("hello") == std::string("hello"));
   }
 
-  SECTION("string(int, const double)") {
+  CATCH_SECTION("string(int, const double)") {
     struct M : torch::nn::Module {
       std::string forward(int x, const double f) {
         return std::to_string(static_cast<int>(x + f));
@@ -53,10 +53,10 @@
     };
     AnyModule any(M{});
     int x = 4;
-    REQUIRE(any.forward<std::string>(x, 3.14) == std::string("7"));
+    CATCH_REQUIRE(any.forward<std::string>(x, 3.14) == std::string("7"));
   }
 
-  SECTION("Tensor(string, const string&, string&&)") {
+  CATCH_SECTION("Tensor(string, const string&, string&&)") {
     struct M : torch::nn::Module {
       torch::Tensor forward(
           std::string a,
@@ -67,42 +67,42 @@
       }
     };
     AnyModule any(M{});
-    REQUIRE(
+    CATCH_REQUIRE(
         any.forward(
                std::string("a"), std::string("ab"), std::string("abc"))
             .sum()
             .toCInt() == 6);
   }
-  SECTION("wrong argument type") {
+  CATCH_SECTION("wrong argument type") {
     struct M : torch::nn::Module {
       int forward(float x) {
         return x;
       }
     };
     AnyModule any(M{});
-    REQUIRE_THROWS_WITH(
+    CATCH_REQUIRE_THROWS_WITH(
         any.forward(5.0),
         StartsWith("Expected argument #0 to be of type float, "
                    "but received value of type double"));
   }
-  SECTION("wrong number of arguments") {
+  CATCH_SECTION("wrong number of arguments") {
     struct M : torch::nn::Module {
       int forward(int a, int b) {
         return a + b;
       }
     };
     AnyModule any(M{});
-    REQUIRE_THROWS_WITH(
+    CATCH_REQUIRE_THROWS_WITH(
         any.forward(),
         Contains("M's forward() method expects 2 arguments, but received 0"));
-    REQUIRE_THROWS_WITH(
+    CATCH_REQUIRE_THROWS_WITH(
         any.forward(5),
         Contains("M's forward() method expects 2 arguments, but received 1"));
-    REQUIRE_THROWS_WITH(
+    CATCH_REQUIRE_THROWS_WITH(
         any.forward(1, 2, 3),
         Contains("M's forward() method expects 2 arguments, but received 3"));
   }
-  SECTION("get()") {
+  CATCH_SECTION("get()") {
     struct M : torch::nn::Module {
       explicit M(int value_) : torch::nn::Module("M"), value(value_) {}
       int value;
@@ -112,16 +112,16 @@
     };
     AnyModule any(M{5});
 
-    SECTION("good cast") {
-      REQUIRE(any.get<M>().value == 5);
+    CATCH_SECTION("good cast") {
+      CATCH_REQUIRE(any.get<M>().value == 5);
     }
 
-    SECTION("bad cast") {
+    CATCH_SECTION("bad cast") {
       struct N : torch::nn::Module {};
-      REQUIRE_THROWS_WITH(any.get<N>(), StartsWith("Attempted to cast module"));
+      CATCH_REQUIRE_THROWS_WITH(any.get<N>(), StartsWith("Attempted to cast module"));
     }
   }
-  SECTION("ptr()") {
+  CATCH_SECTION("ptr()") {
     struct M : torch::nn::Module {
       explicit M(int value_) : torch::nn::Module("M"), value(value_) {}
       int value;
@@ -131,24 +131,24 @@
     };
     AnyModule any(M{5});
 
-    SECTION("base class cast") {
+    CATCH_SECTION("base class cast") {
       auto ptr = any.ptr();
-      REQUIRE(ptr != nullptr);
-      REQUIRE(ptr->name() == "M");
+      CATCH_REQUIRE(ptr != nullptr);
+      CATCH_REQUIRE(ptr->name() == "M");
     }
 
-    SECTION("good downcast") {
+    CATCH_SECTION("good downcast") {
       auto ptr = any.ptr<M>();
-      REQUIRE(ptr != nullptr);
-      REQUIRE(ptr->value == 5);
+      CATCH_REQUIRE(ptr != nullptr);
+      CATCH_REQUIRE(ptr->value == 5);
     }
 
-    SECTION("bad downcast") {
+    CATCH_SECTION("bad downcast") {
       struct N : torch::nn::Module {};
-      REQUIRE_THROWS_WITH(any.ptr<N>(), StartsWith("Attempted to cast module"));
+      CATCH_REQUIRE_THROWS_WITH(any.ptr<N>(), StartsWith("Attempted to cast module"));
     }
   }
-  SECTION("default state is empty") {
+  CATCH_SECTION("default state is empty") {
     struct M : torch::nn::Module {
       explicit M(int value_) : value(value_) {}
       int value;
@@ -157,33 +157,33 @@
       }
     };
     AnyModule any;
-    REQUIRE(any.is_empty());
+    CATCH_REQUIRE(any.is_empty());
     any = std::make_shared<M>(5);
-    REQUIRE(!any.is_empty());
-    REQUIRE(any.get<M>().value == 5);
+    CATCH_REQUIRE(!any.is_empty());
+    CATCH_REQUIRE(any.get<M>().value == 5);
   }
-  SECTION("all methods throw for empty AnyModule") {
+  CATCH_SECTION("all methods throw for empty AnyModule") {
     struct M : torch::nn::Module {
       int forward(int x) {
         return x;
       }
     };
     AnyModule any;
-    REQUIRE(any.is_empty());
-    REQUIRE_THROWS_WITH(
+    CATCH_REQUIRE(any.is_empty());
+    CATCH_REQUIRE_THROWS_WITH(
         any.get<M>(), StartsWith("Cannot call get() on an empty AnyModule"));
-    REQUIRE_THROWS_WITH(
+    CATCH_REQUIRE_THROWS_WITH(
         any.ptr<M>(), StartsWith("Cannot call ptr() on an empty AnyModule"));
-    REQUIRE_THROWS_WITH(
+    CATCH_REQUIRE_THROWS_WITH(
         any.ptr(), StartsWith("Cannot call ptr() on an empty AnyModule"));
-    REQUIRE_THROWS_WITH(
+    CATCH_REQUIRE_THROWS_WITH(
         any.type_info(),
         StartsWith("Cannot call type_info() on an empty AnyModule"));
-    REQUIRE_THROWS_WITH(
+    CATCH_REQUIRE_THROWS_WITH(
         any.forward<int>(5),
         StartsWith("Cannot call forward() on an empty AnyModule"));
   }
-  SECTION("can move assign different modules") {
+  CATCH_SECTION("can move assign different modules") {
     struct M : torch::nn::Module {
       std::string forward(int x) {
         return std::to_string(x);
@@ -195,15 +195,15 @@
       }
     };
     AnyModule any;
-    REQUIRE(any.is_empty());
+    CATCH_REQUIRE(any.is_empty());
     any = std::make_shared<M>();
-    REQUIRE(!any.is_empty());
-    REQUIRE(any.forward<std::string>(5) == "5");
+    CATCH_REQUIRE(!any.is_empty());
+    CATCH_REQUIRE(any.forward<std::string>(5) == "5");
     any = std::make_shared<N>();
-    REQUIRE(!any.is_empty());
-    REQUIRE(any.forward<int>(5.0f) == 8);
+    CATCH_REQUIRE(!any.is_empty());
+    CATCH_REQUIRE(any.forward<int>(5.0f) == 8);
   }
-  SECTION("constructs from ModuleHolder") {
+  CATCH_SECTION("constructs from ModuleHolder") {
     struct MImpl : torch::nn::Module {
       explicit MImpl(int value_) : torch::nn::Module("M"), value(value_) {}
       int value;
@@ -218,14 +218,14 @@
     };
 
     AnyModule any(M{5});
-    REQUIRE(any.get<MImpl>().value == 5);
-    REQUIRE(any.get<M>()->value == 5);
+    CATCH_REQUIRE(any.get<MImpl>().value == 5);
+    CATCH_REQUIRE(any.get<M>()->value == 5);
 
     AnyModule module(Linear(3, 4));
     std::shared_ptr<Module> ptr = module.ptr();
     Linear linear(module.get<Linear>());
   }
-  SECTION("converts autograd::Variable to torch::Tensor correctly") {
+  CATCH_SECTION("converts autograd::Variable to torch::Tensor correctly") {
     struct M : torch::nn::Module {
       torch::Tensor forward(torch::Tensor input) {
         return input;
@@ -236,12 +236,12 @@
       // torch::Tensor before being passed to the function (to avoid a type
       // mismatch).
       AnyModule any(M{});
-      REQUIRE(
+      CATCH_REQUIRE(
           any.forward(torch::autograd::Variable(torch::ones(5)))
               .sum()
               .toCFloat() == 5);
       // at::Tensors that are not variables work too.
-      REQUIRE(any.forward(at::ones(5)).sum().toCFloat() == 5);
+      CATCH_REQUIRE(any.forward(at::ones(5)).sum().toCFloat() == 5);
     }
   }
 }
@@ -263,92 +263,92 @@
 } // namespace nn
 } // namespace torch
 
-TEST_CASE("any-value") {
+CATCH_TEST_CASE("any-value") {
   torch::manual_seed(0);
-  SECTION("gets the correct value for the right type") {
-    SECTION("int") {
+  CATCH_SECTION("gets the correct value for the right type") {
+    CATCH_SECTION("int") {
       auto value = make_value(5);
       // const and non-const types have the same typeid()
-      REQUIRE(value.try_get<int>() != nullptr);
-      REQUIRE(value.try_get<const int>() != nullptr);
-      REQUIRE(value.get<int>() == 5);
+      CATCH_REQUIRE(value.try_get<int>() != nullptr);
+      CATCH_REQUIRE(value.try_get<const int>() != nullptr);
+      CATCH_REQUIRE(value.get<int>() == 5);
     }
-    SECTION("const int") {
+    CATCH_SECTION("const int") {
       auto value = make_value(5);
-      REQUIRE(value.try_get<const int>() != nullptr);
-      REQUIRE(value.try_get<int>() != nullptr);
-      REQUIRE(value.get<const int>() == 5);
+      CATCH_REQUIRE(value.try_get<const int>() != nullptr);
+      CATCH_REQUIRE(value.try_get<int>() != nullptr);
+      CATCH_REQUIRE(value.get<const int>() == 5);
     }
-    SECTION("const char*") {
+    CATCH_SECTION("const char*") {
       auto value = make_value("hello");
-      REQUIRE(value.try_get<const char*>() != nullptr);
-      REQUIRE(value.get<const char*>() == std::string("hello"));
+      CATCH_REQUIRE(value.try_get<const char*>() != nullptr);
+      CATCH_REQUIRE(value.get<const char*>() == std::string("hello"));
     }
-    SECTION("std::string") {
+    CATCH_SECTION("std::string") {
       auto value = make_value(std::string("hello"));
-      REQUIRE(value.try_get<std::string>() != nullptr);
-      REQUIRE(value.get<std::string>() == "hello");
+      CATCH_REQUIRE(value.try_get<std::string>() != nullptr);
+      CATCH_REQUIRE(value.get<std::string>() == "hello");
     }
-    SECTION("pointers") {
+    CATCH_SECTION("pointers") {
       std::string s("hello");
       std::string* p = &s;
       auto value = make_value(p);
-      REQUIRE(value.try_get<std::string*>() != nullptr);
-      REQUIRE(*value.get<std::string*>() == "hello");
+      CATCH_REQUIRE(value.try_get<std::string*>() != nullptr);
+      CATCH_REQUIRE(*value.get<std::string*>() == "hello");
     }
-    SECTION("references") {
+    CATCH_SECTION("references") {
       std::string s("hello");
       const std::string& t = s;
       auto value = make_value(t);
-      REQUIRE(value.try_get<std::string>() != nullptr);
-      REQUIRE(value.get<std::string>() == "hello");
+      CATCH_REQUIRE(value.try_get<std::string>() != nullptr);
+      CATCH_REQUIRE(value.get<std::string>() == "hello");
     }
   }
-  SECTION("try_get returns nullptr for the wrong type") {
+  CATCH_SECTION("try_get returns nullptr for the wrong type") {
     auto value = make_value(5);
-    REQUIRE(value.try_get<int>() != nullptr);
-    REQUIRE(value.try_get<float>() == nullptr);
-    REQUIRE(value.try_get<long>() == nullptr);
-    REQUIRE(value.try_get<std::string>() == nullptr);
+    CATCH_REQUIRE(value.try_get<int>() != nullptr);
+    CATCH_REQUIRE(value.try_get<float>() == nullptr);
+    CATCH_REQUIRE(value.try_get<long>() == nullptr);
+    CATCH_REQUIRE(value.try_get<std::string>() == nullptr);
   }
-  SECTION("get throws for the wrong type") {
+  CATCH_SECTION("get throws for the wrong type") {
     auto value = make_value(5);
-    REQUIRE(value.try_get<int>() != nullptr);
-    REQUIRE_THROWS_WITH(
+    CATCH_REQUIRE(value.try_get<int>() != nullptr);
+    CATCH_REQUIRE_THROWS_WITH(
         value.get<float>(),
         StartsWith("Attempted to cast Value to float, "
                    "but its actual type is int"));
-    REQUIRE_THROWS_WITH(
+    CATCH_REQUIRE_THROWS_WITH(
         value.get<long>(),
         StartsWith("Attempted to cast Value to long, "
                    "but its actual type is int"));
   }
-  SECTION("move is allowed") {
+  CATCH_SECTION("move is allowed") {
     auto value = make_value(5);
-    SECTION("construction") {
+    CATCH_SECTION("construction") {
       auto copy = make_value(std::move(value));
-      REQUIRE(copy.try_get<int>() != nullptr);
-      REQUIRE(copy.get<int>() == 5);
+      CATCH_REQUIRE(copy.try_get<int>() != nullptr);
+      CATCH_REQUIRE(copy.get<int>() == 5);
     }
-    SECTION("assignment") {
+    CATCH_SECTION("assignment") {
       auto copy = make_value(10);
       copy = std::move(value);
-      REQUIRE(copy.try_get<int>() != nullptr);
-      REQUIRE(copy.get<int>() == 5);
+      CATCH_REQUIRE(copy.try_get<int>() != nullptr);
+      CATCH_REQUIRE(copy.get<int>() == 5);
     }
   }
-  SECTION("type_info is correct") {
-    SECTION("int") {
+  CATCH_SECTION("type_info is correct") {
+    CATCH_SECTION("int") {
       auto value = make_value(5);
-      REQUIRE(value.type_info().hash_code() == typeid(int).hash_code());
+      CATCH_REQUIRE(value.type_info().hash_code() == typeid(int).hash_code());
     }
-    SECTION("const char") {
+    CATCH_SECTION("const char") {
       auto value = make_value("hello");
-      REQUIRE(value.type_info().hash_code() == typeid(const char*).hash_code());
+      CATCH_REQUIRE(value.type_info().hash_code() == typeid(const char*).hash_code());
     }
-    SECTION("std::string") {
+    CATCH_SECTION("std::string") {
       auto value = make_value(std::string("hello"));
-      REQUIRE(value.type_info().hash_code() == typeid(std::string).hash_code());
+      CATCH_REQUIRE(value.type_info().hash_code() == typeid(std::string).hash_code());
     }
   }
 }
diff --git a/test/cpp/api/catch_utils.hpp b/test/cpp/api/catch_utils.hpp
new file mode 100644
index 0000000..b9b0a87
--- /dev/null
+++ b/test/cpp/api/catch_utils.hpp
@@ -0,0 +1,8 @@
+#pragma once
+
+#define CATCH_CONFIG_PREFIX_ALL
+#include <catch.hpp>
+
+// CATCH_REQUIRE_THROWS is not defined identically to REQUIRE_THROWS and causes warning;
+// define our own version that doesn't warn.
+#define _CATCH_REQUIRE_THROWS( ... ) INTERNAL_CATCH_THROWS( "CATCH_REQUIRE_THROWS", Catch::ResultDisposition::Normal, __VA_ARGS__ )
diff --git a/test/cpp/api/cursor.cpp b/test/cpp/api/cursor.cpp
index 5c99866..e08bd78 100644
--- a/test/cpp/api/cursor.cpp
+++ b/test/cpp/api/cursor.cpp
@@ -1,4 +1,4 @@
-#include <catch.hpp>
+#include "catch_utils.hpp"
 
 #include <torch/nn/cursor.h>
 #include <torch/nn/module.h>
@@ -58,158 +58,158 @@
   std::vector<std::shared_ptr<Module>> m;
 };
 
-TEST_CASE("cursor/module") {
+CATCH_TEST_CASE("cursor/module") {
   torch::manual_seed(0);
-  SECTION("Works for flat models (depth = 1)") {
+  CATCH_SECTION("Works for flat models (depth = 1)") {
     Container model(TestModule(1), TestModule(2), TestModule(3));
     auto cursor = model.modules();
 
-    SECTION("Iterates in the correct order") {
+    CATCH_SECTION("Iterates in the correct order") {
       auto iterator = cursor.begin();
-      REQUIRE(&iterator->value == &model[0]);
-      REQUIRE(&(++iterator)->value == &model[1]);
-      REQUIRE(&(++iterator)->value == &model[2]);
-      REQUIRE(++iterator == cursor.end());
+      CATCH_REQUIRE(&iterator->value == &model[0]);
+      CATCH_REQUIRE(&(++iterator)->value == &model[1]);
+      CATCH_REQUIRE(&(++iterator)->value == &model[2]);
+      CATCH_REQUIRE(++iterator == cursor.end());
     }
 
-    SECTION("names are flat") {
+    CATCH_SECTION("names are flat") {
       auto iterator = cursor.begin();
-      REQUIRE(iterator->key == "0");
-      REQUIRE((++iterator)->key == "1");
-      REQUIRE((++iterator)->key == "2");
+      CATCH_REQUIRE(iterator->key == "0");
+      CATCH_REQUIRE((++iterator)->key == "1");
+      CATCH_REQUIRE((++iterator)->key == "2");
     }
 
-    SECTION("Apply works") {
+    CATCH_SECTION("Apply works") {
       size_t count = 0;
       cursor.apply([&count, &model](Module& module) {
-        REQUIRE(&module == &model[count]);
+        CATCH_REQUIRE(&module == &model[count]);
         count += 1;
       });
-      REQUIRE(count == 3);
+      CATCH_REQUIRE(count == 3);
     }
 
-    SECTION("Apply_items works") {
+    CATCH_SECTION("Apply_items works") {
       size_t count = 0;
       cursor.apply_items(
           [&count, &model](const std::string& key, Module& module) {
-            REQUIRE(&module == &model[count]);
+            CATCH_REQUIRE(&module == &model[count]);
             count += 1;
           });
-      REQUIRE(count == 3);
+      CATCH_REQUIRE(count == 3);
     }
 
-    SECTION("Map works") {
+    CATCH_SECTION("Map works") {
       std::vector<Module*> vector(3);
       cursor.map(vector.begin(), [](Module& module) { return &module; });
-      REQUIRE(vector[0] == &model[0]);
-      REQUIRE(vector[1] == &model[1]);
-      REQUIRE(vector[2] == &model[2]);
+      CATCH_REQUIRE(vector[0] == &model[0]);
+      CATCH_REQUIRE(vector[1] == &model[1]);
+      CATCH_REQUIRE(vector[2] == &model[2]);
 
       std::list<Module*> list;
       cursor.map(std::inserter(list, list.end()), [](Module& module) {
         return &module;
       });
-      REQUIRE(list.size() == 3);
+      CATCH_REQUIRE(list.size() == 3);
       auto iterator = list.begin();
-      REQUIRE(*iterator++ == &model[0]);
-      REQUIRE(*iterator++ == &model[1]);
-      REQUIRE(*iterator++ == &model[2]);
-      REQUIRE(iterator == list.end());
+      CATCH_REQUIRE(*iterator++ == &model[0]);
+      CATCH_REQUIRE(*iterator++ == &model[1]);
+      CATCH_REQUIRE(*iterator++ == &model[2]);
+      CATCH_REQUIRE(iterator == list.end());
     }
 
-    SECTION("Map_items works") {
+    CATCH_SECTION("Map_items works") {
       std::map<std::string, Module*> output;
       cursor.map_items(
           std::inserter(output, output.end()),
           [](const std::string& key, Module& module) {
             return std::make_pair(key, &module);
           });
-      REQUIRE(output.size() == 3);
-      REQUIRE(output.count("0"));
-      REQUIRE(output.count("1"));
-      REQUIRE(output.count("2"));
-      REQUIRE(output["0"] == &model[0]);
-      REQUIRE(output["1"] == &model[1]);
-      REQUIRE(output["2"] == &model[2]);
+      CATCH_REQUIRE(output.size() == 3);
+      CATCH_REQUIRE(output.count("0"));
+      CATCH_REQUIRE(output.count("1"));
+      CATCH_REQUIRE(output.count("2"));
+      CATCH_REQUIRE(output["0"] == &model[0]);
+      CATCH_REQUIRE(output["1"] == &model[1]);
+      CATCH_REQUIRE(output["2"] == &model[2]);
     }
 
-    SECTION("Count works for flat models") {
-      REQUIRE(cursor.size() == model.m.size());
+    CATCH_SECTION("Count works for flat models") {
+      CATCH_REQUIRE(cursor.size() == model.m.size());
     }
 
-    SECTION("find() finds the correct modules when given a valid key") {
-      REQUIRE(cursor.find("0") == &model[0]);
-      REQUIRE(cursor.find("1") == &model[1]);
-      REQUIRE(cursor.find("2") == &model[2]);
+    CATCH_SECTION("find() finds the correct modules when given a valid key") {
+      CATCH_REQUIRE(cursor.find("0") == &model[0]);
+      CATCH_REQUIRE(cursor.find("1") == &model[1]);
+      CATCH_REQUIRE(cursor.find("2") == &model[2]);
     }
 
-    SECTION("find() returns nullptr when given an invalid key") {
-      REQUIRE(cursor.find("foo") == nullptr);
-      REQUIRE(cursor.find("bar") == nullptr);
+    CATCH_SECTION("find() returns nullptr when given an invalid key") {
+      CATCH_REQUIRE(cursor.find("foo") == nullptr);
+      CATCH_REQUIRE(cursor.find("bar") == nullptr);
     }
 
-    SECTION("at(key) returns the correct modules when given a valid key") {
-      REQUIRE(&cursor.at("0") == &model[0]);
-      REQUIRE(&cursor.at("1") == &model[1]);
-      REQUIRE(&cursor.at("2") == &model[2]);
+    CATCH_SECTION("at(key) returns the correct modules when given a valid key") {
+      CATCH_REQUIRE(&cursor.at("0") == &model[0]);
+      CATCH_REQUIRE(&cursor.at("1") == &model[1]);
+      CATCH_REQUIRE(&cursor.at("2") == &model[2]);
     }
 
-    SECTION("at(key) throws when given an invalid key") {
-      REQUIRE_THROWS_WITH(cursor.at("foo"), StartsWith("No such key: 'foo'"));
-      REQUIRE_THROWS_WITH(cursor.at("bar"), StartsWith("No such key: 'bar'"));
+    CATCH_SECTION("at(key) throws when given an invalid key") {
+      CATCH_REQUIRE_THROWS_WITH(cursor.at("foo"), StartsWith("No such key: 'foo'"));
+      CATCH_REQUIRE_THROWS_WITH(cursor.at("bar"), StartsWith("No such key: 'bar'"));
     }
 
-    SECTION(
+    CATCH_SECTION(
         "operator[key] returns the correct modules when given a valid key") {
-      REQUIRE(&cursor["0"] == &model[0]);
-      REQUIRE(&cursor["1"] == &model[1]);
-      REQUIRE(&cursor["2"] == &model[2]);
+      CATCH_REQUIRE(&cursor["0"] == &model[0]);
+      CATCH_REQUIRE(&cursor["1"] == &model[1]);
+      CATCH_REQUIRE(&cursor["2"] == &model[2]);
     }
 
-    SECTION("operator[key] throws when given an invalid key") {
-      REQUIRE_THROWS_WITH(cursor["foo"], StartsWith("No such key: 'foo'"));
-      REQUIRE_THROWS_WITH(cursor["bar"], StartsWith("No such key: 'bar'"));
+    CATCH_SECTION("operator[key] throws when given an invalid key") {
+      CATCH_REQUIRE_THROWS_WITH(cursor["foo"], StartsWith("No such key: 'foo'"));
+      CATCH_REQUIRE_THROWS_WITH(cursor["bar"], StartsWith("No such key: 'bar'"));
     }
 
-    SECTION("at(index) returns the correct modules when given a valid index") {
-      REQUIRE(&cursor.at(0).value == &model[0]);
-      REQUIRE(&cursor.at(1).value == &model[1]);
-      REQUIRE(&cursor.at(2).value == &model[2]);
+    CATCH_SECTION("at(index) returns the correct modules when given a valid index") {
+      CATCH_REQUIRE(&cursor.at(0).value == &model[0]);
+      CATCH_REQUIRE(&cursor.at(1).value == &model[1]);
+      CATCH_REQUIRE(&cursor.at(2).value == &model[2]);
     }
 
-    SECTION("at(index) throws when given an invalid index") {
-      REQUIRE_THROWS_WITH(
+    CATCH_SECTION("at(index) throws when given an invalid index") {
+      CATCH_REQUIRE_THROWS_WITH(
           cursor.at(5),
           StartsWith("Index 5 is out of range for cursor of size 3"));
-      REQUIRE_THROWS_WITH(
+      CATCH_REQUIRE_THROWS_WITH(
           cursor.at(123),
           StartsWith("Index 123 is out of range for cursor of size 3"));
     }
 
-    SECTION(
+    CATCH_SECTION(
         "operator[index] returns the correct modules when given a valid index") {
-      REQUIRE(&cursor[0].value == &model[0]);
-      REQUIRE(&cursor[1].value == &model[1]);
-      REQUIRE(&cursor[2].value == &model[2]);
+      CATCH_REQUIRE(&cursor[0].value == &model[0]);
+      CATCH_REQUIRE(&cursor[1].value == &model[1]);
+      CATCH_REQUIRE(&cursor[2].value == &model[2]);
     }
 
-    SECTION("operator[index] throws when given an invalid key") {
-      REQUIRE_THROWS_WITH(
+    CATCH_SECTION("operator[index] throws when given an invalid key") {
+      CATCH_REQUIRE_THROWS_WITH(
           cursor[5],
           StartsWith("Index 5 is out of range for cursor of size 3"));
-      REQUIRE_THROWS_WITH(
+      CATCH_REQUIRE_THROWS_WITH(
           cursor[123],
           StartsWith("Index 123 is out of range for cursor of size 3"));
     }
 
-    SECTION("contains() is correct") {
-      REQUIRE(cursor.contains("0"));
-      REQUIRE(cursor.contains("1"));
-      REQUIRE(cursor.contains("2"));
+    CATCH_SECTION("contains() is correct") {
+      CATCH_REQUIRE(cursor.contains("0"));
+      CATCH_REQUIRE(cursor.contains("1"));
+      CATCH_REQUIRE(cursor.contains("2"));
     }
   }
 
-  SECTION("Works for deeper hierarchies (depth > 1)") {
+  CATCH_SECTION("Works for deeper hierarchies (depth > 1)") {
     // clang-format off
     Container model(
         Container(
@@ -227,106 +227,106 @@
     auto cursor = model.modules();
     // This is sufficient for the hierarchical case
     // (other tests build on top)
-    SECTION("Iterates in the correct order") {
+    CATCH_SECTION("Iterates in the correct order") {
       auto iterator = cursor.begin();
 
-      REQUIRE(&iterator->value == &model[0]);
+      CATCH_REQUIRE(&iterator->value == &model[0]);
 
       auto* seq = dynamic_cast<Container*>(&model[0]);
-      REQUIRE(seq != nullptr);
-      REQUIRE(&(++iterator)->value == &(*seq)[0]);
-      REQUIRE(&(++iterator)->value == &(*seq)[1]);
+      CATCH_REQUIRE(seq != nullptr);
+      CATCH_REQUIRE(&(++iterator)->value == &(*seq)[0]);
+      CATCH_REQUIRE(&(++iterator)->value == &(*seq)[1]);
 
-      REQUIRE(&(++iterator)->value == &model[1]);
-      REQUIRE(&(++iterator)->value == &model[2]);
+      CATCH_REQUIRE(&(++iterator)->value == &model[1]);
+      CATCH_REQUIRE(&(++iterator)->value == &model[2]);
 
       seq = dynamic_cast<Container*>(&model[2]);
-      REQUIRE(seq != nullptr);
-      REQUIRE(&(++iterator)->value == &(*seq)[0]);
-      REQUIRE(&(++iterator)->value == &(*seq)[1]);
+      CATCH_REQUIRE(seq != nullptr);
+      CATCH_REQUIRE(&(++iterator)->value == &(*seq)[0]);
+      CATCH_REQUIRE(&(++iterator)->value == &(*seq)[1]);
 
       seq = dynamic_cast<Container*>(&(*seq)[1]);
-      REQUIRE(seq != nullptr);
-      REQUIRE(&(++iterator)->value == &(*seq)[0]);
-      REQUIRE(&(++iterator)->value == &(*seq)[1]);
+      CATCH_REQUIRE(seq != nullptr);
+      CATCH_REQUIRE(&(++iterator)->value == &(*seq)[0]);
+      CATCH_REQUIRE(&(++iterator)->value == &(*seq)[1]);
     }
 
-    SECTION("children() returns only the first level of submodules") {
+    CATCH_SECTION("children() returns only the first level of submodules") {
       auto children = model.children();
-      REQUIRE(children.size() == 3);
-      REQUIRE(&children.at("0") == &model[0]);
-      REQUIRE(&children.at("1") == &model[1]);
-      REQUIRE(&children.at("2") == &model[2]);
-      REQUIRE(!children.contains("0.0"));
+      CATCH_REQUIRE(children.size() == 3);
+      CATCH_REQUIRE(&children.at("0") == &model[0]);
+      CATCH_REQUIRE(&children.at("1") == &model[1]);
+      CATCH_REQUIRE(&children.at("2") == &model[2]);
+      CATCH_REQUIRE(!children.contains("0.0"));
       size_t count = 0;
       for (auto& child : children) {
-        REQUIRE(child.key == std::to_string(count));
-        REQUIRE(&child.value == &model[count]);
+        CATCH_REQUIRE(child.key == std::to_string(count));
+        CATCH_REQUIRE(&child.value == &model[count]);
         count += 1;
       }
     }
   }
 }
 
-TEST_CASE("cursor/parameter") {
+CATCH_TEST_CASE("cursor/parameter") {
   torch::manual_seed(0);
-  SECTION("Works for single models") {
+  CATCH_SECTION("Works for single models") {
     TestModule model(1);
     auto cursor = model.parameters();
 
-    SECTION("Iterates in the correct order") {
+    CATCH_SECTION("Iterates in the correct order") {
       auto iterator = cursor.begin();
-      REQUIRE(iterator->value.equal(model.tensor1));
-      REQUIRE((++iterator)->value.equal(model.tensor2));
+      CATCH_REQUIRE(iterator->value.equal(model.tensor1));
+      CATCH_REQUIRE((++iterator)->value.equal(model.tensor2));
     }
   }
 
-  SECTION("Works for flat models (depth = 1)") {
+  CATCH_SECTION("Works for flat models (depth = 1)") {
     auto first = std::make_shared<TestModule>(1);
     auto second = std::make_shared<TestModule>(2);
     Container model(first, second);
     auto cursor = model.parameters();
 
-    SECTION("Iterates in the correct order") {
+    CATCH_SECTION("Iterates in the correct order") {
       auto iterator = cursor.begin();
-      REQUIRE(iterator->value.equal(first->tensor1));
-      REQUIRE((++iterator)->value.equal(first->tensor2));
-      REQUIRE((++iterator)->value.equal(second->tensor1));
-      REQUIRE((++iterator)->value.equal(second->tensor2));
+      CATCH_REQUIRE(iterator->value.equal(first->tensor1));
+      CATCH_REQUIRE((++iterator)->value.equal(first->tensor2));
+      CATCH_REQUIRE((++iterator)->value.equal(second->tensor1));
+      CATCH_REQUIRE((++iterator)->value.equal(second->tensor2));
     }
 
-    SECTION("Apply_items works") {
+    CATCH_SECTION("Apply_items works") {
       size_t count = 0;
       cursor.apply_items([&count, &model, &first, &second](
                              const std::string& key, torch::Tensor& tensor) {
         switch (count) {
           case 0: {
-            REQUIRE(tensor.equal(first->tensor1));
+            CATCH_REQUIRE(tensor.equal(first->tensor1));
             break;
           }
           case 1: {
-            REQUIRE(tensor.equal(first->tensor2));
+            CATCH_REQUIRE(tensor.equal(first->tensor2));
             break;
           }
           case 2: {
-            REQUIRE(tensor.equal(second->tensor1));
+            CATCH_REQUIRE(tensor.equal(second->tensor1));
             break;
           }
           case 3: {
-            REQUIRE(tensor.equal(second->tensor2));
+            CATCH_REQUIRE(tensor.equal(second->tensor2));
             break;
           }
         }
         count += 1;
       });
-      REQUIRE(count == 4);
+      CATCH_REQUIRE(count == 4);
     }
 
     // Other tests are correct based on correct iteration behavior and apply
     // working.
   }
 
-  SECTION("Works for deeper hierarchies (depth > 1)") {
+  CATCH_SECTION("Works for deeper hierarchies (depth > 1)") {
     std::vector<std::shared_ptr<TestModule>> modules;
     for (size_t i = 1; i <= 6; ++i) {
       modules.push_back(std::make_shared<TestModule>(i));
@@ -346,36 +346,36 @@
     // clang-format on
     auto cursor = model.parameters();
 
-    SECTION("Iterates in the correct order") {
+    CATCH_SECTION("Iterates in the correct order") {
       auto iterator = cursor.begin();
-      REQUIRE(iterator->value.equal(modules[0]->tensor1));
-      REQUIRE((++iterator)->value.equal(modules[0]->tensor2));
+      CATCH_REQUIRE(iterator->value.equal(modules[0]->tensor1));
+      CATCH_REQUIRE((++iterator)->value.equal(modules[0]->tensor2));
       for (size_t index = 1; index < 6; ++index) {
-        REQUIRE((++iterator)->value.equal(modules[index]->tensor1));
-        REQUIRE((++iterator)->value.equal(modules[index]->tensor2));
+        CATCH_REQUIRE((++iterator)->value.equal(modules[index]->tensor1));
+        CATCH_REQUIRE((++iterator)->value.equal(modules[index]->tensor2));
       }
     }
 
-    SECTION("names are hierarchical") {
+    CATCH_SECTION("names are hierarchical") {
       auto iterator = cursor.begin();
-      REQUIRE(iterator->key == "0.0.tensor1");
-      REQUIRE((++iterator)->key == "0.0.tensor2");
-      REQUIRE((++iterator)->key == "0.1.tensor1");
-      REQUIRE((++iterator)->key == "0.1.tensor2");
-      REQUIRE((++iterator)->key == "1.tensor1");
-      REQUIRE((++iterator)->key == "1.tensor2");
-      REQUIRE((++iterator)->key == "2.0.tensor1");
-      REQUIRE((++iterator)->key == "2.0.tensor2");
-      REQUIRE((++iterator)->key == "2.1.0.tensor1");
-      REQUIRE((++iterator)->key == "2.1.0.tensor2");
-      REQUIRE((++iterator)->key == "2.1.1.tensor1");
-      REQUIRE((++iterator)->key == "2.1.1.tensor2");
-      REQUIRE(++iterator == cursor.end());
+      CATCH_REQUIRE(iterator->key == "0.0.tensor1");
+      CATCH_REQUIRE((++iterator)->key == "0.0.tensor2");
+      CATCH_REQUIRE((++iterator)->key == "0.1.tensor1");
+      CATCH_REQUIRE((++iterator)->key == "0.1.tensor2");
+      CATCH_REQUIRE((++iterator)->key == "1.tensor1");
+      CATCH_REQUIRE((++iterator)->key == "1.tensor2");
+      CATCH_REQUIRE((++iterator)->key == "2.0.tensor1");
+      CATCH_REQUIRE((++iterator)->key == "2.0.tensor2");
+      CATCH_REQUIRE((++iterator)->key == "2.1.0.tensor1");
+      CATCH_REQUIRE((++iterator)->key == "2.1.0.tensor2");
+      CATCH_REQUIRE((++iterator)->key == "2.1.1.tensor1");
+      CATCH_REQUIRE((++iterator)->key == "2.1.1.tensor2");
+      CATCH_REQUIRE(++iterator == cursor.end());
     }
   }
 }
 
-TEST_CASE("cursor/non-const-to-const-conversion") {
+CATCH_TEST_CASE("cursor/non-const-to-const-conversion") {
   torch::manual_seed(0);
   auto first = std::make_shared<TestModule>(1);
   auto second = std::make_shared<TestModule>(2);
@@ -404,11 +404,11 @@
   }
 }
 
-TEST_CASE("cursor/can-invoke-const-method-on-const-cursor") {
+CATCH_TEST_CASE("cursor/can-invoke-const-method-on-const-cursor") {
   torch::manual_seed(0);
   TestModule model(1);
 
   /// This will only compile if `Cursor` has the appropriate const methods.
   const auto cursor = model.parameters();
-  REQUIRE(cursor.contains("tensor1"));
+  CATCH_REQUIRE(cursor.contains("tensor1"));
 }
diff --git a/test/cpp/api/integration.cpp b/test/cpp/api/integration.cpp
index 8d75319..972223a 100644
--- a/test/cpp/api/integration.cpp
+++ b/test/cpp/api/integration.cpp
@@ -1,4 +1,4 @@
-#include <catch.hpp>
+#include "catch_utils.hpp"
 
 #include <torch/nn/modules/batchnorm.h>
 #include <torch/nn/modules/conv.h>
@@ -230,7 +230,7 @@
   return correct.sum().toCFloat() > telabel.size(0) * 0.8;
 }
 
-TEST_CASE("integration/cartpole") {
+CATCH_TEST_CASE("integration/cartpole") {
   torch::manual_seed(0);
   std::cerr << "Training episodic policy gradient with a critic for up to 3000"
                " episodes, rest your eyes for a bit!\n";
@@ -326,11 +326,11 @@
     if (running_reward > 150) {
       break;
     }
-    REQUIRE(episode < 3000);
+    CATCH_REQUIRE(episode < 3000);
   }
 }
 
-TEST_CASE("integration/mnist", "[cuda]") {
+CATCH_TEST_CASE("integration/mnist", "[cuda]") {
   torch::manual_seed(0);
   auto model = std::make_shared<SimpleContainer>();
   auto conv1 = model->add(Conv2d(1, 10, 5), "conv1");
@@ -357,7 +357,7 @@
   auto optimizer = torch::optim::SGD(
       model->parameters(), torch::optim::SGDOptions(1e-2).momentum(0.5));
 
-  REQUIRE(test_mnist(
+  CATCH_REQUIRE(test_mnist(
       32, // batch_size
       3, // num_epochs
       true, // useGPU
@@ -366,7 +366,7 @@
       optimizer));
 }
 
-TEST_CASE("integration/mnist/batchnorm", "[cuda]") {
+CATCH_TEST_CASE("integration/mnist/batchnorm", "[cuda]") {
   torch::manual_seed(0);
   auto model = std::make_shared<SimpleContainer>();
   auto conv1 = model->add(Conv2d(1, 10, 5), "conv1");
@@ -393,7 +393,7 @@
   auto optimizer = torch::optim::SGD(
       model->parameters(), torch::optim::SGDOptions(1e-2).momentum(0.5));
 
-  REQUIRE(test_mnist(
+  CATCH_REQUIRE(test_mnist(
       32, // batch_size
       3, // num_epochs
       true, // useGPU
diff --git a/test/cpp/api/jit.cpp b/test/cpp/api/jit.cpp
index c46868c..b477b11 100644
--- a/test/cpp/api/jit.cpp
+++ b/test/cpp/api/jit.cpp
@@ -1,12 +1,12 @@
-#include <catch.hpp>
+#include "catch_utils.hpp"
 
 #include <torch/jit.h>
 #include <torch/tensor.h>
 
 #include <string>
 
-TEST_CASE("torch script") {
-  SECTION("multiple functions") {
+CATCH_TEST_CASE("torch script") {
+  CATCH_SECTION("multiple functions") {
     auto module = torch::jit::compile(R"JIT(
       def test_mul(a, b):
         return a * b
@@ -21,11 +21,11 @@
     auto a = torch::ones(1);
     auto b = torch::ones(1);
 
-    REQUIRE(1 == module->run_method("test_mul", a, b).toTensor().toCLong());
+    CATCH_REQUIRE(1 == module->run_method("test_mul", a, b).toTensor().toCLong());
 
-    REQUIRE(2 == module->run_method("test_relu", a, b).toTensor().toCLong());
+    CATCH_REQUIRE(2 == module->run_method("test_relu", a, b).toTensor().toCLong());
 
-    REQUIRE(
+    CATCH_REQUIRE(
         0x200 == module->run_method("test_while", a, b).toTensor().toCLong());
   }
 }
diff --git a/test/cpp/api/main.cpp b/test/cpp/api/main.cpp
index 4b1aaba..92ea356 100644
--- a/test/cpp/api/main.cpp
+++ b/test/cpp/api/main.cpp
@@ -1,5 +1,5 @@
 #define CATCH_CONFIG_RUNNER
-#include <catch.hpp>
+#include "catch_utils.hpp"
 
 #include <torch/cuda.h>
 
diff --git a/test/cpp/api/misc.cpp b/test/cpp/api/misc.cpp
index 6d065bf..8ced0e0 100644
--- a/test/cpp/api/misc.cpp
+++ b/test/cpp/api/misc.cpp
@@ -1,4 +1,4 @@
-#include <catch.hpp>
+#include "catch_utils.hpp"
 
 #include <torch/detail/ordered_dict.h>
 #include <torch/expanding_array.h>
@@ -18,7 +18,7 @@
 
 using Catch::StartsWith;
 
-TEST_CASE("NoGrad") {
+CATCH_TEST_CASE("NoGrad") {
   torch::manual_seed(0);
   torch::NoGradGuard guard;
   Linear model(5, 2);
@@ -27,88 +27,88 @@
   torch::Tensor s = y.sum();
 
   s.backward();
-  REQUIRE(!model->parameters()["weight"].grad().defined());
+  CATCH_REQUIRE(!model->parameters()["weight"].grad().defined());
 }
 
-TEST_CASE("autograd") {
+CATCH_TEST_CASE("autograd") {
   torch::manual_seed(0);
   auto x = torch::randn({3, 3}, torch::requires_grad());
   auto y = torch::randn({3, 3});
   auto z = x * y;
-  SECTION("derivatives of zero-dim tensors") {
+  CATCH_SECTION("derivatives of zero-dim tensors") {
     z.sum().backward();
-    REQUIRE(x.grad().allclose(y));
+    CATCH_REQUIRE(x.grad().allclose(y));
   }
-  SECTION("derivatives of tensors") {
+  CATCH_SECTION("derivatives of tensors") {
     z.backward();
-    REQUIRE(x.grad().allclose(y));
+    CATCH_REQUIRE(x.grad().allclose(y));
   }
-  SECTION("custom gradient inputs") {
+  CATCH_SECTION("custom gradient inputs") {
     z.sum().backward(torch::ones({}) * 2);
-    REQUIRE(x.grad().allclose(y * 2));
+    CATCH_REQUIRE(x.grad().allclose(y * 2));
   }
   // Assume everything else is safe from PyTorch tests.
 }
 
-TEST_CASE("nn::init") {
+CATCH_TEST_CASE("nn::init") {
   auto tensor = torch::empty({3, 4}, torch::requires_grad());
-  REQUIRE_THROWS_WITH(
+  CATCH_REQUIRE_THROWS_WITH(
       tensor.fill_(1),
       StartsWith("a leaf Variable that requires grad "
                  "has been used in an in-place operation"));
-  REQUIRE(torch::nn::init::ones_(tensor).sum().toCInt() == 12);
+  CATCH_REQUIRE(torch::nn::init::ones_(tensor).sum().toCInt() == 12);
 }
 
-TEST_CASE("expanding-array") {
+CATCH_TEST_CASE("expanding-array") {
   torch::manual_seed(0);
-  SECTION("successful construction") {
-    SECTION("initializer_list") {
+  CATCH_SECTION("successful construction") {
+    CATCH_SECTION("initializer_list") {
       torch::ExpandingArray<5> e({1, 2, 3, 4, 5});
-      REQUIRE(e.size() == 5);
+      CATCH_REQUIRE(e.size() == 5);
       for (size_t i = 0; i < e.size(); ++i) {
-        REQUIRE((*e)[i] == i + 1);
+        CATCH_REQUIRE((*e)[i] == i + 1);
       }
     }
 
-    SECTION("vector") {
+    CATCH_SECTION("vector") {
       torch::ExpandingArray<5> e(std::vector<int64_t>{1, 2, 3, 4, 5});
-      REQUIRE(e.size() == 5);
+      CATCH_REQUIRE(e.size() == 5);
       for (size_t i = 0; i < e.size(); ++i) {
-        REQUIRE((*e)[i] == i + 1);
+        CATCH_REQUIRE((*e)[i] == i + 1);
       }
     }
 
-    SECTION("array") {
+    CATCH_SECTION("array") {
       torch::ExpandingArray<5> e(std::array<int64_t, 5>({1, 2, 3, 4, 5}));
-      REQUIRE(e.size() == 5);
+      CATCH_REQUIRE(e.size() == 5);
       for (size_t i = 0; i < e.size(); ++i) {
-        REQUIRE((*e)[i] == i + 1);
+        CATCH_REQUIRE((*e)[i] == i + 1);
       }
     }
 
-    SECTION("single value") {
+    CATCH_SECTION("single value") {
       torch::ExpandingArray<5> e(5);
-      REQUIRE(e.size() == 5);
+      CATCH_REQUIRE(e.size() == 5);
       for (size_t i = 0; i < e.size(); ++i) {
-        REQUIRE((*e)[i] == 5);
+        CATCH_REQUIRE((*e)[i] == 5);
       }
     }
   }
-  SECTION("throws for incorrect size on construction") {
-    SECTION("initializer_list") {
-      REQUIRE_THROWS_WITH(
+  CATCH_SECTION("throws for incorrect size on construction") {
+    CATCH_SECTION("initializer_list") {
+      CATCH_REQUIRE_THROWS_WITH(
           torch::ExpandingArray<5>({1, 2, 3, 4, 5, 6, 7}),
           StartsWith("Expected 5 values, but instead got 7"));
     }
-    SECTION("vector") {
-      REQUIRE_THROWS_WITH(
+    CATCH_SECTION("vector") {
+      CATCH_REQUIRE_THROWS_WITH(
           torch::ExpandingArray<5>(std::vector<int64_t>({1, 2, 3, 4, 5, 6, 7})),
           StartsWith("Expected 5 values, but instead got 7"));
     }
   }
 }
 
-TEST_CASE("make_unique") {
+CATCH_TEST_CASE("make_unique") {
   struct Test {
     explicit Test(const int& x) : lvalue_(x) {}
     explicit Test(int&& x) : rvalue_(x) {}
@@ -117,216 +117,216 @@
     at::optional<int> rvalue_;
   };
 
-  SECTION("forwards rvalues correctly") {
+  CATCH_SECTION("forwards rvalues correctly") {
     auto ptr = torch::make_unique<Test>(123);
-    REQUIRE(!ptr->lvalue_.has_value());
-    REQUIRE(ptr->rvalue_.has_value());
-    REQUIRE(*ptr->rvalue_ == 123);
+    CATCH_REQUIRE(!ptr->lvalue_.has_value());
+    CATCH_REQUIRE(ptr->rvalue_.has_value());
+    CATCH_REQUIRE(*ptr->rvalue_ == 123);
   }
 
-  SECTION("forwards lvalues correctly") {
+  CATCH_SECTION("forwards lvalues correctly") {
     int x = 5;
     auto ptr = torch::make_unique<Test>(x);
-    REQUIRE(ptr->lvalue_.has_value());
-    REQUIRE(*ptr->lvalue_ == 5);
-    REQUIRE(!ptr->rvalue_.has_value());
+    CATCH_REQUIRE(ptr->lvalue_.has_value());
+    CATCH_REQUIRE(*ptr->lvalue_ == 5);
+    CATCH_REQUIRE(!ptr->rvalue_.has_value());
   }
 
-  SECTION("Can construct unique_ptr of array") {
+  CATCH_SECTION("Can construct unique_ptr of array") {
     auto ptr = torch::make_unique<int[]>(3);
     // Value initialization is required by the standard.
-    REQUIRE(ptr[0] == 0);
-    REQUIRE(ptr[1] == 0);
-    REQUIRE(ptr[2] == 0);
+    CATCH_REQUIRE(ptr[0] == 0);
+    CATCH_REQUIRE(ptr[1] == 0);
+    CATCH_REQUIRE(ptr[2] == 0);
   }
 }
 
-TEST_CASE("ordered-dict") {
-  SECTION("is empty after default construction") {
+CATCH_TEST_CASE("ordered-dict") {
+  CATCH_SECTION("is empty after default construction") {
     OrderedDict<int> dict;
-    REQUIRE(dict.subject() == "Key");
-    REQUIRE(dict.is_empty());
-    REQUIRE(dict.size() == 0);
+    CATCH_REQUIRE(dict.subject() == "Key");
+    CATCH_REQUIRE(dict.is_empty());
+    CATCH_REQUIRE(dict.size() == 0);
   }
 
-  SECTION("insert inserts elements when they are not yet present") {
+  CATCH_SECTION("insert inserts elements when they are not yet present") {
     OrderedDict<int> dict;
     dict.insert("a", 1);
     dict.insert("b", 2);
-    REQUIRE(dict.size() == 2);
+    CATCH_REQUIRE(dict.size() == 2);
   }
 
-  SECTION("get returns values when present") {
+  CATCH_SECTION("get returns values when present") {
     OrderedDict<int> dict;
     dict.insert("a", 1);
     dict.insert("b", 2);
-    REQUIRE(dict.get("a") == 1);
-    REQUIRE(dict.get("b") == 2);
+    CATCH_REQUIRE(dict.get("a") == 1);
+    CATCH_REQUIRE(dict.get("b") == 2);
   }
 
-  SECTION("get throws when passed keys that are not present") {
+  CATCH_SECTION("get throws when passed keys that are not present") {
     OrderedDict<int> dict;
     dict.insert("a", 1);
     dict.insert("b", 2);
-    REQUIRE_THROWS_WITH(
+    CATCH_REQUIRE_THROWS_WITH(
         dict.get("foo"), StartsWith("Key 'foo' is not defined"));
-    REQUIRE_THROWS_WITH(dict.get(""), StartsWith("Key '' is not defined"));
+    CATCH_REQUIRE_THROWS_WITH(dict.get(""), StartsWith("Key '' is not defined"));
   }
 
-  SECTION("can initialize from list") {
+  CATCH_SECTION("can initialize from list") {
     OrderedDict<int> dict = {{"a", 1}, {"b", 2}};
-    REQUIRE(dict.size() == 2);
-    REQUIRE(dict.get("a") == 1);
-    REQUIRE(dict.get("b") == 2);
+    CATCH_REQUIRE(dict.size() == 2);
+    CATCH_REQUIRE(dict.get("a") == 1);
+    CATCH_REQUIRE(dict.get("b") == 2);
   }
 
-  SECTION("insert throws when passed elements that are present") {
+  CATCH_SECTION("insert throws when passed elements that are present") {
     OrderedDict<int> dict = {{"a", 1}, {"b", 2}};
-    REQUIRE_THROWS_WITH(
+    CATCH_REQUIRE_THROWS_WITH(
         dict.insert("a", 1), StartsWith("Key 'a' already defined"));
-    REQUIRE_THROWS_WITH(
+    CATCH_REQUIRE_THROWS_WITH(
         dict.insert("b", 1), StartsWith("Key 'b' already defined"));
   }
 
-  SECTION("front() returns the first item") {
+  CATCH_SECTION("front() returns the first item") {
     OrderedDict<int> dict = {{"a", 1}, {"b", 2}};
-    REQUIRE(dict.front().key == "a");
-    REQUIRE(dict.front().value == 1);
+    CATCH_REQUIRE(dict.front().key == "a");
+    CATCH_REQUIRE(dict.front().value == 1);
   }
 
-  SECTION("back() returns the last item") {
+  CATCH_SECTION("back() returns the last item") {
     OrderedDict<int> dict = {{"a", 1}, {"b", 2}};
-    REQUIRE(dict.back().key == "b");
-    REQUIRE(dict.back().value == 2);
+    CATCH_REQUIRE(dict.back().key == "b");
+    CATCH_REQUIRE(dict.back().value == 2);
   }
 
-  SECTION("find returns pointers to values when present") {
+  CATCH_SECTION("find returns pointers to values when present") {
     OrderedDict<int> dict = {{"a", 1}, {"b", 2}};
-    REQUIRE(dict.find("a") != nullptr);
-    REQUIRE(*dict.find("a") == 1);
-    REQUIRE(dict.find("b") != nullptr);
-    REQUIRE(*dict.find("b") == 2);
+    CATCH_REQUIRE(dict.find("a") != nullptr);
+    CATCH_REQUIRE(*dict.find("a") == 1);
+    CATCH_REQUIRE(dict.find("b") != nullptr);
+    CATCH_REQUIRE(*dict.find("b") == 2);
   }
 
-  SECTION("find returns null pointers when passed keys that are not present") {
+  CATCH_SECTION("find returns null pointers when passed keys that are not present") {
     OrderedDict<int> dict = {{"a", 1}, {"b", 2}};
-    REQUIRE(dict.find("bar") == nullptr);
-    REQUIRE(dict.find("") == nullptr);
+    CATCH_REQUIRE(dict.find("bar") == nullptr);
+    CATCH_REQUIRE(dict.find("") == nullptr);
   }
 
-  SECTION("operator[] returns values when passed keys that are present") {
+  CATCH_SECTION("operator[] returns values when passed keys that are present") {
     OrderedDict<int> dict = {{"a", 1}, {"b", 2}};
-    REQUIRE(dict["a"] == 1);
-    REQUIRE(dict["b"] == 2);
+    CATCH_REQUIRE(dict["a"] == 1);
+    CATCH_REQUIRE(dict["b"] == 2);
   }
 
-  SECTION("operator[] returns items positionally when passed integers") {
+  CATCH_SECTION("operator[] returns items positionally when passed integers") {
     OrderedDict<int> dict = {{"a", 1}, {"b", 2}};
-    REQUIRE(dict[0].key == "a");
-    REQUIRE(dict[0].value == 1);
-    REQUIRE(dict[1].key == "b");
-    REQUIRE(dict[1].value == 2);
+    CATCH_REQUIRE(dict[0].key == "a");
+    CATCH_REQUIRE(dict[0].value == 1);
+    CATCH_REQUIRE(dict[1].key == "b");
+    CATCH_REQUIRE(dict[1].value == 2);
   }
 
-  SECTION("operator[] throws when passed keys that are not present") {
+  CATCH_SECTION("operator[] throws when passed keys that are not present") {
     OrderedDict<int> dict = {{"a", 1}, {"b", 2}};
-    REQUIRE_THROWS_WITH(
+    CATCH_REQUIRE_THROWS_WITH(
         dict.get("foo"), StartsWith("Key 'foo' is not defined"));
-    REQUIRE_THROWS_WITH(dict.get(""), StartsWith("Key '' is not defined"));
+    CATCH_REQUIRE_THROWS_WITH(dict.get(""), StartsWith("Key '' is not defined"));
   }
 
-  SECTION("update inserts all items from another OrderedDict") {
+  CATCH_SECTION("update inserts all items from another OrderedDict") {
     OrderedDict<int> dict = {{"a", 1}, {"b", 2}};
     OrderedDict<int> dict2 = {{"c", 3}};
     dict2.update(dict);
-    REQUIRE(dict2.size() == 3);
-    REQUIRE(dict2.find("a") != nullptr);
-    REQUIRE(dict2.find("b") != nullptr);
-    REQUIRE(dict2.find("c") != nullptr);
+    CATCH_REQUIRE(dict2.size() == 3);
+    CATCH_REQUIRE(dict2.find("a") != nullptr);
+    CATCH_REQUIRE(dict2.find("b") != nullptr);
+    CATCH_REQUIRE(dict2.find("c") != nullptr);
   }
 
-  SECTION("update also checks for duplicates") {
+  CATCH_SECTION("update also checks for duplicates") {
     OrderedDict<int> dict = {{"a", 1}, {"b", 2}};
     OrderedDict<int> dict2 = {{"a", 1}};
-    REQUIRE_THROWS_WITH(
+    CATCH_REQUIRE_THROWS_WITH(
         dict2.update(dict), StartsWith("Key 'a' already defined"));
   }
 
-  SECTION("Can iterate items") {
+  CATCH_SECTION("Can iterate items") {
     OrderedDict<int> dict = {{"a", 1}, {"b", 2}};
     auto iterator = dict.begin();
-    REQUIRE(iterator != dict.end());
-    REQUIRE(iterator->key == "a");
-    REQUIRE(iterator->value == 1);
+    CATCH_REQUIRE(iterator != dict.end());
+    CATCH_REQUIRE(iterator->key == "a");
+    CATCH_REQUIRE(iterator->value == 1);
     ++iterator;
-    REQUIRE(iterator != dict.end());
-    REQUIRE(iterator->key == "b");
-    REQUIRE(iterator->value == 2);
+    CATCH_REQUIRE(iterator != dict.end());
+    CATCH_REQUIRE(iterator->key == "b");
+    CATCH_REQUIRE(iterator->value == 2);
     ++iterator;
-    REQUIRE(iterator == dict.end());
+    CATCH_REQUIRE(iterator == dict.end());
   }
 
-  SECTION("clear makes the dict empty") {
+  CATCH_SECTION("clear makes the dict empty") {
     OrderedDict<int> dict = {{"a", 1}, {"b", 2}};
-    REQUIRE(!dict.is_empty());
+    CATCH_REQUIRE(!dict.is_empty());
     dict.clear();
-    REQUIRE(dict.is_empty());
+    CATCH_REQUIRE(dict.is_empty());
   }
 
-  SECTION("can copy construct") {
+  CATCH_SECTION("can copy construct") {
     OrderedDict<int> dict = {{"a", 1}, {"b", 2}};
     OrderedDict<int> copy = dict;
-    REQUIRE(copy.size() == 2);
-    REQUIRE(*copy[0] == 1);
-    REQUIRE(*copy[1] == 2);
+    CATCH_REQUIRE(copy.size() == 2);
+    CATCH_REQUIRE(*copy[0] == 1);
+    CATCH_REQUIRE(*copy[1] == 2);
   }
 
-  SECTION("can copy assign") {
+  CATCH_SECTION("can copy assign") {
     OrderedDict<int> dict = {{"a", 1}, {"b", 2}};
     OrderedDict<int> copy = {{"c", 1}};
-    REQUIRE(copy.find("c") != nullptr);
+    CATCH_REQUIRE(copy.find("c") != nullptr);
     copy = dict;
-    REQUIRE(copy.size() == 2);
-    REQUIRE(*copy[0] == 1);
-    REQUIRE(*copy[1] == 2);
-    REQUIRE(copy.find("c") == nullptr);
+    CATCH_REQUIRE(copy.size() == 2);
+    CATCH_REQUIRE(*copy[0] == 1);
+    CATCH_REQUIRE(*copy[1] == 2);
+    CATCH_REQUIRE(copy.find("c") == nullptr);
   }
 
-  SECTION("can move construct") {
+  CATCH_SECTION("can move construct") {
     OrderedDict<int> dict = {{"a", 1}, {"b", 2}};
     OrderedDict<int> copy = std::move(dict);
-    REQUIRE(copy.size() == 2);
-    REQUIRE(*copy[0] == 1);
-    REQUIRE(*copy[1] == 2);
+    CATCH_REQUIRE(copy.size() == 2);
+    CATCH_REQUIRE(*copy[0] == 1);
+    CATCH_REQUIRE(*copy[1] == 2);
   }
 
-  SECTION("can move assign") {
+  CATCH_SECTION("can move assign") {
     OrderedDict<int> dict = {{"a", 1}, {"b", 2}};
     OrderedDict<int> copy = {{"c", 1}};
-    REQUIRE(copy.find("c") != nullptr);
+    CATCH_REQUIRE(copy.find("c") != nullptr);
     copy = std::move(dict);
-    REQUIRE(copy.size() == 2);
-    REQUIRE(*copy[0] == 1);
-    REQUIRE(*copy[1] == 2);
-    REQUIRE(copy.find("c") == nullptr);
+    CATCH_REQUIRE(copy.size() == 2);
+    CATCH_REQUIRE(*copy[0] == 1);
+    CATCH_REQUIRE(*copy[1] == 2);
+    CATCH_REQUIRE(copy.find("c") == nullptr);
   }
 
-  SECTION("can insert with braces") {
+  CATCH_SECTION("can insert with braces") {
     OrderedDict<std::pair<int, int>> dict;
     dict.insert("a", {1, 2});
-    REQUIRE(!dict.is_empty());
-    REQUIRE(dict["a"].first == 1);
-    REQUIRE(dict["a"].second == 2);
+    CATCH_REQUIRE(!dict.is_empty());
+    CATCH_REQUIRE(dict["a"].first == 1);
+    CATCH_REQUIRE(dict["a"].second == 2);
   }
 
-  SECTION("Error messages include the what") {
+  CATCH_SECTION("Error messages include the what") {
     OrderedDict<int> dict("Penguin");
-    REQUIRE(dict.subject() == "Penguin");
+    CATCH_REQUIRE(dict.subject() == "Penguin");
     dict.insert("a", 1);
-    REQUIRE(!dict.is_empty());
-    REQUIRE_THROWS_WITH(
+    CATCH_REQUIRE(!dict.is_empty());
+    CATCH_REQUIRE_THROWS_WITH(
         dict.get("b"), StartsWith("Penguin 'b' is not defined"));
-    REQUIRE_THROWS_WITH(
+    CATCH_REQUIRE_THROWS_WITH(
         dict.insert("a", 1), StartsWith("Penguin 'a' already defined"));
   }
 }
diff --git a/test/cpp/api/module.cpp b/test/cpp/api/module.cpp
index d4049d6..2b9d0ad 100644
--- a/test/cpp/api/module.cpp
+++ b/test/cpp/api/module.cpp
@@ -1,4 +1,4 @@
-#include <catch.hpp>
+#include "catch_utils.hpp"
 
 #include <torch/nn/module.h>
 #include <torch/nn/modules/linear.h>
@@ -22,21 +22,21 @@
 };
 } // namespace test
 
-TEST_CASE("module/training-mode") {
+CATCH_TEST_CASE("module/training-mode") {
   torch::manual_seed(0);
   Linear module(3, 4);
-  REQUIRE(module->is_training());
-  SECTION("Enable eval mode") {
+  CATCH_REQUIRE(module->is_training());
+  CATCH_SECTION("Enable eval mode") {
     module->eval();
-    REQUIRE(!module->is_training());
+    CATCH_REQUIRE(!module->is_training());
   }
-  SECTION("Enable train mode") {
+  CATCH_SECTION("Enable train mode") {
     module->train();
-    REQUIRE(module->is_training());
+    CATCH_REQUIRE(module->is_training());
   }
 }
 
-TEST_CASE("module/zero-grad") {
+CATCH_TEST_CASE("module/zero-grad") {
   torch::manual_seed(0);
   Linear module(3, 4);
   auto weight = torch::ones({8, 3}, torch::requires_grad());
@@ -44,18 +44,18 @@
   loss.backward();
   for (auto& parameter : module->parameters()) {
     auto grad = parameter->grad();
-    REQUIRE(grad.defined());
-    REQUIRE(grad.sum().toCFloat() != 0);
+    CATCH_REQUIRE(grad.defined());
+    CATCH_REQUIRE(grad.sum().toCFloat() != 0);
   }
   module->zero_grad();
   for (auto& parameter : module->parameters()) {
     auto grad = parameter->grad();
-    REQUIRE(grad.defined());
-    REQUIRE(grad.sum().toCFloat() == 0);
+    CATCH_REQUIRE(grad.defined());
+    CATCH_REQUIRE(grad.sum().toCFloat() == 0);
   }
 }
 
-TEST_CASE("module/zero-grad-with-undefined") {
+CATCH_TEST_CASE("module/zero-grad-with-undefined") {
   struct TestModule : torch::nn::Module {
     TestModule() {
       x = register_parameter("x", torch::ones(5, at::requires_grad()));
@@ -68,120 +68,120 @@
   auto z = module.x * 2;
   z.sum().backward();
 
-  REQUIRE(module.x.grad().defined());
-  REQUIRE(!module.y.grad().defined());
+  CATCH_REQUIRE(module.x.grad().defined());
+  CATCH_REQUIRE(!module.y.grad().defined());
 
   module.zero_grad();
 
-  REQUIRE(module.x.grad().defined());
-  REQUIRE(!module.y.grad().defined());
+  CATCH_REQUIRE(module.x.grad().defined());
+  CATCH_REQUIRE(!module.y.grad().defined());
 
-  REQUIRE(module.x.grad().sum().toCFloat() == 0);
+  CATCH_REQUIRE(module.x.grad().sum().toCFloat() == 0);
 }
 
-TEST_CASE("module/name") {
+CATCH_TEST_CASE("module/name") {
   // CHECK instead of REQUIRE because demangling may fail.
   AGIUnit agi;
   // Call it twice just to make sure there are no bugs in the lazy
   // initialization semantics.
-  CHECK(agi.name() == "AGIUnit");
-  CHECK(agi.name() == "AGIUnit");
-  SECTION("correctly demangled") {
-    CHECK(test::AGIUnit().name() == "test::AGIUnit");
-    CHECK(test::AGIUnit2().name() == "Foo");
+  CATCH_CHECK(agi.name() == "AGIUnit");
+  CATCH_CHECK(agi.name() == "AGIUnit");
+  CATCH_SECTION("correctly demangled") {
+    CATCH_CHECK(test::AGIUnit().name() == "test::AGIUnit");
+    CATCH_CHECK(test::AGIUnit2().name() == "Foo");
   }
 }
 
-TEST_CASE("module/as") {
+CATCH_TEST_CASE("module/as") {
   Linear module(3, 4);
-  REQUIRE(module->as<Linear>() == module.get());
-  REQUIRE(module->as<LinearImpl>() == module.get());
-  REQUIRE(module->as<Module>() == module.get());
-  REQUIRE(module->as<AGIUnit>() == nullptr);
+  CATCH_REQUIRE(module->as<Linear>() == module.get());
+  CATCH_REQUIRE(module->as<LinearImpl>() == module.get());
+  CATCH_REQUIRE(module->as<Module>() == module.get());
+  CATCH_REQUIRE(module->as<AGIUnit>() == nullptr);
 
   std::shared_ptr<Module> raw = module.ptr();
-  REQUIRE(raw->as<Linear>() == module.get());
-  REQUIRE(raw->as<LinearImpl>() == module.get());
-  REQUIRE(raw->as<Module>() == module.get());
-  REQUIRE(raw->as<AGIUnit>() == nullptr);
+  CATCH_REQUIRE(raw->as<Linear>() == module.get());
+  CATCH_REQUIRE(raw->as<LinearImpl>() == module.get());
+  CATCH_REQUIRE(raw->as<Module>() == module.get());
+  CATCH_REQUIRE(raw->as<AGIUnit>() == nullptr);
 
   Module& raw_ref = *raw.get();
-  REQUIRE(raw_ref.as<Linear>() == module.get());
-  REQUIRE(raw_ref.as<LinearImpl>() == module.get());
-  REQUIRE(raw_ref.as<Module>() == module.get());
-  REQUIRE(raw_ref.as<AGIUnit>() == nullptr);
+  CATCH_REQUIRE(raw_ref.as<Linear>() == module.get());
+  CATCH_REQUIRE(raw_ref.as<LinearImpl>() == module.get());
+  CATCH_REQUIRE(raw_ref.as<Module>() == module.get());
+  CATCH_REQUIRE(raw_ref.as<AGIUnit>() == nullptr);
   if (auto* linear = raw_ref.as<Linear>()) {
-    REQUIRE(linear->weight.ndimension() == 2);
+    CATCH_REQUIRE(linear->weight.ndimension() == 2);
   }
 
   AGIUnit unit;
-  REQUIRE(unit.as<Linear>() == nullptr);
-  REQUIRE(unit.as<LinearImpl>() == nullptr);
-  REQUIRE(unit.as<AGIUnit>() == &unit);
+  CATCH_REQUIRE(unit.as<Linear>() == nullptr);
+  CATCH_REQUIRE(unit.as<LinearImpl>() == nullptr);
+  CATCH_REQUIRE(unit.as<AGIUnit>() == &unit);
 }
 
-TEST_CASE("module/conversions", "[multi-cuda]") {
+CATCH_TEST_CASE("module/conversions", "[multi-cuda]") {
   torch::manual_seed(0);
   Linear module(128, 64);
-  SECTION("starts as float on CPU") {
+  CATCH_SECTION("starts as float on CPU") {
     for (auto& parameter : module->parameters()) {
-      REQUIRE(parameter->device() == torch::Device(torch::kCPU));
-      REQUIRE(parameter->dtype() == torch::kFloat32);
+      CATCH_REQUIRE(parameter->device() == torch::Device(torch::kCPU));
+      CATCH_REQUIRE(parameter->dtype() == torch::kFloat32);
     }
   }
-  SECTION("to(CUDA)") {
+  CATCH_SECTION("to(CUDA)") {
     module->to({torch::kCUDA, 0});
     for (auto& parameter : module->parameters()) {
-      REQUIRE(parameter->device().type() == torch::Device::Type::CUDA);
-      REQUIRE(parameter->device().index() == 0);
+      CATCH_REQUIRE(parameter->device().type() == torch::Device::Type::CUDA);
+      CATCH_REQUIRE(parameter->device().index() == 0);
     }
     module->to({at::kCUDA, 1});
     for (auto& parameter : module->parameters()) {
-      REQUIRE(parameter->device().type() == torch::Device::Type::CUDA);
-      REQUIRE(parameter->device().index() == 1);
+      CATCH_REQUIRE(parameter->device().type() == torch::Device::Type::CUDA);
+      CATCH_REQUIRE(parameter->device().index() == 1);
     }
   }
-  SECTION("to(CPU)") {
+  CATCH_SECTION("to(CPU)") {
     module->to(torch::Device(torch::kCPU));
     for (auto& parameter : module->parameters()) {
-      REQUIRE(parameter->device().type() == torch::Device::Type::CPU);
+      CATCH_REQUIRE(parameter->device().type() == torch::Device::Type::CPU);
     }
   }
-  SECTION("to(Int32)") {
+  CATCH_SECTION("to(Int32)") {
     module->to(torch::kInt32);
     for (auto& parameter : module->parameters()) {
-      REQUIRE(parameter->dtype() == torch::kInt32);
+      CATCH_REQUIRE(parameter->dtype() == torch::kInt32);
     }
   }
-  SECTION("to(Float64)") {
+  CATCH_SECTION("to(Float64)") {
     module->to(torch::kFloat64);
     for (auto& parameter : module->parameters()) {
-      REQUIRE(parameter->dtype() == torch::kFloat64);
+      CATCH_REQUIRE(parameter->dtype() == torch::kFloat64);
     }
   }
-  SECTION("to(CUDA, Byte)") {
+  CATCH_SECTION("to(CUDA, Byte)") {
     module->to(torch::Device(torch::kCUDA, 1), torch::kUInt8);
     for (auto& parameter : module->parameters()) {
-      REQUIRE(parameter->device().type() == torch::Device::Type::CUDA);
-      REQUIRE(parameter->device().index() == 1);
+      CATCH_REQUIRE(parameter->device().type() == torch::Device::Type::CUDA);
+      CATCH_REQUIRE(parameter->device().index() == 1);
     }
     for (auto& parameter : module->parameters()) {
-      REQUIRE(parameter->dtype() == torch::kUInt8);
+      CATCH_REQUIRE(parameter->dtype() == torch::kUInt8);
     }
   }
 }
 
-TEST_CASE("module/clone") {
+CATCH_TEST_CASE("module/clone") {
   torch::manual_seed(0);
-  SECTION(
+  CATCH_SECTION(
       "a module that does not override clone() throws when clone() is called") {
     struct UnCloneable : Module {};
     UnCloneable module;
-    REQUIRE_THROWS_WITH(
+    CATCH_REQUIRE_THROWS_WITH(
         module.clone(), StartsWith("clone() has not been implemented"));
   }
 
-  SECTION(
+  CATCH_SECTION(
       "a module that overrides clone() does not throw when clone() is called ") {
     struct Cloneable : Module {
       std::shared_ptr<Module> clone(
@@ -190,10 +190,10 @@
       }
     };
     Cloneable module;
-    REQUIRE_NOTHROW(module.clone());
+    CATCH_REQUIRE_NOTHROW(module.clone());
   }
 
-  SECTION("Cloning creates distinct parameters") {
+  CATCH_SECTION("Cloning creates distinct parameters") {
     struct TestModule : public Cloneable<TestModule> {
       TestModule() {
         reset();
@@ -216,32 +216,32 @@
     auto module2 = module->clone();
     auto params1 = module->parameters();
     auto params2 = module2->parameters();
-    REQUIRE(params1.size() == 6);
-    REQUIRE(params2.size() == 6);
+    CATCH_REQUIRE(params1.size() == 6);
+    CATCH_REQUIRE(params2.size() == 6);
     for (auto& param : params1) {
-      REQUIRE(!pointer_equal(param.value, params2[param.key]));
-      REQUIRE(param->allclose(params2[param.key]));
+      CATCH_REQUIRE(!pointer_equal(param.value, params2[param.key]));
+      CATCH_REQUIRE(param->allclose(params2[param.key]));
       param->add_(2);
     }
     for (auto& param : params1) {
-      REQUIRE(!param->allclose(params2[param.key]));
+      CATCH_REQUIRE(!param->allclose(params2[param.key]));
     }
 
     auto buffers1 = module->buffers();
     auto buffers2 = module2->buffers();
-    REQUIRE(buffers1.size() == 1);
-    REQUIRE(buffers2.size() == 1);
+    CATCH_REQUIRE(buffers1.size() == 1);
+    CATCH_REQUIRE(buffers2.size() == 1);
     for (auto& buffer : buffers1) {
-      REQUIRE(!pointer_equal(buffer.value, buffers2[buffer.key]));
-      REQUIRE(buffer->allclose(buffers2[buffer.key]));
+      CATCH_REQUIRE(!pointer_equal(buffer.value, buffers2[buffer.key]));
+      CATCH_REQUIRE(buffer->allclose(buffers2[buffer.key]));
       buffer->add_(2);
     }
     for (auto& buffer : buffers1) {
-      REQUIRE(!buffer->allclose(buffers2[buffer.key]));
+      CATCH_REQUIRE(!buffer->allclose(buffers2[buffer.key]));
     }
   }
 
-  SECTION("Cloning preserves external references") {
+  CATCH_SECTION("Cloning preserves external references") {
     struct TestModule : public Cloneable<TestModule> {
       TestModule() {
         reset();
@@ -256,19 +256,19 @@
       torch::NoGradGuard no_grad;
       module->weight += 1;
     }
-    REQUIRE(pointer_equal(module->weight, module->parameters()["weight"]));
-    REQUIRE(module->weight.allclose(module->parameters()["weight"]));
+    CATCH_REQUIRE(pointer_equal(module->weight, module->parameters()["weight"]));
+    CATCH_REQUIRE(module->weight.allclose(module->parameters()["weight"]));
 
     auto module2 = std::dynamic_pointer_cast<TestModule>(
         std::shared_ptr<Module>(module->clone()));
-    REQUIRE(!pointer_equal(module2->weight, module->weight));
-    REQUIRE(pointer_equal(module2->weight, module2->parameters()["weight"]));
-    REQUIRE(module2->weight.allclose(module2->parameters()["weight"]));
-    REQUIRE(module2->weight.allclose(module->weight));
-    REQUIRE(!pointer_equal(module2->weight, module->parameters()["weight"]));
+    CATCH_REQUIRE(!pointer_equal(module2->weight, module->weight));
+    CATCH_REQUIRE(pointer_equal(module2->weight, module2->parameters()["weight"]));
+    CATCH_REQUIRE(module2->weight.allclose(module2->parameters()["weight"]));
+    CATCH_REQUIRE(module2->weight.allclose(module->weight));
+    CATCH_REQUIRE(!pointer_equal(module2->weight, module->parameters()["weight"]));
   }
 
-  SECTION("Cloning copies the values of variables of submodules") {
+  CATCH_SECTION("Cloning copies the values of variables of submodules") {
     struct TestModule : public Cloneable<TestModule> {
       TestModule() {
         reset();
@@ -299,16 +299,16 @@
 
     auto b = std::dynamic_pointer_cast<NestedModule>(a->clone());
 
-    REQUIRE(!pointer_equal(b->module->weight, a->module->weight));
-    REQUIRE(
+    CATCH_REQUIRE(!pointer_equal(b->module->weight, a->module->weight));
+    CATCH_REQUIRE(
         pointer_equal(b->module->weight, b->module->parameters()["weight"]));
-    REQUIRE(b->module->parameters()["weight"].allclose(a->module->weight));
-    REQUIRE(b->module->weight.allclose(a->module->weight));
-    REQUIRE(b->module->value == a->module->value);
+    CATCH_REQUIRE(b->module->parameters()["weight"].allclose(a->module->weight));
+    CATCH_REQUIRE(b->module->weight.allclose(a->module->weight));
+    CATCH_REQUIRE(b->module->value == a->module->value);
   }
 }
 
-TEST_CASE("module/clone-to-device", "[cuda]") {
+CATCH_TEST_CASE("module/clone-to-device", "[cuda]") {
   struct TestModule : public Cloneable<TestModule> {
     TestModule() {
       reset();
@@ -324,7 +324,7 @@
     torch::Tensor buffer;
   };
 
-  SECTION("Cloning preserves the device of parameters/buffers") {
+  CATCH_SECTION("Cloning preserves the device of parameters/buffers") {
     TestModule m;
     torch::Device device(torch::kCUDA, 0);
 
@@ -332,33 +332,33 @@
 
     auto clone = m.clone();
     for (const auto& parameter : clone->parameters()) {
-      REQUIRE(parameter->device().type() == device.type());
-      REQUIRE(parameter->device().index() == device.index());
+      CATCH_REQUIRE(parameter->device().type() == device.type());
+      CATCH_REQUIRE(parameter->device().index() == device.index());
     }
     for (const auto& buffer : clone->buffers()) {
-      REQUIRE(buffer->device().type() == device.type());
-      REQUIRE(buffer->device().index() == device.index());
+      CATCH_REQUIRE(buffer->device().type() == device.type());
+      CATCH_REQUIRE(buffer->device().index() == device.index());
     }
   }
 
-  SECTION(
+  CATCH_SECTION(
       "Cloning to a particular device places all parameters/buffers there") {
     TestModule m;
     torch::Device device(torch::kCUDA, 1);
     // everything is on CPU here
     auto clone = m.clone(device);
     for (const auto& parameter : clone->parameters()) {
-      REQUIRE(parameter->device().type() == device.type());
-      REQUIRE(parameter->device().index() == device.index());
+      CATCH_REQUIRE(parameter->device().type() == device.type());
+      CATCH_REQUIRE(parameter->device().index() == device.index());
     }
     for (const auto& buffer : clone->buffers()) {
-      REQUIRE(buffer->device().type() == device.type());
-      REQUIRE(buffer->device().index() == device.index());
+      CATCH_REQUIRE(buffer->device().type() == device.type());
+      CATCH_REQUIRE(buffer->device().index() == device.index());
     }
   }
 }
 
-TEST_CASE("module/parameters") {
+CATCH_TEST_CASE("module/parameters") {
   torch::manual_seed(0);
   struct TestModule : Module {
     TestModule() {
@@ -372,19 +372,19 @@
 
   TestModule module;
 
-  SECTION("has correct number of parameters") {
-    REQUIRE(module.parameters().size() == 3);
+  CATCH_SECTION("has correct number of parameters") {
+    CATCH_REQUIRE(module.parameters().size() == 3);
   }
 
-  SECTION("contains parameters with the correct name") {
+  CATCH_SECTION("contains parameters with the correct name") {
     auto parameters = module.parameters();
-    REQUIRE(parameters.contains("a"));
-    REQUIRE(parameters.contains("b"));
-    REQUIRE(parameters.contains("c"));
+    CATCH_REQUIRE(parameters.contains("a"));
+    CATCH_REQUIRE(parameters.contains("b"));
+    CATCH_REQUIRE(parameters.contains("c"));
   }
 }
 
-TEST_CASE("module/buffers") {
+CATCH_TEST_CASE("module/buffers") {
   torch::manual_seed(0);
   struct TestModule : Module {
     TestModule() {
@@ -398,19 +398,19 @@
 
   TestModule module;
 
-  SECTION("has correct number of buffers") {
-    REQUIRE(module.buffers().size() == 3);
+  CATCH_SECTION("has correct number of buffers") {
+    CATCH_REQUIRE(module.buffers().size() == 3);
   }
 
-  SECTION("contains buffers with the correct name") {
+  CATCH_SECTION("contains buffers with the correct name") {
     auto buffers = module.buffers();
-    REQUIRE(buffers.contains("a"));
-    REQUIRE(buffers.contains("b"));
-    REQUIRE(buffers.contains("c"));
+    CATCH_REQUIRE(buffers.contains("a"));
+    CATCH_REQUIRE(buffers.contains("b"));
+    CATCH_REQUIRE(buffers.contains("c"));
   }
 }
 
-TEST_CASE("module/default-constructor") {
+CATCH_TEST_CASE("module/default-constructor") {
   struct AImpl : torch::nn::Module {
     AImpl() : x_(123) {}
     AImpl(int x) : x_(x) {}
@@ -420,20 +420,20 @@
 
   {
     A a;
-    REQUIRE(a);
-    REQUIRE(!a.is_empty());
-    REQUIRE(a->x_ == 123);
+    CATCH_REQUIRE(a);
+    CATCH_REQUIRE(!a.is_empty());
+    CATCH_REQUIRE(a->x_ == 123);
   }
   {
     A a(5);
-    REQUIRE(a);
-    REQUIRE(!a.is_empty());
-    REQUIRE(a->x_ == 5);
+    CATCH_REQUIRE(a);
+    CATCH_REQUIRE(!a.is_empty());
+    CATCH_REQUIRE(a->x_ == 5);
   }
   {
     A a = nullptr;
-    REQUIRE(!a);
-    REQUIRE(a.is_empty());
-    REQUIRE_THROWS_WITH(a->x_, StartsWith("Accessing empty ModuleHolder"));
+    CATCH_REQUIRE(!a);
+    CATCH_REQUIRE(a.is_empty());
+    CATCH_REQUIRE_THROWS_WITH(a->x_, StartsWith("Accessing empty ModuleHolder"));
   }
 }
diff --git a/test/cpp/api/modules.cpp b/test/cpp/api/modules.cpp
index 928a39f..7d4f9ab 100644
--- a/test/cpp/api/modules.cpp
+++ b/test/cpp/api/modules.cpp
@@ -1,4 +1,4 @@
-#include <catch.hpp>
+#include "catch_utils.hpp"
 
 #include <torch/nn/module.h>
 #include <torch/nn/modules/batchnorm.h>
@@ -39,92 +39,92 @@
   std::shared_ptr<TestModel> t;
 };
 
-TEST_CASE("modules") {
+CATCH_TEST_CASE("modules") {
   torch::manual_seed(0);
-  SECTION("conv") {
-    SECTION("1d") {
+  CATCH_SECTION("conv") {
+    CATCH_SECTION("1d") {
       Conv1d model(Conv1dOptions(3, 2, 3).stride(2));
       auto x = torch::randn({2, 3, 5}, torch::requires_grad());
       auto y = model->forward(x);
       torch::Tensor s = y.sum();
 
       s.backward();
-      REQUIRE(y.ndimension() == 3);
-      REQUIRE(s.ndimension() == 0);
+      CATCH_REQUIRE(y.ndimension() == 3);
+      CATCH_REQUIRE(s.ndimension() == 0);
       for (auto i = 0; i < 3; i++) {
-        REQUIRE(y.size(i) == 2);
+        CATCH_REQUIRE(y.size(i) == 2);
       }
 
-      REQUIRE(model->parameters()["weight"].grad().numel() == 3 * 2 * 3);
+      CATCH_REQUIRE(model->parameters()["weight"].grad().numel() == 3 * 2 * 3);
     }
-    SECTION("2d") {
-      SECTION("even") {
+    CATCH_SECTION("2d") {
+      CATCH_SECTION("even") {
         Conv2d model(Conv2dOptions(3, 2, 3).stride(2));
         auto x = torch::randn({2, 3, 5, 5}, torch::requires_grad());
         auto y = model->forward(x);
         torch::Tensor s = y.sum();
 
         s.backward();
-        REQUIRE(y.ndimension() == 4);
-        REQUIRE(s.ndimension() == 0);
+        CATCH_REQUIRE(y.ndimension() == 4);
+        CATCH_REQUIRE(s.ndimension() == 0);
         for (auto i = 0; i < 4; i++) {
-          REQUIRE(y.size(i) == 2);
+          CATCH_REQUIRE(y.size(i) == 2);
         }
 
-        REQUIRE(model->parameters()["weight"].grad().numel() == 3 * 2 * 3 * 3);
+        CATCH_REQUIRE(model->parameters()["weight"].grad().numel() == 3 * 2 * 3 * 3);
       }
 
-      SECTION("uneven") {
+      CATCH_SECTION("uneven") {
         Conv2d model(Conv2dOptions(3, 2, {3, 2}).stride({2, 2}));
         auto x = torch::randn({2, 3, 5, 4}, torch::requires_grad());
         auto y = model->forward(x);
         torch::Tensor s = y.sum();
 
         s.backward();
-        REQUIRE(y.ndimension() == 4);
-        REQUIRE(s.ndimension() == 0);
+        CATCH_REQUIRE(y.ndimension() == 4);
+        CATCH_REQUIRE(s.ndimension() == 0);
         for (auto i = 0; i < 4; i++) {
-          REQUIRE(y.size(i) == 2);
+          CATCH_REQUIRE(y.size(i) == 2);
         }
 
-        REQUIRE(model->parameters()["weight"].grad().numel() == 3 * 2 * 3 * 2);
+        CATCH_REQUIRE(model->parameters()["weight"].grad().numel() == 3 * 2 * 3 * 2);
       }
     }
-    SECTION("3d") {
+    CATCH_SECTION("3d") {
       Conv3d model(Conv3dOptions(3, 2, 3).stride(2));
       auto x = torch::randn({2, 3, 5, 5, 5}, torch::requires_grad());
       auto y = model->forward(x);
       torch::Tensor s = y.sum();
 
       s.backward();
-      REQUIRE(y.ndimension() == 5);
-      REQUIRE(s.ndimension() == 0);
+      CATCH_REQUIRE(y.ndimension() == 5);
+      CATCH_REQUIRE(s.ndimension() == 0);
       for (auto i = 0; i < 5; i++) {
-        REQUIRE(y.size(i) == 2);
+        CATCH_REQUIRE(y.size(i) == 2);
       }
 
-      REQUIRE(
+      CATCH_REQUIRE(
           model->parameters()["weight"].grad().numel() == 3 * 2 * 3 * 3 * 3);
     }
   }
-  SECTION("linear") {
-    SECTION("basic1") {
+  CATCH_SECTION("linear") {
+    CATCH_SECTION("basic1") {
       Linear model(5, 2);
       auto x = torch::randn({10, 5}, torch::requires_grad());
       auto y = model->forward(x);
       torch::Tensor s = y.sum();
 
       s.backward();
-      REQUIRE(y.ndimension() == 2);
-      REQUIRE(s.ndimension() == 0);
-      REQUIRE(y.size(0) == 10);
-      REQUIRE(y.size(1) == 2);
+      CATCH_REQUIRE(y.ndimension() == 2);
+      CATCH_REQUIRE(s.ndimension() == 0);
+      CATCH_REQUIRE(y.size(0) == 10);
+      CATCH_REQUIRE(y.size(1) == 2);
 
-      REQUIRE(model->parameters()["weight"].grad().numel() == 2 * 5);
+      CATCH_REQUIRE(model->parameters()["weight"].grad().numel() == 2 * 5);
     }
   }
 
-  SECTION("simple") {
+  CATCH_SECTION("simple") {
     auto model = std::make_shared<SimpleContainer>();
     auto l1 = model->add(Linear(10, 3), "l1");
     auto l2 = model->add(Linear(3, 5), "l2");
@@ -136,20 +136,20 @@
     x = l3->forward(x).clamp_min(0);
 
     x.backward();
-    REQUIRE(x.ndimension() == 2);
-    REQUIRE(x.size(0) == 1000);
-    REQUIRE(x.size(1) == 100);
-    REQUIRE(x.min().toCFloat() == 0);
+    CATCH_REQUIRE(x.ndimension() == 2);
+    CATCH_REQUIRE(x.size(0) == 1000);
+    CATCH_REQUIRE(x.size(1) == 100);
+    CATCH_REQUIRE(x.min().toCFloat() == 0);
   }
 
-  SECTION("embedding") {
-    SECTION("basic") {
+  CATCH_SECTION("embedding") {
+    CATCH_SECTION("basic") {
       const int64_t dict_size = 10;
       Embedding model(dict_size, 2);
-      REQUIRE(model->parameters().contains("weight"));
-      REQUIRE(model->weight.ndimension() == 2);
-      REQUIRE(model->weight.size(0) == dict_size);
-      REQUIRE(model->weight.size(1) == 2);
+      CATCH_REQUIRE(model->parameters().contains("weight"));
+      CATCH_REQUIRE(model->weight.ndimension() == 2);
+      CATCH_REQUIRE(model->weight.size(0) == dict_size);
+      CATCH_REQUIRE(model->weight.size(1) == 2);
 
       // Cannot get gradients to change indices (input) - only for embedding
       // params
@@ -158,65 +158,65 @@
       torch::Tensor s = y.sum();
 
       s.backward();
-      REQUIRE(y.ndimension() == 2);
-      REQUIRE(s.ndimension() == 0);
-      REQUIRE(y.size(0) == 10);
-      REQUIRE(y.size(1) == 2);
+      CATCH_REQUIRE(y.ndimension() == 2);
+      CATCH_REQUIRE(s.ndimension() == 0);
+      CATCH_REQUIRE(y.size(0) == 10);
+      CATCH_REQUIRE(y.size(1) == 2);
 
-      REQUIRE(model->parameters()["weight"].grad().numel() == 2 * dict_size);
+      CATCH_REQUIRE(model->parameters()["weight"].grad().numel() == 2 * dict_size);
     }
 
-    SECTION("list") {
+    CATCH_SECTION("list") {
       Embedding model(6, 4);
       auto x = torch::full({2, 3}, 5, torch::kInt64);
       auto y = model->forward(x);
       torch::Tensor s = y.sum();
 
       s.backward();
-      REQUIRE(y.ndimension() == 3);
-      REQUIRE(y.size(0) == 2);
-      REQUIRE(y.size(1) == 3);
-      REQUIRE(y.size(2) == 4);
+      CATCH_REQUIRE(y.ndimension() == 3);
+      CATCH_REQUIRE(y.size(0) == 2);
+      CATCH_REQUIRE(y.size(1) == 3);
+      CATCH_REQUIRE(y.size(2) == 4);
     }
   }
 
-  SECTION("dropout") {
+  CATCH_SECTION("dropout") {
     Dropout dropout(0.5);
     torch::Tensor x = torch::ones(100, torch::requires_grad());
     torch::Tensor y = dropout->forward(x);
 
     y.backward();
-    REQUIRE(y.ndimension() == 1);
-    REQUIRE(y.size(0) == 100);
-    REQUIRE(y.sum().toCFloat() < 130); // Probably
-    REQUIRE(y.sum().toCFloat() > 70); // Probably
+    CATCH_REQUIRE(y.ndimension() == 1);
+    CATCH_REQUIRE(y.size(0) == 100);
+    CATCH_REQUIRE(y.sum().toCFloat() < 130); // Probably
+    CATCH_REQUIRE(y.sum().toCFloat() > 70); // Probably
 
     dropout->eval();
     y = dropout->forward(x);
-    REQUIRE(y.sum().toCFloat() == 100);
+    CATCH_REQUIRE(y.sum().toCFloat() == 100);
   }
 
-  SECTION("param") {
+  CATCH_SECTION("param") {
     auto model = std::make_shared<NestedModel>();
     auto parameters = model->parameters();
-    REQUIRE(parameters["param"].size(0) == 3);
-    REQUIRE(parameters["param"].size(1) == 2);
-    REQUIRE(parameters["param"].size(2) == 21);
-    REQUIRE(parameters["l1.bias"].size(0) == 20);
-    REQUIRE(parameters["l1.weight"].size(0) == 20);
-    REQUIRE(parameters["l1.weight"].size(1) == 5);
-    REQUIRE(parameters["test.l1.bias"].size(0) == 3);
-    REQUIRE(parameters["test.l1.weight"].size(0) == 3);
-    REQUIRE(parameters["test.l1.weight"].size(1) == 10);
-    REQUIRE(parameters["test.l2.bias"].size(0) == 5);
-    REQUIRE(parameters["test.l2.weight"].size(0) == 5);
-    REQUIRE(parameters["test.l2.weight"].size(1) == 3);
-    REQUIRE(parameters["test.l3.bias"].size(0) == 100);
-    REQUIRE(parameters["test.l3.weight"].size(0) == 100);
-    REQUIRE(parameters["test.l3.weight"].size(1) == 5);
+    CATCH_REQUIRE(parameters["param"].size(0) == 3);
+    CATCH_REQUIRE(parameters["param"].size(1) == 2);
+    CATCH_REQUIRE(parameters["param"].size(2) == 21);
+    CATCH_REQUIRE(parameters["l1.bias"].size(0) == 20);
+    CATCH_REQUIRE(parameters["l1.weight"].size(0) == 20);
+    CATCH_REQUIRE(parameters["l1.weight"].size(1) == 5);
+    CATCH_REQUIRE(parameters["test.l1.bias"].size(0) == 3);
+    CATCH_REQUIRE(parameters["test.l1.weight"].size(0) == 3);
+    CATCH_REQUIRE(parameters["test.l1.weight"].size(1) == 10);
+    CATCH_REQUIRE(parameters["test.l2.bias"].size(0) == 5);
+    CATCH_REQUIRE(parameters["test.l2.weight"].size(0) == 5);
+    CATCH_REQUIRE(parameters["test.l2.weight"].size(1) == 3);
+    CATCH_REQUIRE(parameters["test.l3.bias"].size(0) == 100);
+    CATCH_REQUIRE(parameters["test.l3.weight"].size(0) == 100);
+    CATCH_REQUIRE(parameters["test.l3.weight"].size(1) == 5);
   }
 
-  SECTION("functional") {
+  CATCH_SECTION("functional") {
     {
       bool was_called = false;
       auto functional = Functional([&was_called](torch::Tensor input) {
@@ -224,63 +224,63 @@
         return input;
       });
       auto output = functional->forward(torch::ones(5, torch::requires_grad()));
-      REQUIRE(was_called);
-      REQUIRE(output.equal(torch::ones(5, torch::requires_grad())));
+      CATCH_REQUIRE(was_called);
+      CATCH_REQUIRE(output.equal(torch::ones(5, torch::requires_grad())));
 
       was_called = false;
       // Use the call operator overload here.
       output = functional(torch::ones(5, torch::requires_grad()));
-      REQUIRE(was_called);
-      REQUIRE(output.equal(torch::ones(5, torch::requires_grad())));
+      CATCH_REQUIRE(was_called);
+      CATCH_REQUIRE(output.equal(torch::ones(5, torch::requires_grad())));
     }
     {
       auto functional = Functional(torch::relu);
-      REQUIRE(functional(torch::ones({})).toCFloat() == 1);
-      REQUIRE(functional(torch::ones({})).toCFloat() == 1);
-      REQUIRE(functional(torch::ones({}) * -1).toCFloat() == 0);
+      CATCH_REQUIRE(functional(torch::ones({})).toCFloat() == 1);
+      CATCH_REQUIRE(functional(torch::ones({})).toCFloat() == 1);
+      CATCH_REQUIRE(functional(torch::ones({}) * -1).toCFloat() == 0);
     }
     {
       auto functional =
           Functional(torch::elu, /*alpha=*/1, /*scale=*/0, /*input_scale=*/1);
-      REQUIRE(functional(torch::ones({})).toCFloat() == 0);
+      CATCH_REQUIRE(functional(torch::ones({})).toCFloat() == 0);
     }
   }
 
-  SECTION("batchnorm") {
+  CATCH_SECTION("batchnorm") {
     {
       BatchNorm bn(5);
 
       // Is stateful by default.
-      REQUIRE(bn->options.stateful());
+      CATCH_REQUIRE(bn->options.stateful());
 
-      REQUIRE(bn->running_mean.defined());
-      REQUIRE(bn->running_mean.dim() == 1);
-      REQUIRE(bn->running_mean.size(0) == 5);
+      CATCH_REQUIRE(bn->running_mean.defined());
+      CATCH_REQUIRE(bn->running_mean.dim() == 1);
+      CATCH_REQUIRE(bn->running_mean.size(0) == 5);
 
-      REQUIRE(bn->running_variance.defined());
-      REQUIRE(bn->running_variance.dim() == 1);
-      REQUIRE(bn->running_variance.size(0) == 5);
+      CATCH_REQUIRE(bn->running_variance.defined());
+      CATCH_REQUIRE(bn->running_variance.dim() == 1);
+      CATCH_REQUIRE(bn->running_variance.size(0) == 5);
 
       // Is affine by default.
-      REQUIRE(bn->options.affine());
+      CATCH_REQUIRE(bn->options.affine());
 
-      REQUIRE(bn->weight.defined());
-      REQUIRE(bn->weight.dim() == 1);
-      REQUIRE(bn->weight.size(0) == 5);
+      CATCH_REQUIRE(bn->weight.defined());
+      CATCH_REQUIRE(bn->weight.dim() == 1);
+      CATCH_REQUIRE(bn->weight.size(0) == 5);
 
-      REQUIRE(bn->bias.defined());
-      REQUIRE(bn->bias.dim() == 1);
-      REQUIRE(bn->bias.size(0) == 5);
+      CATCH_REQUIRE(bn->bias.defined());
+      CATCH_REQUIRE(bn->bias.dim() == 1);
+      CATCH_REQUIRE(bn->bias.size(0) == 5);
     }
     {
       BatchNorm bn(BatchNormOptions(5).stateful(false).affine(false));
 
-      REQUIRE(!bn->running_mean.defined());
-      REQUIRE(!bn->running_variance.defined());
-      REQUIRE(!bn->weight.defined());
-      REQUIRE(!bn->bias.defined());
+      CATCH_REQUIRE(!bn->running_mean.defined());
+      CATCH_REQUIRE(!bn->running_variance.defined());
+      CATCH_REQUIRE(!bn->weight.defined());
+      CATCH_REQUIRE(!bn->bias.defined());
 
-      REQUIRE_THROWS_WITH(
+      CATCH_REQUIRE_THROWS_WITH(
           bn->forward(torch::ones({2, 5})),
           StartsWith("Calling BatchNorm::forward is only permitted "
                      "when the 'stateful' option is true (was false). "
@@ -298,14 +298,14 @@
       auto output = bn->pure_forward(input, mean, variance);
       auto expected =
           (input - mean) / torch::sqrt(variance + bn->options.eps());
-      REQUIRE(output.allclose(expected));
+      CATCH_REQUIRE(output.allclose(expected));
     }
   }
 }
 
-TEST_CASE("modules_cuda", "[cuda]") {
+CATCH_TEST_CASE("modules_cuda", "[cuda]") {
   torch::manual_seed(0);
-  SECTION("1") {
+  CATCH_SECTION("1") {
     Linear model(5, 2);
     model->to(torch::kCUDA);
     auto x =
@@ -314,15 +314,15 @@
     torch::Tensor s = y.sum();
 
     s.backward();
-    REQUIRE(y.ndimension() == 2);
-    REQUIRE(s.ndimension() == 0);
-    REQUIRE(y.size(0) == 10);
-    REQUIRE(y.size(1) == 2);
+    CATCH_REQUIRE(y.ndimension() == 2);
+    CATCH_REQUIRE(s.ndimension() == 0);
+    CATCH_REQUIRE(y.size(0) == 10);
+    CATCH_REQUIRE(y.size(1) == 2);
 
-    REQUIRE(model->parameters()["weight"].grad().numel() == 2 * 5);
+    CATCH_REQUIRE(model->parameters()["weight"].grad().numel() == 2 * 5);
   }
 
-  SECTION("2") {
+  CATCH_SECTION("2") {
     Linear model(5, 2);
     model->to(torch::kCUDA);
     model->to(torch::kCPU);
@@ -331,11 +331,11 @@
     torch::Tensor s = y.sum();
 
     s.backward();
-    REQUIRE(y.ndimension() == 2);
-    REQUIRE(s.ndimension() == 0);
-    REQUIRE(y.size(0) == 10);
-    REQUIRE(y.size(1) == 2);
+    CATCH_REQUIRE(y.ndimension() == 2);
+    CATCH_REQUIRE(s.ndimension() == 0);
+    CATCH_REQUIRE(y.size(0) == 10);
+    CATCH_REQUIRE(y.size(1) == 2);
 
-    REQUIRE(model->parameters()["weight"].grad().numel() == 2 * 5);
+    CATCH_REQUIRE(model->parameters()["weight"].grad().numel() == 2 * 5);
   }
 }
diff --git a/test/cpp/api/optim.cpp b/test/cpp/api/optim.cpp
index ab27818..4cb398d 100644
--- a/test/cpp/api/optim.cpp
+++ b/test/cpp/api/optim.cpp
@@ -1,4 +1,4 @@
-#include <catch.hpp>
+#include "catch_utils.hpp"
 
 #include <torch/nn/module.h>
 #include <torch/nn/modules/functional.h>
@@ -118,24 +118,24 @@
     optimizer.step();
 
     if (i % kSampleEvery == 0) {
-      REQUIRE(
+      CATCH_REQUIRE(
           expected_parameters.at(i / kSampleEvery).size() == parameters.size());
       for (size_t p = 0; p < parameters.size(); ++p) {
-        REQUIRE(parameters.at(p)->defined());
+        CATCH_REQUIRE(parameters.at(p)->defined());
         auto computed = parameters.at(p)->flatten();
         auto expected = expected_parameters.at(i / kSampleEvery).at(p);
         if (!computed.allclose(expected, /*rtol=*/1e-3, /*atol=*/5e-4)) {
           std::cout << "Iteration " << i << ": " << computed
                     << " != " << expected << " (parameter " << p << ")"
                     << std::endl;
-          REQUIRE(false);
+          CATCH_REQUIRE(false);
         }
       }
     }
   }
 }
 
-TEST_CASE("Optim/BasicInterface") {
+CATCH_TEST_CASE("Optim/BasicInterface") {
   struct MyOptimizer : Optimizer {
     using Optimizer::Optimizer;
     void step() override {}
@@ -144,139 +144,139 @@
       torch::ones({2, 3}), torch::zeros({2, 3}), torch::rand({2, 3})};
   {
     MyOptimizer optimizer(parameters);
-    REQUIRE(optimizer.size() == parameters.size());
+    CATCH_REQUIRE(optimizer.size() == parameters.size());
   }
   {
     MyOptimizer optimizer;
-    REQUIRE(optimizer.size() == 0);
+    CATCH_REQUIRE(optimizer.size() == 0);
     optimizer.add_parameters(parameters);
-    REQUIRE(optimizer.size() == parameters.size());
+    CATCH_REQUIRE(optimizer.size() == parameters.size());
     for (size_t p = 0; p < parameters.size(); ++p) {
-      REQUIRE(optimizer.parameters()[p].allclose(parameters[p]));
+      CATCH_REQUIRE(optimizer.parameters()[p].allclose(parameters[p]));
     }
   }
   {
     Linear linear(3, 4);
     MyOptimizer optimizer(linear->parameters());
-    REQUIRE(optimizer.size() == linear->parameters().size());
+    CATCH_REQUIRE(optimizer.size() == linear->parameters().size());
   }
 }
 
-TEST_CASE("Optim/XORConvergence/SGD") {
-  REQUIRE(test_optimizer_xor<SGD>(
+CATCH_TEST_CASE("Optim/XORConvergence/SGD") {
+  CATCH_REQUIRE(test_optimizer_xor<SGD>(
       SGDOptions(0.1).momentum(0.9).nesterov(true).weight_decay(1e-6)));
 }
 
-TEST_CASE("Optim/XORConvergence/Adagrad") {
-  REQUIRE(test_optimizer_xor<Adagrad>(
+CATCH_TEST_CASE("Optim/XORConvergence/Adagrad") {
+  CATCH_REQUIRE(test_optimizer_xor<Adagrad>(
       AdagradOptions(1.0).weight_decay(1e-6).lr_decay(1e-3)));
 }
 
-TEST_CASE("Optim/XORConvergence/RMSprop") {
-  REQUIRE(test_optimizer_xor<RMSprop>(RMSpropOptions(0.1).centered(true)));
+CATCH_TEST_CASE("Optim/XORConvergence/RMSprop") {
+  CATCH_REQUIRE(test_optimizer_xor<RMSprop>(RMSpropOptions(0.1).centered(true)));
 }
 
-TEST_CASE("Optim/XORConvergence/RMSpropWithMomentum") {
-  REQUIRE(test_optimizer_xor<RMSprop>(
+CATCH_TEST_CASE("Optim/XORConvergence/RMSpropWithMomentum") {
+  CATCH_REQUIRE(test_optimizer_xor<RMSprop>(
       RMSpropOptions(0.1).momentum(0.9).weight_decay(1e-6)));
 }
 
-TEST_CASE("Optim/XORConvergence/Adam") {
-  REQUIRE(test_optimizer_xor<Adam>(AdamOptions(0.1).weight_decay(1e-6)));
+CATCH_TEST_CASE("Optim/XORConvergence/Adam") {
+  CATCH_REQUIRE(test_optimizer_xor<Adam>(AdamOptions(0.1).weight_decay(1e-6)));
 }
 
-TEST_CASE("Optim/XORConvergence/AdamWithAmsgrad") {
-  REQUIRE(test_optimizer_xor<Adam>(
+CATCH_TEST_CASE("Optim/XORConvergence/AdamWithAmsgrad") {
+  CATCH_REQUIRE(test_optimizer_xor<Adam>(
       AdamOptions(0.1).weight_decay(1e-6).amsgrad(true)));
 }
 
-TEST_CASE("Optim/ProducesPyTorchValues/Adam") {
+CATCH_TEST_CASE("Optim/ProducesPyTorchValues/Adam") {
   check_exact_values<Adam>(AdamOptions(1.0), expected_parameters::Adam);
 }
 
-TEST_CASE("Optim/ProducesPyTorchValues/AdamWithWeightDecay") {
+CATCH_TEST_CASE("Optim/ProducesPyTorchValues/AdamWithWeightDecay") {
   check_exact_values<Adam>(
       AdamOptions(1.0).weight_decay(1e-2),
       expected_parameters::Adam_with_weight_decay);
 }
 
-TEST_CASE("Optim/ProducesPyTorchValues/AdamWithWeightDecayAndAMSGrad") {
+CATCH_TEST_CASE("Optim/ProducesPyTorchValues/AdamWithWeightDecayAndAMSGrad") {
   check_exact_values<Adam>(
       AdamOptions(1.0).weight_decay(1e-6).amsgrad(true),
       expected_parameters::Adam_with_weight_decay_and_amsgrad);
 }
 
-TEST_CASE("Optim/ProducesPyTorchValues/Adagrad") {
+CATCH_TEST_CASE("Optim/ProducesPyTorchValues/Adagrad") {
   check_exact_values<Adagrad>(
       AdagradOptions(1.0), expected_parameters::Adagrad);
 }
 
-TEST_CASE("Optim/ProducesPyTorchValues/AdagradWithWeightDecay") {
+CATCH_TEST_CASE("Optim/ProducesPyTorchValues/AdagradWithWeightDecay") {
   check_exact_values<Adagrad>(
       AdagradOptions(1.0).weight_decay(1e-2),
       expected_parameters::Adagrad_with_weight_decay);
 }
 
-TEST_CASE("Optim/ProducesPyTorchValues/AdagradWithWeightDecayAndLRDecay") {
+CATCH_TEST_CASE("Optim/ProducesPyTorchValues/AdagradWithWeightDecayAndLRDecay") {
   check_exact_values<Adagrad>(
       AdagradOptions(1.0).weight_decay(1e-6).lr_decay(1e-3),
       expected_parameters::Adagrad_with_weight_decay_and_lr_decay);
 }
 
-TEST_CASE("Optim/ProducesPyTorchValues/RMSprop") {
+CATCH_TEST_CASE("Optim/ProducesPyTorchValues/RMSprop") {
   check_exact_values<RMSprop>(
       RMSpropOptions(0.1), expected_parameters::RMSprop);
 }
 
-TEST_CASE("Optim/ProducesPyTorchValues/RMSpropWithWeightDecay") {
+CATCH_TEST_CASE("Optim/ProducesPyTorchValues/RMSpropWithWeightDecay") {
   check_exact_values<RMSprop>(
       RMSpropOptions(0.1).weight_decay(1e-2),
       expected_parameters::RMSprop_with_weight_decay);
 }
 
-TEST_CASE("Optim/ProducesPyTorchValues/RMSpropWithWeightDecayAndCentered") {
+CATCH_TEST_CASE("Optim/ProducesPyTorchValues/RMSpropWithWeightDecayAndCentered") {
   check_exact_values<RMSprop>(
       RMSpropOptions(0.1).weight_decay(1e-6).centered(true),
       expected_parameters::RMSprop_with_weight_decay_and_centered);
 }
 
-TEST_CASE(
+CATCH_TEST_CASE(
     "Optim/ProducesPyTorchValues/RMSpropWithWeightDecayAndCenteredAndMomentum") {
   check_exact_values<RMSprop>(
       RMSpropOptions(0.1).weight_decay(1e-6).centered(true).momentum(0.9),
       expected_parameters::RMSprop_with_weight_decay_and_centered_and_momentum);
 }
 
-TEST_CASE("Optim/ProducesPyTorchValues/SGD") {
+CATCH_TEST_CASE("Optim/ProducesPyTorchValues/SGD") {
   check_exact_values<SGD>(SGDOptions(0.1), expected_parameters::SGD);
 }
 
-TEST_CASE("Optim/ProducesPyTorchValues/SGDWithWeightDecay") {
+CATCH_TEST_CASE("Optim/ProducesPyTorchValues/SGDWithWeightDecay") {
   check_exact_values<SGD>(
       SGDOptions(0.1).weight_decay(1e-2),
       expected_parameters::SGD_with_weight_decay);
 }
 
-TEST_CASE("Optim/ProducesPyTorchValues/SGDWithWeightDecayAndMomentum") {
+CATCH_TEST_CASE("Optim/ProducesPyTorchValues/SGDWithWeightDecayAndMomentum") {
   check_exact_values<SGD>(
       SGDOptions(0.1).weight_decay(1e-2).momentum(0.9),
       expected_parameters::SGD_with_weight_decay_and_momentum);
 }
 
-TEST_CASE("Optim/ProducesPyTorchValues/SGDWithWeightDecayAndNesterovMomentum") {
+CATCH_TEST_CASE("Optim/ProducesPyTorchValues/SGDWithWeightDecayAndNesterovMomentum") {
   check_exact_values<SGD>(
       SGDOptions(0.1).weight_decay(1e-6).momentum(0.9).nesterov(true),
       expected_parameters::SGD_with_weight_decay_and_nesterov_momentum);
 }
 
-TEST_CASE("Optim/ZeroGrad") {
+CATCH_TEST_CASE("Optim/ZeroGrad") {
   torch::manual_seed(0);
 
   Linear model(2, 8);
   SGD optimizer(model->parameters(), 0.1);
 
   for (const auto& parameter : model->parameters()) {
-    REQUIRE(!parameter->grad().defined());
+    CATCH_REQUIRE(!parameter->grad().defined());
   }
 
   auto output = model->forward(torch::ones({5, 2}));
@@ -284,19 +284,19 @@
   loss.backward();
 
   for (const auto& parameter : model->parameters()) {
-    REQUIRE(parameter->grad().defined());
-    REQUIRE(parameter->grad().sum().toCFloat() > 0);
+    CATCH_REQUIRE(parameter->grad().defined());
+    CATCH_REQUIRE(parameter->grad().sum().toCFloat() > 0);
   }
 
   optimizer.zero_grad();
 
   for (const auto& parameter : model->parameters()) {
-    REQUIRE(parameter->grad().defined());
-    REQUIRE(parameter->grad().sum().toCFloat() == 0);
+    CATCH_REQUIRE(parameter->grad().defined());
+    CATCH_REQUIRE(parameter->grad().sum().toCFloat() == 0);
   }
 }
 
-TEST_CASE("Optim/ExternalVectorOfParameters") {
+CATCH_TEST_CASE("Optim/ExternalVectorOfParameters") {
   torch::manual_seed(0);
 
   std::vector<torch::Tensor> parameters = {
@@ -313,12 +313,12 @@
 
   optimizer.step();
 
-  REQUIRE(parameters[0].allclose(original_parameters[0] - 1.0));
-  REQUIRE(parameters[1].allclose(original_parameters[1] - 1.0));
-  REQUIRE(parameters[2].allclose(original_parameters[2] - 1.0));
+  CATCH_REQUIRE(parameters[0].allclose(original_parameters[0] - 1.0));
+  CATCH_REQUIRE(parameters[1].allclose(original_parameters[1] - 1.0));
+  CATCH_REQUIRE(parameters[2].allclose(original_parameters[2] - 1.0));
 }
 
-TEST_CASE("Optim/AddParameter/LBFGS") {
+CATCH_TEST_CASE("Optim/AddParameter/LBFGS") {
   torch::manual_seed(0);
 
   std::vector<torch::Tensor> parameters = {torch::randn({5, 5})};
diff --git a/test/cpp/api/parallel.cpp b/test/cpp/api/parallel.cpp
index a151758..33e3a16 100644
--- a/test/cpp/api/parallel.cpp
+++ b/test/cpp/api/parallel.cpp
@@ -1,4 +1,4 @@
-#include <catch.hpp>
+#include "catch_utils.hpp"
 
 #include <torch/csrc/autograd/functions/comm.h>
 #include <torch/nn/module.h>
@@ -19,92 +19,92 @@
 
 #ifdef USE_CUDA
 
-TEST_CASE("Parallel/DifferentiableScatter", "[multi-cuda]") {
+CATCH_TEST_CASE("Parallel/DifferentiableScatter", "[multi-cuda]") {
   Scatter scatter(
       {torch::Device(torch::kCUDA, 0), torch::Device(torch::kCUDA, 1)});
 
   auto input = torch::ones(10, torch::requires_grad(true));
   auto output = scatter.apply({input});
 
-  REQUIRE(output.size() == 2);
-  REQUIRE(output[0].size(0) == 5);
-  REQUIRE(output[1].size(0) == 5);
+  CATCH_REQUIRE(output.size() == 2);
+  CATCH_REQUIRE(output[0].size(0) == 5);
+  CATCH_REQUIRE(output[1].size(0) == 5);
 
-  REQUIRE(torch::cat({output[0].to(torch::kCPU), output[1].to(torch::kCPU)})
+  CATCH_REQUIRE(torch::cat({output[0].to(torch::kCPU), output[1].to(torch::kCPU)})
               .allclose(input));
 
   auto sum = output[0].to({torch::kCUDA, 1}) + output[1];
   sum.backward();
 
-  REQUIRE(input.grad().defined());
-  REQUIRE(input.grad().device().is_cpu());
-  REQUIRE(input.grad().sum().toCInt() == 10);
+  CATCH_REQUIRE(input.grad().defined());
+  CATCH_REQUIRE(input.grad().device().is_cpu());
+  CATCH_REQUIRE(input.grad().sum().toCInt() == 10);
 }
 
-TEST_CASE("Parallel/DifferentiableGather", "[multi-cuda]") {
+CATCH_TEST_CASE("Parallel/DifferentiableGather", "[multi-cuda]") {
   Gather gather(torch::Device(torch::kCUDA, 1));
 
   auto a = torch::ones(5, torch::requires_grad(true).device({torch::kCUDA, 0}));
   auto b = torch::ones(5, torch::requires_grad(true).device({torch::kCUDA, 1}));
 
   auto outputs = gather.apply({a, b});
-  REQUIRE(outputs.size() == 1);
+  CATCH_REQUIRE(outputs.size() == 1);
   auto& output = outputs.front();
 
-  REQUIRE(output.size(0) == 10);
-  REQUIRE(output.device() == torch::Device(torch::kCUDA, 1));
+  CATCH_REQUIRE(output.size(0) == 10);
+  CATCH_REQUIRE(output.device() == torch::Device(torch::kCUDA, 1));
 
   auto chunks = output.chunk(2);
-  REQUIRE(chunks[0].to({torch::kCUDA, 0}).allclose(a));
-  REQUIRE(chunks[1].allclose(b));
+  CATCH_REQUIRE(chunks[0].to({torch::kCUDA, 0}).allclose(a));
+  CATCH_REQUIRE(chunks[1].allclose(b));
 
   output.backward();
 
-  REQUIRE(a.grad().defined());
-  REQUIRE(a.grad().device() == torch::Device(torch::kCUDA, 0));
-  REQUIRE(a.grad().sum().toCInt() == 5);
+  CATCH_REQUIRE(a.grad().defined());
+  CATCH_REQUIRE(a.grad().device() == torch::Device(torch::kCUDA, 0));
+  CATCH_REQUIRE(a.grad().sum().toCInt() == 5);
 
-  REQUIRE(b.grad().defined());
-  REQUIRE(b.grad().device() == torch::Device(torch::kCUDA, 1));
-  REQUIRE(b.grad().sum().toCInt() == 5);
+  CATCH_REQUIRE(b.grad().defined());
+  CATCH_REQUIRE(b.grad().device() == torch::Device(torch::kCUDA, 1));
+  CATCH_REQUIRE(b.grad().sum().toCInt() == 5);
 }
 
-TEST_CASE("Parallel/Replicate", "[multi-cuda]") {
+CATCH_TEST_CASE("Parallel/Replicate", "[multi-cuda]") {
   Linear linear(3, 4);
   auto replicas = parallel::replicate(
       linear, {torch::Device(torch::kCUDA, 0), torch::Device(torch::kCUDA, 1)});
-  REQUIRE(replicas.size() == 2);
+  CATCH_REQUIRE(replicas.size() == 2);
 
   auto original_parameters = linear->parameters();
 
   auto replica1_parameters = replicas[0]->parameters();
   for (auto& parameter : replica1_parameters) {
-    REQUIRE(parameter->device() == torch::Device(torch::kCUDA, 0));
+    CATCH_REQUIRE(parameter->device() == torch::Device(torch::kCUDA, 0));
   }
   replicas[0]->to(torch::kCPU);
-  REQUIRE(replica1_parameters.size() == original_parameters.size());
+  CATCH_REQUIRE(replica1_parameters.size() == original_parameters.size());
   for (size_t i = 0; i < original_parameters.size(); ++i) {
-    REQUIRE(replica1_parameters[i]->allclose(*original_parameters[i]));
-    REQUIRE(
+    CATCH_REQUIRE(replica1_parameters[i]->allclose(*original_parameters[i]));
+    CATCH_REQUIRE(
         replica1_parameters[i].data<float>() !=
         original_parameters[i].data<float>());
   }
 
   auto replica2_parameters = replicas[1]->parameters();
   for (auto& parameter : replica2_parameters) {
-    REQUIRE(parameter->device() == torch::Device(torch::kCUDA, 1));
+    CATCH_REQUIRE(parameter->device() == torch::Device(torch::kCUDA, 1));
   }
   replicas[1]->to(torch::kCPU);
-  REQUIRE(replica2_parameters.size() == original_parameters.size());
+  CATCH_REQUIRE(replica2_parameters.size() == original_parameters.size());
   for (size_t i = 0; i < original_parameters.size(); ++i) {
-    REQUIRE(replica2_parameters[i]->allclose(*original_parameters[i]));
-    REQUIRE(
+    CATCH_REQUIRE(replica2_parameters[i]->allclose(*original_parameters[i]));
+    CATCH_REQUIRE(
         replica2_parameters[i].data<float>() !=
         original_parameters[i].data<float>());
   }
 }
 
-TEST_CASE("Parallel/ParallelApply", "[multi-cuda]") {
+CATCH_TEST_CASE("Parallel/ParallelApply", "[multi-cuda]") {
   Linear a(3, 4);
 
   Linear b(std::static_pointer_cast<LinearImpl>(a->clone()));
@@ -121,17 +121,17 @@
 
   auto outputs = parallel::parallel_apply(modules, inputs);
 
-  REQUIRE(outputs.size() == 3);
-  REQUIRE(outputs[0].device().is_cpu());
+  CATCH_REQUIRE(outputs.size() == 3);
+  CATCH_REQUIRE(outputs[0].device().is_cpu());
 
-  REQUIRE(outputs[1].device() == torch::Device(torch::kCUDA, 0));
-  REQUIRE(outputs[1].to(torch::kCPU).allclose(outputs[0]));
+  CATCH_REQUIRE(outputs[1].device() == torch::Device(torch::kCUDA, 0));
+  CATCH_REQUIRE(outputs[1].to(torch::kCPU).allclose(outputs[0]));
 
-  REQUIRE(outputs[2].device() == torch::Device(torch::kCUDA, 1));
-  REQUIRE(outputs[2].to(torch::kCPU).allclose(outputs[0]));
+  CATCH_REQUIRE(outputs[2].device() == torch::Device(torch::kCUDA, 1));
+  CATCH_REQUIRE(outputs[2].to(torch::kCPU).allclose(outputs[0]));
 }
 
-TEST_CASE("Parallel/ParallelApplyWithDifferentOutputDevice", "[multi-cuda]") {
+CATCH_TEST_CASE("Parallel/ParallelApplyWithDifferentOutputDevice", "[multi-cuda]") {
   struct M : torch::nn::Module {
     torch::Tensor forward(torch::Tensor input) {
       return torch::ones({5}, torch::dtype(torch::kInt32));
@@ -147,17 +147,17 @@
 
   auto outputs = parallel::parallel_apply(modules, inputs, devices);
 
-  REQUIRE(outputs.size() == 3);
-  REQUIRE(outputs[0].device().is_cuda());
-  REQUIRE(outputs[0].device() == torch::Device(torch::kCUDA, 1));
+  CATCH_REQUIRE(outputs.size() == 3);
+  CATCH_REQUIRE(outputs[0].device().is_cuda());
+  CATCH_REQUIRE(outputs[0].device() == torch::Device(torch::kCUDA, 1));
 
-  REQUIRE(outputs[1].device().is_cuda());
-  REQUIRE(outputs[1].device() == torch::Device(torch::kCUDA, 0));
+  CATCH_REQUIRE(outputs[1].device().is_cuda());
+  CATCH_REQUIRE(outputs[1].device() == torch::Device(torch::kCUDA, 0));
 
-  REQUIRE(outputs[2].device().is_cpu());
+  CATCH_REQUIRE(outputs[2].device().is_cpu());
 }
 
-TEST_CASE("Parallel/ParallelApplyRethrowsException", "[multi-cuda]") {
+CATCH_TEST_CASE("Parallel/ParallelApplyRethrowsException", "[multi-cuda]") {
   struct M : torch::nn::Cloneable<M> {
     void reset() override {}
     torch::Tensor forward(torch::Tensor input) {
@@ -167,11 +167,11 @@
 
   auto m = std::make_shared<M>();
   auto input = torch::ones({10, 3});
-  REQUIRE_THROWS_WITH(
+  CATCH_REQUIRE_THROWS_WITH(
       parallel::data_parallel(m, input), StartsWith("Badness!"));
 }
 
-TEST_CASE(
+CATCH_TEST_CASE(
     "Parallel/DataParallelPlacesTheOutputOnTheRequestedDevice",
     "[multi-cuda]") {
   struct M : torch::nn::Cloneable<M> {
@@ -192,9 +192,9 @@
         input,
         /*devices=*/at::nullopt,
         /*output_device=*/torch::Device(torch::kCUDA, 1));
-    REQUIRE(output.defined());
-    REQUIRE(output.device().is_cuda());
-    REQUIRE(output.device().index() == 1);
+    CATCH_REQUIRE(output.defined());
+    CATCH_REQUIRE(output.device().is_cuda());
+    CATCH_REQUIRE(output.device().index() == 1);
   }
   {
     // Verify for the single-device case (where we don't scatter/gather).
@@ -203,16 +203,16 @@
         input,
         /*devices=*/std::vector<torch::Device>{torch::Device(torch::kCUDA, 0)},
         /*output_device=*/torch::Device(torch::kCUDA, 1));
-    REQUIRE(m->intermediate_tensor.defined());
-    REQUIRE(m->intermediate_tensor.device().is_cuda());
-    REQUIRE(m->intermediate_tensor.device().index() == 0);
-    REQUIRE(output.defined());
-    REQUIRE(output.device().is_cuda());
-    REQUIRE(output.device().index() == 1);
+    CATCH_REQUIRE(m->intermediate_tensor.defined());
+    CATCH_REQUIRE(m->intermediate_tensor.device().is_cuda());
+    CATCH_REQUIRE(m->intermediate_tensor.device().index() == 0);
+    CATCH_REQUIRE(output.defined());
+    CATCH_REQUIRE(output.device().is_cuda());
+    CATCH_REQUIRE(output.device().index() == 1);
   }
 }
 
-TEST_CASE("Parallel/DataParallelUsesAllAvailableCUDADevices", "[cuda]") {
+CATCH_TEST_CASE("Parallel/DataParallelUsesAllAvailableCUDADevices", "[cuda]") {
   struct M : torch::nn::Cloneable<M> {
     void reset() override {}
     torch::Tensor forward(torch::Tensor input) {
@@ -225,9 +225,9 @@
   auto output = parallel::data_parallel(m, input);
 
   const auto device_count = torch::cuda::device_count();
-  REQUIRE(output.numel() == device_count);
+  CATCH_REQUIRE(output.numel() == device_count);
   for (size_t i = 0; i < device_count; ++i) {
-    REQUIRE(output[i].toCInt() == i);
+    CATCH_REQUIRE(output[i].toCInt() == i);
   }
 }
 
diff --git a/test/cpp/api/rnn.cpp b/test/cpp/api/rnn.cpp
index 9668572..a307851 100644
--- a/test/cpp/api/rnn.cpp
+++ b/test/cpp/api/rnn.cpp
@@ -1,4 +1,4 @@
-#include <catch.hpp>
+#include "catch_utils.hpp"
 
 #include <torch/nn/modules/linear.h>
 #include <torch/nn/modules/rnn.h>
@@ -71,22 +71,22 @@
   // Expect the LSTM to have 64 outputs and 3 layers, with an input of batch
   // 10 and 16 time steps (10 x 16 x n)
 
-  REQUIRE(output.output.ndimension() == 3);
-  REQUIRE(output.output.size(0) == 10);
-  REQUIRE(output.output.size(1) == 16);
-  REQUIRE(output.output.size(2) == 64);
+  CATCH_REQUIRE(output.output.ndimension() == 3);
+  CATCH_REQUIRE(output.output.size(0) == 10);
+  CATCH_REQUIRE(output.output.size(1) == 16);
+  CATCH_REQUIRE(output.output.size(2) == 64);
 
-  REQUIRE(output.state.ndimension() == 4);
-  REQUIRE(output.state.size(0) == 2); // (hx, cx)
-  REQUIRE(output.state.size(1) == 3); // layers
-  REQUIRE(output.state.size(2) == 16); // Batchsize
-  REQUIRE(output.state.size(3) == 64); // 64 hidden dims
+  CATCH_REQUIRE(output.state.ndimension() == 4);
+  CATCH_REQUIRE(output.state.size(0) == 2); // (hx, cx)
+  CATCH_REQUIRE(output.state.size(1) == 3); // layers
+  CATCH_REQUIRE(output.state.size(2) == 16); // Batchsize
+  CATCH_REQUIRE(output.state.size(3) == 64); // 64 hidden dims
 
   // Something is in the hiddens
-  REQUIRE(output.state.norm().toCFloat() > 0);
+  CATCH_REQUIRE(output.state.norm().toCFloat() > 0);
 }
 
-TEST_CASE("RNN/CheckOutputSizes") {
+CATCH_TEST_CASE("RNN/CheckOutputSizes") {
   torch::manual_seed(0);
   LSTM model(LSTMOptions(128, 64).layers(3).dropout(0.2));
   // Input size is: sequence length, batch size, input size
@@ -104,10 +104,10 @@
   torch::Tensor diff = next.state - output.state;
 
   // Hiddens changed
-  REQUIRE(diff.abs().sum().toCFloat() > 1e-3);
+  CATCH_REQUIRE(diff.abs().sum().toCFloat() > 1e-3);
 }
 
-TEST_CASE("RNN/CheckOutputValuesMatchPyTorch") {
+CATCH_TEST_CASE("RNN/CheckOutputValuesMatchPyTorch") {
   torch::manual_seed(0);
   // Make sure the outputs match pytorch outputs
   LSTM model(2, 2);
@@ -127,10 +127,10 @@
   }
 
   auto out = model->forward(x);
-  REQUIRE(out.output.ndimension() == 3);
-  REQUIRE(out.output.size(0) == 3);
-  REQUIRE(out.output.size(1) == 4);
-  REQUIRE(out.output.size(2) == 2);
+  CATCH_REQUIRE(out.output.ndimension() == 3);
+  CATCH_REQUIRE(out.output.size(0) == 3);
+  CATCH_REQUIRE(out.output.size(1) == 4);
+  CATCH_REQUIRE(out.output.size(2) == 2);
 
   auto flat = out.output.view(3 * 4 * 2);
   float c_out[] = {0.4391, 0.5402, 0.4330, 0.5324, 0.4261, 0.5239,
@@ -138,14 +138,14 @@
                    0.6620, 0.7860, 0.6501, 0.7741, 0.7889, 0.9003,
                    0.7769, 0.8905, 0.7635, 0.8794, 0.7484, 0.8666};
   for (size_t i = 0; i < 3 * 4 * 2; i++) {
-    REQUIRE(std::abs(flat[i].toCFloat() - c_out[i]) < 1e-3);
+    CATCH_REQUIRE(std::abs(flat[i].toCFloat() - c_out[i]) < 1e-3);
   }
 
-  REQUIRE(out.state.ndimension() == 4); // (hx, cx) x layers x B x 2
-  REQUIRE(out.state.size(0) == 2);
-  REQUIRE(out.state.size(1) == 1);
-  REQUIRE(out.state.size(2) == 4);
-  REQUIRE(out.state.size(3) == 2);
+  CATCH_REQUIRE(out.state.ndimension() == 4); // (hx, cx) x layers x B x 2
+  CATCH_REQUIRE(out.state.size(0) == 2);
+  CATCH_REQUIRE(out.state.size(1) == 1);
+  CATCH_REQUIRE(out.state.size(2) == 4);
+  CATCH_REQUIRE(out.state.size(3) == 2);
   flat = out.state.view(16);
   float h_out[] = {0.7889,
                    0.9003,
@@ -164,33 +164,33 @@
                    1.0931,
                    1.4911};
   for (size_t i = 0; i < 16; i++) {
-    REQUIRE(std::abs(flat[i].toCFloat() - h_out[i]) < 1e-3);
+    CATCH_REQUIRE(std::abs(flat[i].toCFloat() - h_out[i]) < 1e-3);
   }
 }
 
-TEST_CASE("RNN/integration/LSTM") {
-  REQUIRE(test_RNN_xor<LSTM>(
+CATCH_TEST_CASE("RNN/integration/LSTM") {
+  CATCH_REQUIRE(test_RNN_xor<LSTM>(
       [](int s) { return LSTM(LSTMOptions(s, s).layers(2)); }));
 }
 
-TEST_CASE("RNN/integration/GRU") {
-  REQUIRE(
+CATCH_TEST_CASE("RNN/integration/GRU") {
+  CATCH_REQUIRE(
       test_RNN_xor<GRU>([](int s) { return GRU(GRUOptions(s, s).layers(2)); }));
 }
 
-TEST_CASE("RNN/integration/RNN") {
-  SECTION("relu") {
-    REQUIRE(test_RNN_xor<RNN>(
+CATCH_TEST_CASE("RNN/integration/RNN") {
+  CATCH_SECTION("relu") {
+    CATCH_REQUIRE(test_RNN_xor<RNN>(
         [](int s) { return RNN(RNNOptions(s, s).relu().layers(2)); }));
   }
-  SECTION("tanh") {
-    REQUIRE(test_RNN_xor<RNN>(
+  CATCH_SECTION("tanh") {
+    CATCH_REQUIRE(test_RNN_xor<RNN>(
         [](int s) { return RNN(RNNOptions(s, s).tanh().layers(2)); }));
   }
 }
 
-TEST_CASE("rnn_cuda", "[cuda]") {
-  SECTION("sizes") {
+CATCH_TEST_CASE("rnn_cuda", "[cuda]") {
+  CATCH_SECTION("sizes") {
     torch::manual_seed(0);
     LSTM model(LSTMOptions(128, 64).layers(3).dropout(0.2));
     model->to(torch::kCUDA);
@@ -209,26 +209,26 @@
     torch::Tensor diff = next.state - output.state;
 
     // Hiddens changed
-    REQUIRE(diff.abs().sum().toCFloat() > 1e-3);
+    CATCH_REQUIRE(diff.abs().sum().toCFloat() > 1e-3);
   }
 
-  SECTION("lstm") {
-    REQUIRE(test_RNN_xor<LSTM>(
+  CATCH_SECTION("lstm") {
+    CATCH_REQUIRE(test_RNN_xor<LSTM>(
         [](int s) { return LSTM(LSTMOptions(s, s).layers(2)); }, true));
   }
 
-  SECTION("gru") {
-    REQUIRE(test_RNN_xor<GRU>(
+  CATCH_SECTION("gru") {
+    CATCH_REQUIRE(test_RNN_xor<GRU>(
         [](int s) { return GRU(GRUOptions(s, s).layers(2)); }, true));
   }
 
-  SECTION("rnn") {
-    SECTION("relu") {
-      REQUIRE(test_RNN_xor<RNN>(
+  CATCH_SECTION("rnn") {
+    CATCH_SECTION("relu") {
+      CATCH_REQUIRE(test_RNN_xor<RNN>(
           [](int s) { return RNN(RNNOptions(s, s).relu().layers(2)); }, true));
     }
-    SECTION("tanh") {
-      REQUIRE(test_RNN_xor<RNN>(
+    CATCH_SECTION("tanh") {
+      CATCH_REQUIRE(test_RNN_xor<RNN>(
           [](int s) { return RNN(RNNOptions(s, s).tanh().layers(2)); }, true));
     }
   }
diff --git a/test/cpp/api/sequential.cpp b/test/cpp/api/sequential.cpp
index aef1332..777d6e2 100644
--- a/test/cpp/api/sequential.cpp
+++ b/test/cpp/api/sequential.cpp
@@ -1,4 +1,4 @@
-#include <catch.hpp>
+#include "catch_utils.hpp"
 
 #include <torch/nn/modules.h>
 #include <torch/nn/modules/batchnorm.h>
@@ -21,7 +21,7 @@
 
 using Catch::StartsWith;
 
-TEST_CASE("Sequential/ConstructsFromSharedPointer") {
+CATCH_TEST_CASE("Sequential/ConstructsFromSharedPointer") {
   struct M : torch::nn::Module {
     explicit M(int value_) : value(value_) {}
     int value;
@@ -31,10 +31,10 @@
   };
   Sequential sequential(
       std::make_shared<M>(1), std::make_shared<M>(2), std::make_shared<M>(3));
-  REQUIRE(sequential->size() == 3);
+  CATCH_REQUIRE(sequential->size() == 3);
 }
 
-TEST_CASE("Sequential/ConstructsFromConcreteType") {
+CATCH_TEST_CASE("Sequential/ConstructsFromConcreteType") {
   struct M : torch::nn::Module {
     explicit M(int value_) : value(value_) {}
     int value;
@@ -44,9 +44,9 @@
   };
 
   Sequential sequential(M(1), M(2), M(3));
-  REQUIRE(sequential->size() == 3);
+  CATCH_REQUIRE(sequential->size() == 3);
 }
-TEST_CASE("Sequential/ConstructsFromModuleHolder") {
+CATCH_TEST_CASE("Sequential/ConstructsFromModuleHolder") {
   struct MImpl : torch::nn::Module {
     explicit MImpl(int value_) : value(value_) {}
     int forward() {
@@ -61,10 +61,10 @@
   };
 
   Sequential sequential(M(1), M(2), M(3));
-  REQUIRE(sequential->size() == 3);
+  CATCH_REQUIRE(sequential->size() == 3);
 }
 
-TEST_CASE("Sequential/PushBackAddsAnElement") {
+CATCH_TEST_CASE("Sequential/PushBackAddsAnElement") {
   struct M : torch::nn::Module {
     explicit M(int value_) : value(value_) {}
     int forward() {
@@ -73,17 +73,17 @@
     int value;
   };
   Sequential sequential;
-  REQUIRE(sequential->size() == 0);
-  REQUIRE(sequential->is_empty());
+  CATCH_REQUIRE(sequential->size() == 0);
+  CATCH_REQUIRE(sequential->is_empty());
   sequential->push_back(Linear(3, 4));
-  REQUIRE(sequential->size() == 1);
+  CATCH_REQUIRE(sequential->size() == 1);
   sequential->push_back(std::make_shared<M>(1));
-  REQUIRE(sequential->size() == 2);
+  CATCH_REQUIRE(sequential->size() == 2);
   sequential->push_back(M(2));
-  REQUIRE(sequential->size() == 3);
+  CATCH_REQUIRE(sequential->size() == 3);
 }
 
-TEST_CASE("Sequential/AccessWithAt") {
+CATCH_TEST_CASE("Sequential/AccessWithAt") {
   struct M : torch::nn::Module {
     explicit M(int value_) : value(value_) {}
     int forward() {
@@ -98,22 +98,22 @@
   for (auto& module : modules) {
     sequential->push_back(module);
   }
-  REQUIRE(sequential->size() == 3);
+  CATCH_REQUIRE(sequential->size() == 3);
 
   // returns the correct module for a given index
   for (size_t i = 0; i < modules.size(); ++i) {
-    REQUIRE(&sequential->at<M>(i) == modules[i].get());
+    CATCH_REQUIRE(&sequential->at<M>(i) == modules[i].get());
   }
 
   // throws for a bad index
-  REQUIRE_THROWS_WITH(
+  CATCH_REQUIRE_THROWS_WITH(
       sequential->at<M>(modules.size() + 1), StartsWith("Index out of range"));
-  REQUIRE_THROWS_WITH(
+  CATCH_REQUIRE_THROWS_WITH(
       sequential->at<M>(modules.size() + 1000000),
       StartsWith("Index out of range"));
 }
 
-TEST_CASE("Sequential/AccessWithPtr") {
+CATCH_TEST_CASE("Sequential/AccessWithPtr") {
   struct M : torch::nn::Module {
     explicit M(int value_) : value(value_) {}
     int forward() {
@@ -128,46 +128,46 @@
   for (auto& module : modules) {
     sequential->push_back(module);
   }
-  REQUIRE(sequential->size() == 3);
+  CATCH_REQUIRE(sequential->size() == 3);
 
   // returns the correct module for a given index
   for (size_t i = 0; i < modules.size(); ++i) {
-    REQUIRE(sequential->ptr(i).get() == modules[i].get());
-    REQUIRE(sequential[i].get() == modules[i].get());
-    REQUIRE(sequential->ptr<M>(i).get() == modules[i].get());
+    CATCH_REQUIRE(sequential->ptr(i).get() == modules[i].get());
+    CATCH_REQUIRE(sequential[i].get() == modules[i].get());
+    CATCH_REQUIRE(sequential->ptr<M>(i).get() == modules[i].get());
   }
 
   // throws for a bad index
-  REQUIRE_THROWS_WITH(
+  CATCH_REQUIRE_THROWS_WITH(
       sequential->ptr(modules.size() + 1), StartsWith("Index out of range"));
-  REQUIRE_THROWS_WITH(
+  CATCH_REQUIRE_THROWS_WITH(
       sequential->ptr(modules.size() + 1000000),
       StartsWith("Index out of range"));
 }
 
-TEST_CASE("Sequential/CallingForwardOnEmptySequentialIsDisallowed") {
+CATCH_TEST_CASE("Sequential/CallingForwardOnEmptySequentialIsDisallowed") {
   Sequential empty;
-  REQUIRE_THROWS_WITH(
+  CATCH_REQUIRE_THROWS_WITH(
       empty->forward<int>(),
       StartsWith("Cannot call forward() on an empty Sequential"));
 }
 
-TEST_CASE("Sequential/CallingForwardChainsCorrectly") {
+CATCH_TEST_CASE("Sequential/CallingForwardChainsCorrectly") {
   struct MockModule : torch::nn::Module {
     explicit MockModule(int value) : expected(value) {}
     int expected;
     int forward(int value) {
-      REQUIRE(value == expected);
+      CATCH_REQUIRE(value == expected);
       return value + 1;
     }
   };
 
   Sequential sequential(MockModule{1}, MockModule{2}, MockModule{3});
 
-  REQUIRE(sequential->forward<int>(1) == 4);
+  CATCH_REQUIRE(sequential->forward<int>(1) == 4);
 }
 
-TEST_CASE("Sequential/CallingForwardWithTheWrongReturnTypeThrows") {
+CATCH_TEST_CASE("Sequential/CallingForwardWithTheWrongReturnTypeThrows") {
   struct M : public torch::nn::Module {
     int forward() {
       return 5;
@@ -175,14 +175,14 @@
   };
 
   Sequential sequential(M{});
-  REQUIRE(sequential->forward<int>() == 5);
-  REQUIRE_THROWS_WITH(
+  CATCH_REQUIRE(sequential->forward<int>() == 5);
+  CATCH_REQUIRE_THROWS_WITH(
       sequential->forward<float>(),
       StartsWith("The type of the return value "
                  "is int, but you asked for type float"));
 }
 
-TEST_CASE("Sequential/TheReturnTypeOfForwardDefaultsToTensor") {
+CATCH_TEST_CASE("Sequential/TheReturnTypeOfForwardDefaultsToTensor") {
   struct M : public torch::nn::Module {
     torch::Tensor forward(torch::Tensor v) {
       return v;
@@ -191,21 +191,21 @@
 
   Sequential sequential(M{});
   auto variable = torch::ones({3, 3}, torch::requires_grad());
-  REQUIRE(sequential->forward(variable).equal(variable));
+  CATCH_REQUIRE(sequential->forward(variable).equal(variable));
 }
 
-TEST_CASE("Sequential/ForwardReturnsTheLastValue") {
+CATCH_TEST_CASE("Sequential/ForwardReturnsTheLastValue") {
   torch::manual_seed(0);
   Sequential sequential(Linear(10, 3), Linear(3, 5), Linear(5, 100));
 
   auto x = torch::randn({1000, 10}, torch::requires_grad());
   auto y = sequential->forward(x);
-  REQUIRE(y.ndimension() == 2);
-  REQUIRE(y.size(0) == 1000);
-  REQUIRE(y.size(1) == 100);
+  CATCH_REQUIRE(y.ndimension() == 2);
+  CATCH_REQUIRE(y.size(0) == 1000);
+  CATCH_REQUIRE(y.size(1) == 100);
 }
 
-TEST_CASE("Sequential/SanityCheckForHoldingStandardModules") {
+CATCH_TEST_CASE("Sequential/SanityCheckForHoldingStandardModules") {
   Sequential sequential(
       Linear(10, 3),
       Conv2d(1, 2, 3),
@@ -215,7 +215,7 @@
       LSTM(4, 5));
 }
 
-TEST_CASE("Sequential/ExtendPushesModulesFromOtherSequential") {
+CATCH_TEST_CASE("Sequential/ExtendPushesModulesFromOtherSequential") {
   struct A : torch::nn::Module {
     int forward(int x) {
       return x;
@@ -240,34 +240,34 @@
   Sequential b(C{}, D{});
   a->extend(*b);
 
-  REQUIRE(a->size() == 4);
-  REQUIRE(a[0]->as<A>());
-  REQUIRE(a[1]->as<B>());
-  REQUIRE(a[2]->as<C>());
-  REQUIRE(a[3]->as<D>());
+  CATCH_REQUIRE(a->size() == 4);
+  CATCH_REQUIRE(a[0]->as<A>());
+  CATCH_REQUIRE(a[1]->as<B>());
+  CATCH_REQUIRE(a[2]->as<C>());
+  CATCH_REQUIRE(a[3]->as<D>());
 
-  REQUIRE(b->size() == 2);
-  REQUIRE(b[0]->as<C>());
-  REQUIRE(b[1]->as<D>());
+  CATCH_REQUIRE(b->size() == 2);
+  CATCH_REQUIRE(b[0]->as<C>());
+  CATCH_REQUIRE(b[1]->as<D>());
 
   std::vector<std::shared_ptr<A>> c = {std::make_shared<A>(),
                                        std::make_shared<A>()};
   b->extend(c);
 
-  REQUIRE(b->size() == 4);
-  REQUIRE(b[0]->as<C>());
-  REQUIRE(b[1]->as<D>());
-  REQUIRE(b[2]->as<A>());
-  REQUIRE(b[3]->as<A>());
+  CATCH_REQUIRE(b->size() == 4);
+  CATCH_REQUIRE(b[0]->as<C>());
+  CATCH_REQUIRE(b[1]->as<D>());
+  CATCH_REQUIRE(b[2]->as<A>());
+  CATCH_REQUIRE(b[3]->as<A>());
 }
 
-TEST_CASE("Sequential/HasReferenceSemantics") {
+CATCH_TEST_CASE("Sequential/HasReferenceSemantics") {
   Sequential first(Linear(2, 3), Linear(4, 4), Linear(4, 5));
   Sequential second(first);
 
-  REQUIRE(first.get() == second.get());
-  REQUIRE(first->size() == second->size());
-  REQUIRE(std::equal(
+  CATCH_REQUIRE(first.get() == second.get());
+  CATCH_REQUIRE(first->size() == second->size());
+  CATCH_REQUIRE(std::equal(
       first->begin(),
       first->end(),
       second->begin(),
@@ -276,17 +276,17 @@
       }));
 }
 
-TEST_CASE("Sequential/IsCloneable") {
+CATCH_TEST_CASE("Sequential/IsCloneable") {
   Sequential sequential(Linear(3, 4), Functional(torch::relu), BatchNorm(3));
   Sequential clone =
       std::dynamic_pointer_cast<SequentialImpl>(sequential->clone());
-  REQUIRE(sequential->size() == clone->size());
+  CATCH_REQUIRE(sequential->size() == clone->size());
 
   for (size_t i = 0; i < sequential->size(); ++i) {
     // The modules should be the same kind (type).
-    REQUIRE(sequential[i]->name() == clone[i]->name());
+    CATCH_REQUIRE(sequential[i]->name() == clone[i]->name());
     // But not pointer-equal (distinct objects).
-    REQUIRE(sequential[i] != clone[i]);
+    CATCH_REQUIRE(sequential[i] != clone[i]);
   }
 
   // Verify that the clone is deep, i.e. parameters of modules are cloned too.
@@ -295,38 +295,38 @@
 
   auto params1 = sequential->parameters();
   auto params2 = clone->parameters();
-  REQUIRE(params1.size() == params2.size());
+  CATCH_REQUIRE(params1.size() == params2.size());
   for (auto& param : params1) {
-    REQUIRE(!pointer_equal(param.value, params2[param.key]));
-    REQUIRE(param->device() == params2[param.key].device());
-    REQUIRE(param->allclose(params2[param.key]));
+    CATCH_REQUIRE(!pointer_equal(param.value, params2[param.key]));
+    CATCH_REQUIRE(param->device() == params2[param.key].device());
+    CATCH_REQUIRE(param->allclose(params2[param.key]));
     param->add_(2);
   }
   for (auto& param : params1) {
-    REQUIRE(!param->allclose(params2[param.key]));
+    CATCH_REQUIRE(!param->allclose(params2[param.key]));
   }
 }
 
-TEST_CASE("Sequential/RegistersElementsAsSubmodules") {
+CATCH_TEST_CASE("Sequential/RegistersElementsAsSubmodules") {
   Sequential sequential(Linear(10, 3), Conv2d(1, 2, 3), FeatureDropout(0.5));
 
   auto modules = sequential->modules();
-  REQUIRE(modules.size() == sequential->children().size());
+  CATCH_REQUIRE(modules.size() == sequential->children().size());
 
-  REQUIRE(modules[0]->as<Linear>());
-  REQUIRE(modules[1]->as<Conv2d>());
-  REQUIRE(modules[2]->as<FeatureDropout>());
+  CATCH_REQUIRE(modules[0]->as<Linear>());
+  CATCH_REQUIRE(modules[1]->as<Conv2d>());
+  CATCH_REQUIRE(modules[2]->as<FeatureDropout>());
 }
 
-TEST_CASE("Sequential/CloneToDevice", "[cuda]") {
+CATCH_TEST_CASE("Sequential/CloneToDevice", "[cuda]") {
   Sequential sequential(Linear(3, 4), Functional(torch::relu), BatchNorm(3));
   torch::Device device(torch::kCUDA, 0);
   Sequential clone =
       std::dynamic_pointer_cast<SequentialImpl>(sequential->clone(device));
   for (const auto& p : clone->parameters()) {
-    REQUIRE(p->device() == device);
+    CATCH_REQUIRE(p->device() == device);
   }
   for (const auto& b : clone->buffers()) {
-    REQUIRE(b->device() == device);
+    CATCH_REQUIRE(b->device() == device);
   }
 }
diff --git a/test/cpp/api/serialization.cpp b/test/cpp/api/serialization.cpp
index 3541089..fda133b 100644
--- a/test/cpp/api/serialization.cpp
+++ b/test/cpp/api/serialization.cpp
@@ -1,4 +1,4 @@
-#include <catch.hpp>
+#include "catch_utils.hpp"
 
 #include <torch/nn/modules/functional.h>
 #include <torch/nn/modules/linear.h>
@@ -30,12 +30,12 @@
 }
 } // namespace
 
-TEST_CASE("serialization") {
+CATCH_TEST_CASE("serialization") {
   torch::manual_seed(0);
-  SECTION("undefined") {
+  CATCH_SECTION("undefined") {
     auto x = torch::Tensor();
 
-    REQUIRE(!x.defined());
+    CATCH_REQUIRE(!x.defined());
 
     auto y = torch::randn({5});
 
@@ -43,10 +43,10 @@
     torch::save(ss, &x);
     torch::load(ss, &y);
 
-    REQUIRE(!y.defined());
+    CATCH_REQUIRE(!y.defined());
   }
 
-  SECTION("cputypes") {
+  CATCH_SECTION("cputypes") {
     for (int i = 0; i < static_cast<int>(torch::Dtype::NumOptions); i++) {
       if (i == static_cast<int>(torch::Dtype::Half)) {
         // XXX can't serialize half tensors at the moment since contiguous() is
@@ -69,17 +69,17 @@
       torch::save(ss, &x);
       torch::load(ss, &y);
 
-      REQUIRE(y.defined());
-      REQUIRE(x.sizes().vec() == y.sizes().vec());
+      CATCH_REQUIRE(y.defined());
+      CATCH_REQUIRE(x.sizes().vec() == y.sizes().vec());
       if (torch::isIntegralType(static_cast<torch::Dtype>(i))) {
-        REQUIRE(x.equal(y));
+        CATCH_REQUIRE(x.equal(y));
       } else {
-        REQUIRE(x.allclose(y));
+        CATCH_REQUIRE(x.allclose(y));
       }
     }
   }
 
-  SECTION("binary") {
+  CATCH_SECTION("binary") {
     auto x = torch::randn({5, 5});
     auto y = torch::Tensor();
 
@@ -93,11 +93,11 @@
       archive(y);
     }
 
-    REQUIRE(y.defined());
-    REQUIRE(x.sizes().vec() == y.sizes().vec());
-    REQUIRE(x.allclose(y));
+    CATCH_REQUIRE(y.defined());
+    CATCH_REQUIRE(x.sizes().vec() == y.sizes().vec());
+    CATCH_REQUIRE(x.allclose(y));
   }
-  SECTION("portable_binary") {
+  CATCH_SECTION("portable_binary") {
     auto x = torch::randn({5, 5});
     auto y = torch::Tensor();
 
@@ -111,12 +111,12 @@
       archive(y);
     }
 
-    REQUIRE(y.defined());
-    REQUIRE(x.sizes().vec() == y.sizes().vec());
-    REQUIRE(x.allclose(y));
+    CATCH_REQUIRE(y.defined());
+    CATCH_REQUIRE(x.sizes().vec() == y.sizes().vec());
+    CATCH_REQUIRE(x.allclose(y));
   }
 
-  SECTION("resized") {
+  CATCH_SECTION("resized") {
     auto x = torch::randn({11, 5});
     x.resize_({5, 5});
     auto y = torch::Tensor();
@@ -131,11 +131,11 @@
       archive(y);
     }
 
-    REQUIRE(y.defined());
-    REQUIRE(x.sizes().vec() == y.sizes().vec());
-    REQUIRE(x.allclose(y));
+    CATCH_REQUIRE(y.defined());
+    CATCH_REQUIRE(x.sizes().vec() == y.sizes().vec());
+    CATCH_REQUIRE(x.allclose(y));
   }
-  SECTION("sliced") {
+  CATCH_SECTION("sliced") {
     auto x = torch::randn({11, 5});
     x = x.slice(0, 1, 3);
     auto y = torch::Tensor();
@@ -150,12 +150,12 @@
       archive(y);
     }
 
-    REQUIRE(y.defined());
-    REQUIRE(x.sizes().vec() == y.sizes().vec());
-    REQUIRE(x.allclose(y));
+    CATCH_REQUIRE(y.defined());
+    CATCH_REQUIRE(x.sizes().vec() == y.sizes().vec());
+    CATCH_REQUIRE(x.allclose(y));
   }
 
-  SECTION("noncontig") {
+  CATCH_SECTION("noncontig") {
     auto x = torch::randn({11, 5});
     x = x.slice(1, 1, 4);
     auto y = torch::Tensor();
@@ -170,12 +170,12 @@
       archive(y);
     }
 
-    REQUIRE(y.defined());
-    REQUIRE(x.sizes().vec() == y.sizes().vec());
-    REQUIRE(x.allclose(y));
+    CATCH_REQUIRE(y.defined());
+    CATCH_REQUIRE(x.sizes().vec() == y.sizes().vec());
+    CATCH_REQUIRE(x.allclose(y));
   }
 
-  SECTION("xor") {
+  CATCH_SECTION("xor") {
     // We better be able to save and load a XOR model!
     auto getLoss = [](Sequential model, uint32_t batch_size) {
       auto inputs = torch::empty({batch_size, 2});
@@ -207,7 +207,7 @@
       optimizer.step();
 
       running_loss = running_loss * 0.99 + loss.sum().toCFloat() * 0.01;
-      REQUIRE(epoch < 3000);
+      CATCH_REQUIRE(epoch < 3000);
       epoch++;
     }
 
@@ -216,10 +216,10 @@
     torch::load(ss, model2);
 
     auto loss = getLoss(model2, 100);
-    REQUIRE(loss.toCFloat() < 0.1);
+    CATCH_REQUIRE(loss.toCFloat() < 0.1);
   }
 
-  SECTION("optim") {
+  CATCH_SECTION("optim") {
     auto model1 = Linear(5, 2);
     auto model2 = Linear(5, 2);
     auto model3 = Linear(5, 2);
@@ -235,8 +235,8 @@
     auto param2 = model2->parameters();
     auto param3 = model3->parameters();
     for (const auto& p : param1) {
-      REQUIRE(param1[p.key].allclose(param2[p.key]));
-      REQUIRE(param2[p.key].allclose(param3[p.key]));
+      CATCH_REQUIRE(param1[p.key].allclose(param2[p.key]));
+      CATCH_REQUIRE(param2[p.key].allclose(param3[p.key]));
     }
 
     // Make some optimizers with momentum (and thus state)
@@ -281,13 +281,13 @@
     for (const auto& p : param1) {
       const auto& name = p.key;
       // Model 1 and 3 should be the same
-      REQUIRE(param1[name].norm().toCFloat() == param3[name].norm().toCFloat());
-      REQUIRE(param1[name].norm().toCFloat() != param2[name].norm().toCFloat());
+      CATCH_REQUIRE(param1[name].norm().toCFloat() == param3[name].norm().toCFloat());
+      CATCH_REQUIRE(param1[name].norm().toCFloat() != param2[name].norm().toCFloat());
     }
   }
 }
 
-TEST_CASE("serialization_cuda", "[cuda]") {
+CATCH_TEST_CASE("serialization_cuda", "[cuda]") {
   torch::manual_seed(0);
   // We better be able to save and load a XOR model!
   auto getLoss = [](Sequential model, uint32_t batch_size) {
@@ -318,7 +318,7 @@
     optimizer.step();
 
     running_loss = running_loss * 0.99 + loss.sum().toCFloat() * 0.01;
-    REQUIRE(epoch < 3000);
+    CATCH_REQUIRE(epoch < 3000);
     epoch++;
   }
 
@@ -327,7 +327,7 @@
   torch::load(ss, model2);
 
   auto loss = getLoss(model2, 100);
-  REQUIRE(loss.toCFloat() < 0.1);
+  CATCH_REQUIRE(loss.toCFloat() < 0.1);
 
   model2->to(torch::kCUDA);
   ss.clear();
@@ -335,5 +335,5 @@
   torch::load(ss, model3);
 
   loss = getLoss(model3, 100);
-  REQUIRE(loss.toCFloat() < 0.1);
+  CATCH_REQUIRE(loss.toCFloat() < 0.1);
 }
diff --git a/test/cpp/api/tensor.cpp b/test/cpp/api/tensor.cpp
index f08a30d..5760556 100644
--- a/test/cpp/api/tensor.cpp
+++ b/test/cpp/api/tensor.cpp
@@ -1,4 +1,4 @@
-#include <catch.hpp>
+#include "catch_utils.hpp"
 
 #include <torch/tensor.h>
 
@@ -19,12 +19,12 @@
 }
 
 #define REQUIRE_TENSOR_OPTIONS(device_, index_, type_, layout_)                \
-  REQUIRE(tensor.device().type() == at::Device((device_), (index_)).type());   \
-  REQUIRE(tensor.device().index() == at::Device((device_), (index_)).index()); \
-  REQUIRE(tensor.dtype() == (type_));                                          \
-  REQUIRE(tensor.layout() == (layout_))
+  CATCH_REQUIRE(tensor.device().type() == at::Device((device_), (index_)).type());   \
+  CATCH_REQUIRE(tensor.device().index() == at::Device((device_), (index_)).index()); \
+  CATCH_REQUIRE(tensor.dtype() == (type_));                                          \
+  CATCH_REQUIRE(tensor.layout() == (layout_))
 
-TEST_CASE("Tensor/ToDtype") {
+CATCH_TEST_CASE("Tensor/ToDtype") {
   auto tensor = at::empty({3, 4});
   REQUIRE_TENSOR_OPTIONS(at::kCPU, -1, at::kFloat, at::kStrided);
 
@@ -39,7 +39,7 @@
 }
 
 // Not currently supported.
-// TEST_CASE("Tensor/ToLayout") {
+// CATCH_TEST_CASE("Tensor/ToLayout") {
 //   auto tensor = at::empty({3, 4});
 //   REQUIRE_TENSOR_OPTIONS(at::kCPU, -1, at::kFloat, at::kStrided);
 //
@@ -50,7 +50,7 @@
 //   REQUIRE_TENSOR_OPTIONS(at::kCPU, -1, at::kFloat, at::kStrided);
 // }
 
-TEST_CASE("Tensor/ToDevice", "[cuda]") {
+CATCH_TEST_CASE("Tensor/ToDevice", "[cuda]") {
   auto tensor = at::empty({3, 4});
   REQUIRE_TENSOR_OPTIONS(at::kCPU, -1, at::kFloat, at::kStrided);
 
@@ -67,7 +67,7 @@
   REQUIRE_TENSOR_OPTIONS(at::kCPU, -1, at::kFloat, at::kStrided);
 }
 
-TEST_CASE("Tensor/ToDeviceAndDtype", "[cuda]") {
+CATCH_TEST_CASE("Tensor/ToDeviceAndDtype", "[cuda]") {
   auto tensor = at::empty({3, 4});
   REQUIRE_TENSOR_OPTIONS(at::kCPU, -1, at::kFloat, at::kStrided);
 
@@ -75,119 +75,119 @@
   REQUIRE_TENSOR_OPTIONS(at::kCUDA, 1, at::kInt, at::kStrided);
 }
 
-TEST_CASE("Tensor/ToOptionsRespectsRequiresGrad") {
+CATCH_TEST_CASE("Tensor/ToOptionsRespectsRequiresGrad") {
   {
     auto tensor = torch::empty({3, 4}, at::requires_grad());
-    REQUIRE(tensor.requires_grad());
+    CATCH_REQUIRE(tensor.requires_grad());
 
     tensor = tensor.to(at::kDouble);
-    REQUIRE(tensor.requires_grad());
+    CATCH_REQUIRE(tensor.requires_grad());
   }
   {
     auto tensor = torch::empty({3, 4});
-    REQUIRE(!tensor.requires_grad());
+    CATCH_REQUIRE(!tensor.requires_grad());
 
     tensor = tensor.to(at::kDouble);
-    REQUIRE(!tensor.requires_grad());
+    CATCH_REQUIRE(!tensor.requires_grad());
   }
 }
 
-TEST_CASE("Tensor/ToDoesNotCopyWhenOptionsAreAllTheSame") {
+CATCH_TEST_CASE("Tensor/ToDoesNotCopyWhenOptionsAreAllTheSame") {
   auto tensor = at::empty({3, 4}, at::kFloat);
   auto hopefully_not_copy = tensor.to(at::kFloat);
-  REQUIRE(hopefully_not_copy.data<float>() == tensor.data<float>());
+  CATCH_REQUIRE(hopefully_not_copy.data<float>() == tensor.data<float>());
 }
 
-TEST_CASE("Tensor/ContainsCorrectValueForSingleValue") {
+CATCH_TEST_CASE("Tensor/ContainsCorrectValueForSingleValue") {
   auto tensor = at::tensor(123);
-  REQUIRE(tensor.numel() == 1);
-  REQUIRE(tensor.dtype() == at::kInt);
-  REQUIRE(tensor[0].toCInt() == 123);
+  CATCH_REQUIRE(tensor.numel() == 1);
+  CATCH_REQUIRE(tensor.dtype() == at::kInt);
+  CATCH_REQUIRE(tensor[0].toCInt() == 123);
 
   tensor = at::tensor(123.456f);
-  REQUIRE(tensor.numel() == 1);
-  REQUIRE(tensor.dtype() == at::kFloat);
-  REQUIRE(almost_equal(tensor[0], 123.456f));
+  CATCH_REQUIRE(tensor.numel() == 1);
+  CATCH_REQUIRE(tensor.dtype() == at::kFloat);
+  CATCH_REQUIRE(almost_equal(tensor[0], 123.456f));
 
   tensor = at::tensor(123.456);
-  REQUIRE(tensor.numel() == 1);
-  REQUIRE(tensor.dtype() == at::kDouble);
-  REQUIRE(almost_equal(tensor[0], 123.456));
+  CATCH_REQUIRE(tensor.numel() == 1);
+  CATCH_REQUIRE(tensor.dtype() == at::kDouble);
+  CATCH_REQUIRE(almost_equal(tensor[0], 123.456));
 }
 
-TEST_CASE("Tensor/ContainsCorrectValuesForManyValues") {
+CATCH_TEST_CASE("Tensor/ContainsCorrectValuesForManyValues") {
   auto tensor = at::tensor({1, 2, 3});
-  REQUIRE(tensor.numel() == 3);
-  REQUIRE(tensor.dtype() == at::kInt);
-  REQUIRE(exactly_equal(tensor[0], 1));
-  REQUIRE(exactly_equal(tensor[1], 2));
-  REQUIRE(exactly_equal(tensor[2], 3));
+  CATCH_REQUIRE(tensor.numel() == 3);
+  CATCH_REQUIRE(tensor.dtype() == at::kInt);
+  CATCH_REQUIRE(exactly_equal(tensor[0], 1));
+  CATCH_REQUIRE(exactly_equal(tensor[1], 2));
+  CATCH_REQUIRE(exactly_equal(tensor[2], 3));
 
   tensor = at::tensor({1.5, 2.25, 3.125});
-  REQUIRE(tensor.numel() == 3);
-  REQUIRE(tensor.dtype() == at::kDouble);
-  REQUIRE(almost_equal(tensor[0], 1.5));
-  REQUIRE(almost_equal(tensor[1], 2.25));
-  REQUIRE(almost_equal(tensor[2], 3.125));
+  CATCH_REQUIRE(tensor.numel() == 3);
+  CATCH_REQUIRE(tensor.dtype() == at::kDouble);
+  CATCH_REQUIRE(almost_equal(tensor[0], 1.5));
+  CATCH_REQUIRE(almost_equal(tensor[1], 2.25));
+  CATCH_REQUIRE(almost_equal(tensor[2], 3.125));
 }
 
-TEST_CASE("Tensor/ContainsCorrectValuesForManyValuesVariable") {
+CATCH_TEST_CASE("Tensor/ContainsCorrectValuesForManyValuesVariable") {
   auto tensor = torch::tensor({1, 2, 3});
-  REQUIRE(tensor.is_variable());
-  REQUIRE(tensor.numel() == 3);
-  REQUIRE(tensor.dtype() == at::kInt);
-  REQUIRE(exactly_equal(tensor[0], 1));
-  REQUIRE(exactly_equal(tensor[1], 2));
-  REQUIRE(exactly_equal(tensor[2], 3));
+  CATCH_REQUIRE(tensor.is_variable());
+  CATCH_REQUIRE(tensor.numel() == 3);
+  CATCH_REQUIRE(tensor.dtype() == at::kInt);
+  CATCH_REQUIRE(exactly_equal(tensor[0], 1));
+  CATCH_REQUIRE(exactly_equal(tensor[1], 2));
+  CATCH_REQUIRE(exactly_equal(tensor[2], 3));
 
   tensor = torch::tensor({1.5, 2.25, 3.125});
-  REQUIRE(tensor.is_variable());
-  REQUIRE(tensor.numel() == 3);
-  REQUIRE(tensor.dtype() == at::kDouble);
-  REQUIRE(almost_equal(tensor[0], 1.5));
-  REQUIRE(almost_equal(tensor[1], 2.25));
-  REQUIRE(almost_equal(tensor[2], 3.125));
+  CATCH_REQUIRE(tensor.is_variable());
+  CATCH_REQUIRE(tensor.numel() == 3);
+  CATCH_REQUIRE(tensor.dtype() == at::kDouble);
+  CATCH_REQUIRE(almost_equal(tensor[0], 1.5));
+  CATCH_REQUIRE(almost_equal(tensor[1], 2.25));
+  CATCH_REQUIRE(almost_equal(tensor[2], 3.125));
 }
 
-TEST_CASE("Tensor/ContainsCorrectValuesWhenConstructedFromVector") {
+CATCH_TEST_CASE("Tensor/ContainsCorrectValuesWhenConstructedFromVector") {
   std::vector<int> v = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10};
   auto tensor = at::tensor(v);
-  REQUIRE(tensor.numel() == v.size());
-  REQUIRE(tensor.dtype() == at::kInt);
+  CATCH_REQUIRE(tensor.numel() == v.size());
+  CATCH_REQUIRE(tensor.dtype() == at::kInt);
   for (size_t i = 0; i < v.size(); ++i) {
-    REQUIRE(exactly_equal(tensor[i], v.at(i)));
+    CATCH_REQUIRE(exactly_equal(tensor[i], v.at(i)));
   }
 
   std::vector<float> w = {1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8, 9.9, 10.0};
   tensor = at::tensor(w);
-  REQUIRE(tensor.numel() == w.size());
-  REQUIRE(tensor.dtype() == at::kFloat);
+  CATCH_REQUIRE(tensor.numel() == w.size());
+  CATCH_REQUIRE(tensor.dtype() == at::kFloat);
   for (size_t i = 0; i < w.size(); ++i) {
-    REQUIRE(almost_equal(tensor[i], w.at(i)));
+    CATCH_REQUIRE(almost_equal(tensor[i], w.at(i)));
   }
 }
 
-TEST_CASE("Tensor/UsesOptionsThatAreSupplied") {
+CATCH_TEST_CASE("Tensor/UsesOptionsThatAreSupplied") {
   auto tensor = at::tensor(123, dtype(at::kFloat)) + 0.5;
-  REQUIRE(tensor.numel() == 1);
-  REQUIRE(tensor.dtype() == at::kFloat);
-  REQUIRE(almost_equal(tensor[0], 123.5));
+  CATCH_REQUIRE(tensor.numel() == 1);
+  CATCH_REQUIRE(tensor.dtype() == at::kFloat);
+  CATCH_REQUIRE(almost_equal(tensor[0], 123.5));
 
   tensor = at::tensor({1.1, 2.2, 3.3}, dtype(at::kInt));
-  REQUIRE(tensor.numel() == 3);
-  REQUIRE(tensor.dtype() == at::kInt);
-  REQUIRE(tensor.layout() == at::kStrided);
-  REQUIRE(exactly_equal(tensor[0], 1));
-  REQUIRE(exactly_equal(tensor[1], 2));
-  REQUIRE(exactly_equal(tensor[2], 3));
+  CATCH_REQUIRE(tensor.numel() == 3);
+  CATCH_REQUIRE(tensor.dtype() == at::kInt);
+  CATCH_REQUIRE(tensor.layout() == at::kStrided);
+  CATCH_REQUIRE(exactly_equal(tensor[0], 1));
+  CATCH_REQUIRE(exactly_equal(tensor[1], 2));
+  CATCH_REQUIRE(exactly_equal(tensor[2], 3));
 }
 
-TEST_CASE("FromBlob") {
+CATCH_TEST_CASE("FromBlob") {
   std::vector<int32_t> v = {1, 2, 3};
   auto tensor = torch::from_blob(v.data(), v.size(), torch::kInt32);
-  REQUIRE(tensor.is_variable());
-  REQUIRE(tensor.numel() == 3);
-  REQUIRE(tensor[0].toCInt() == 1);
-  REQUIRE(tensor[1].toCInt() == 2);
-  REQUIRE(tensor[2].toCInt() == 3);
+  CATCH_REQUIRE(tensor.is_variable());
+  CATCH_REQUIRE(tensor.numel() == 3);
+  CATCH_REQUIRE(tensor[0].toCInt() == 1);
+  CATCH_REQUIRE(tensor[1].toCInt() == 2);
+  CATCH_REQUIRE(tensor[2].toCInt() == 3);
 }
diff --git a/test/cpp/api/tensor_cuda.cpp b/test/cpp/api/tensor_cuda.cpp
index 82d874e..8f85014 100644
--- a/test/cpp/api/tensor_cuda.cpp
+++ b/test/cpp/api/tensor_cuda.cpp
@@ -1,11 +1,11 @@
-#include <catch.hpp>
+#include "catch_utils.hpp"
 
 #include <ATen/ATen.h>
 
 #include <cmath>
 
-TEST_CASE("Tensor/AllocatesTensorOnTheCorrectDevice", "[multi-cuda]") {
+CATCH_TEST_CASE("Tensor/AllocatesTensorOnTheCorrectDevice", "[multi-cuda]") {
   auto tensor = at::tensor({1, 2, 3}, at::device({at::kCUDA, 1}));
-  REQUIRE(tensor.device().type() == at::Device::Type::CUDA);
-  REQUIRE(tensor.device().index() == 1);
+  CATCH_REQUIRE(tensor.device().type() == at::Device::Type::CUDA);
+  CATCH_REQUIRE(tensor.device().index() == 1);
 }
diff --git a/test/cpp/api/tensor_options.cpp b/test/cpp/api/tensor_options.cpp
index ab80c5f..7118a35 100644
--- a/test/cpp/api/tensor_options.cpp
+++ b/test/cpp/api/tensor_options.cpp
@@ -1,4 +1,4 @@
-#include "catch.hpp"
+#include "catch_utils.hpp"
 
 #include <torch/tensor.h>
 
@@ -14,28 +14,28 @@
 
 // A macro so we don't lose location information when an assertion fails.
 #define REQUIRE_OPTIONS(device_, index_, type_, layout_)                    \
-  REQUIRE(options.device().type() == Device((device_), (index_)).type());   \
-  REQUIRE(options.device().index() == Device((device_), (index_)).index()); \
-  REQUIRE(options.dtype() == (type_));                                      \
-  REQUIRE(options.layout() == (layout_))
+  CATCH_REQUIRE(options.device().type() == Device((device_), (index_)).type());   \
+  CATCH_REQUIRE(options.device().index() == Device((device_), (index_)).index()); \
+  CATCH_REQUIRE(options.dtype() == (type_));                                      \
+  CATCH_REQUIRE(options.layout() == (layout_))
 
 #define REQUIRE_TENSOR_OPTIONS(device_, index_, type_, layout_)            \
-  REQUIRE(tensor.device().type() == Device((device_), (index_)).type());   \
-  REQUIRE(tensor.device().index() == Device((device_), (index_)).index()); \
-  REQUIRE(tensor.type().scalarType() == (type_));                          \
-  REQUIRE(tensor.type().layout() == (layout_))
+  CATCH_REQUIRE(tensor.device().type() == Device((device_), (index_)).type());   \
+  CATCH_REQUIRE(tensor.device().index() == Device((device_), (index_)).index()); \
+  CATCH_REQUIRE(tensor.type().scalarType() == (type_));                          \
+  CATCH_REQUIRE(tensor.type().layout() == (layout_))
 
-TEST_CASE("TensorOptions/DefaultsToTheRightValues") {
+CATCH_TEST_CASE("TensorOptions/DefaultsToTheRightValues") {
   TensorOptions options;
   REQUIRE_OPTIONS(kCPU, -1, kFloat, kStrided);
 }
 
-TEST_CASE("TensorOptions/ReturnsTheCorrectType") {
+CATCH_TEST_CASE("TensorOptions/ReturnsTheCorrectType") {
   auto options = TensorOptions().device(kCPU).dtype(kInt).layout(kSparse);
-  REQUIRE(at::getType(options) == getNonVariableType(Backend::SparseCPU, kInt));
+  CATCH_REQUIRE(at::getType(options) == getNonVariableType(Backend::SparseCPU, kInt));
 }
 
-TEST_CASE("TensorOptions/UtilityFunctionsReturnTheRightTensorOptions") {
+CATCH_TEST_CASE("TensorOptions/UtilityFunctionsReturnTheRightTensorOptions") {
   auto options = dtype(kInt);
   REQUIRE_OPTIONS(kCPU, -1, kInt, kStrided);
 
@@ -52,7 +52,7 @@
   REQUIRE_OPTIONS(kCUDA, 3, kByte, kSparse);
 }
 
-TEST_CASE("TensorOptions/ConstructsWellFromCPUTypes") {
+CATCH_TEST_CASE("TensorOptions/ConstructsWellFromCPUTypes") {
   TensorOptions options;
   REQUIRE_OPTIONS(kCPU, -1, kFloat, kStrided);
 
@@ -69,7 +69,7 @@
   REQUIRE_OPTIONS(kCPU, -1, kByte, kSparse);
 }
 
-TEST_CASE("TensorOptions/ConstructsWellFromCPUTensors") {
+CATCH_TEST_CASE("TensorOptions/ConstructsWellFromCPUTensors") {
   auto options = empty(5, kDouble).options();
   REQUIRE_OPTIONS(kCPU, -1, kDouble, kStrided);
 
@@ -77,37 +77,37 @@
   REQUIRE_OPTIONS(kCPU, -1, kByte, kSparse);
 }
 
-TEST_CASE("TensorOptions/ConstructsWellFromVariables") {
+CATCH_TEST_CASE("TensorOptions/ConstructsWellFromVariables") {
   auto options = torch::empty(5).options();
   REQUIRE_OPTIONS(kCPU, -1, kFloat, kStrided);
-  REQUIRE(!options.requires_grad());
+  CATCH_REQUIRE(!options.requires_grad());
 
   options = torch::empty(5, at::requires_grad()).options();
   REQUIRE_OPTIONS(kCPU, -1, kFloat, kStrided);
-  REQUIRE(!options.requires_grad());
+  CATCH_REQUIRE(!options.requires_grad());
 }
 
-TEST_CASE("Device/ParsesCorrectlyFromString") {
+CATCH_TEST_CASE("Device/ParsesCorrectlyFromString") {
   Device device("cpu:0");
-  REQUIRE(device == Device(kCPU, 0));
+  CATCH_REQUIRE(device == Device(kCPU, 0));
 
   device = Device("cpu");
-  REQUIRE(device == Device(kCPU));
+  CATCH_REQUIRE(device == Device(kCPU));
 
   device = Device("cuda:123");
-  REQUIRE(device == Device(kCUDA, 123));
+  CATCH_REQUIRE(device == Device(kCUDA, 123));
 
   device = Device("cuda");
-  REQUIRE(device == Device(kCUDA));
+  CATCH_REQUIRE(device == Device(kCUDA));
 
   std::vector<std::string> badnesses = {
       "", "cud:1", "cuda:", "cpu::1", ":1", "3", "tpu:4", "??"};
   for (const auto& badness : badnesses) {
-    REQUIRE_THROWS(Device(badness));
+    _CATCH_REQUIRE_THROWS(Device(badness));
   }
 }
 
-TEST_CASE("OptionsGuard") {
+CATCH_TEST_CASE("OptionsGuard") {
   Tensor tensor;
   {
     OptionsGuard guard(TensorOptions{});
@@ -132,5 +132,5 @@
     tensor = torch::empty({10});
   }
   REQUIRE_TENSOR_OPTIONS(kCPU, -1, kFloat, kStrided);
-  REQUIRE(tensor.requires_grad());
+  CATCH_REQUIRE(tensor.requires_grad());
 }
diff --git a/test/cpp/api/tensor_options_cuda.cpp b/test/cpp/api/tensor_options_cuda.cpp
index ea33321..edeede8 100644
--- a/test/cpp/api/tensor_options_cuda.cpp
+++ b/test/cpp/api/tensor_options_cuda.cpp
@@ -1,4 +1,4 @@
-#include "catch.hpp"
+#include "catch_utils.hpp"
 
 #include <ATen/Context.h>
 #include <ATen/DeviceGuard.h>
@@ -10,18 +10,18 @@
 
 // A macro so we don't lose location information when an assertion fails.
 #define REQUIRE_OPTIONS(device_, index_, type_, layout_)                    \
-  REQUIRE(options.device().type() == Device((device_), (index_)).type());   \
-  REQUIRE(options.device().index() == Device((device_), (index_)).index()); \
-  REQUIRE(options.dtype() == (type_));                                      \
-  REQUIRE(options.layout() == (layout_))
+  CATCH_REQUIRE(options.device().type() == Device((device_), (index_)).type());   \
+  CATCH_REQUIRE(options.device().index() == Device((device_), (index_)).index()); \
+  CATCH_REQUIRE(options.dtype() == (type_));                                      \
+  CATCH_REQUIRE(options.layout() == (layout_))
 
 #define REQUIRE_TENSOR_OPTIONS(device_, index_, type_, layout_)            \
-  REQUIRE(tensor.device().type() == Device((device_), (index_)).type());   \
-  REQUIRE(tensor.device().index() == Device((device_), (index_)).index()); \
-  REQUIRE(tensor.type().scalarType() == (type_));                          \
-  REQUIRE(tensor.type().layout() == (layout_))
+  CATCH_REQUIRE(tensor.device().type() == Device((device_), (index_)).type());   \
+  CATCH_REQUIRE(tensor.device().index() == Device((device_), (index_)).index()); \
+  CATCH_REQUIRE(tensor.type().scalarType() == (type_));                          \
+  CATCH_REQUIRE(tensor.type().layout() == (layout_))
 
-TEST_CASE("TensorOptions/ConstructsWellFromCUDATypes", "[cuda]") {
+CATCH_TEST_CASE("TensorOptions/ConstructsWellFromCUDATypes", "[cuda]") {
   auto options = CUDA(kFloat).options();
   REQUIRE_OPTIONS(kCUDA, -1, kFloat, kStrided);
 
@@ -41,7 +41,7 @@
   REQUIRE_OPTIONS(kCUDA, 5, kFloat, kSparse);
 }
 
-TEST_CASE("TensorOptions/ConstructsWellFromCUDATensors", "[multi-cuda]") {
+CATCH_TEST_CASE("TensorOptions/ConstructsWellFromCUDATensors", "[multi-cuda]") {
   auto options = empty(5, device(kCUDA).dtype(kDouble)).options();
   REQUIRE_OPTIONS(kCUDA, 0, kDouble, kStrided);
 
@@ -66,7 +66,7 @@
   }
 }
 
-TEST_CASE("OptionsGuardCUDA", "[multi-cuda]") {
+CATCH_TEST_CASE("OptionsGuardCUDA", "[multi-cuda]") {
   Tensor tensor;
   {
     OptionsGuard guard(device(kCUDA));
@@ -87,7 +87,7 @@
   REQUIRE_TENSOR_OPTIONS(kCUDA, 0, kInt, kStrided);
 }
 
-TEST_CASE("DeviceGuardOptionsGuardInteraction", "[multi-cuda]") {
+CATCH_TEST_CASE("DeviceGuardOptionsGuardInteraction", "[multi-cuda]") {
   Tensor tensor;
   {
     // Check that OptionsGuard respects any active device before construction.
@@ -112,17 +112,17 @@
   }
 }
 
-TEST_CASE("DeviceGuardIsMovable", "[cuda]") {
+CATCH_TEST_CASE("DeviceGuardIsMovable", "[cuda]") {
   DeviceGuard first(1);
-  REQUIRE(first.original_index() == 0);
-  REQUIRE(first.last_index() == 1);
+  CATCH_REQUIRE(first.original_index() == 0);
+  CATCH_REQUIRE(first.last_index() == 1);
   DeviceGuard second(std::move(first));
-  REQUIRE(second.original_index() == 0);
-  REQUIRE(second.last_index() == 1);
-  REQUIRE(first.original_index() == -1);
+  CATCH_REQUIRE(second.original_index() == 0);
+  CATCH_REQUIRE(second.last_index() == 1);
+  CATCH_REQUIRE(first.original_index() == -1);
   DeviceGuard third;
   third = std::move(second);
-  REQUIRE(third.original_index() == 0);
-  REQUIRE(third.last_index() == 1);
-  REQUIRE(second.original_index() == -1);
+  CATCH_REQUIRE(third.original_index() == 0);
+  CATCH_REQUIRE(third.last_index() == 1);
+  CATCH_REQUIRE(second.original_index() == -1);
 }
diff --git a/torch/csrc/jit/catch_utils.hpp b/torch/csrc/jit/catch_utils.hpp
new file mode 100644
index 0000000..b9b0a87
--- /dev/null
+++ b/torch/csrc/jit/catch_utils.hpp
@@ -0,0 +1,8 @@
+#pragma once
+
+#define CATCH_CONFIG_PREFIX_ALL
+#include <catch.hpp>
+
+// CATCH_REQUIRE_THROWS is not defined identically to REQUIRE_THROWS and causes warning;
+// define our own version that doesn't warn.
+#define _CATCH_REQUIRE_THROWS( ... ) INTERNAL_CATCH_THROWS( "CATCH_REQUIRE_THROWS", Catch::ResultDisposition::Normal, __VA_ARGS__ )
diff --git a/torch/csrc/jit/test_jit.cpp b/torch/csrc/jit/test_jit.cpp
index 28bf958..3110fb2c 100644
--- a/torch/csrc/jit/test_jit.cpp
+++ b/torch/csrc/jit/test_jit.cpp
@@ -1,13 +1,13 @@
 #ifdef USE_CATCH
 
 #define CATCH_CONFIG_MAIN
-#include "catch.hpp"
+#include "catch_utils.hpp"
 
 using Catch::StartsWith;
 
 #else
 
-#define REQUIRE JIT_ASSERT
+#define CATCH_REQUIRE JIT_ASSERT
 
 #endif
 
@@ -110,9 +110,9 @@
     e.v("what",{"is","this"});
     TemplateEnv c(e);
     c.s("hi","foo2");
-    REQUIRE(e.s("hi") == "foo");
-    REQUIRE(c.s("hi") == "foo2");
-    REQUIRE(e.v("what")[0] == "is");
+    CATCH_REQUIRE(e.s("hi") == "foo");
+    CATCH_REQUIRE(c.s("hi") == "foo2");
+    CATCH_REQUIRE(e.v("what")[0] == "is");
   }
 
   {
@@ -126,7 +126,7 @@
     auto s = ct.format(e);
     //std::cout << "'" << s << "'\n";
     //std::cout << "'" << ct_expect << "'\n";
-    REQUIRE(s == ct_expect);
+    CATCH_REQUIRE(s == ct_expect);
   }
 }
 
@@ -146,11 +146,11 @@
     auto b = at::rand({4,3}, at::kCUDA).transpose(0,1);
     auto o = at::zeros({3,4}, at::kCUDA);
     auto outputs = debugLaunchGraph(graph, 0, {a,b});
-    REQUIRE(outputs.size() == 1);
+    CATCH_REQUIRE(outputs.size() == 1);
     auto o2 = a*b;
     float max_diff = (o2 - outputs[0]).abs().max().toCDouble();
     //std::cout << "max diff: " << max_diff << "\n";
-    REQUIRE(max_diff == 0);
+    CATCH_REQUIRE(max_diff == 0);
   };
   testSimple();
 
@@ -200,10 +200,10 @@
     auto out0 = t16*t5;
 
     auto outputs = debugLaunchGraph(graph, 0, inputs);
-    REQUIRE(outputs.size() == graph.outputs().size());
-    REQUIRE(out0.is_same_size(outputs.front()));
+    CATCH_REQUIRE(outputs.size() == graph.outputs().size());
+    CATCH_REQUIRE(out0.is_same_size(outputs.front()));
     float max_diff = (outputs.front() - out0).abs().max().toCDouble();
-    REQUIRE(max_diff < 1e-6);
+    CATCH_REQUIRE(max_diff < 1e-6);
 
   };
   testOne(0,0,0,0);
@@ -234,12 +234,12 @@
     auto o_r = a*b;
     auto o2_r = at::cat({a, o_r}, dim);
     auto outputs = debugLaunchGraph(graph, 0, {a,b});
-    REQUIRE(outputs.size() == 2);
+    CATCH_REQUIRE(outputs.size() == 2);
 
     float max_diff = (o_r - outputs[0]).abs().max().toCDouble();
-    REQUIRE(max_diff == 0);
+    CATCH_REQUIRE(max_diff == 0);
     float max_diff2 = (o2_r - outputs[1]).abs().max().toCDouble();
-    REQUIRE(max_diff2 == 0);
+    CATCH_REQUIRE(max_diff2 == 0);
   };
   testConcat(0);
   testConcat(1);
@@ -255,58 +255,58 @@
   auto four = attr::perm;
   Attr attr;
   attr.f_(one,3.4)->i_(two,5)->s_(three,"what");
-  REQUIRE(attr.f(one) == 3.4);
-  REQUIRE(attr.s(three) == "what");
-  REQUIRE(attr.i(two) == 5);
+  CATCH_REQUIRE(attr.f(one) == 3.4);
+  CATCH_REQUIRE(attr.s(three) == "what");
+  CATCH_REQUIRE(attr.i(two) == 5);
   attr.s_(one,"no");
-  REQUIRE(attr.s(one) == "no");
-  REQUIRE(attr.hasAttribute(three));
-  REQUIRE(!attr.hasAttribute(four));
+  CATCH_REQUIRE(attr.s(one) == "no");
+  CATCH_REQUIRE(attr.hasAttribute(three));
+  CATCH_REQUIRE(!attr.hasAttribute(four));
   attr.ss_(two, {"hi", "now"});
-  REQUIRE(attr.ss(two).at(1) == "now");
+  CATCH_REQUIRE(attr.ss(two).at(1) == "now");
 
   Attr attr2;
   attr2.copyAttributes(attr);
-  REQUIRE(attr2.s(one) == "no");
+  CATCH_REQUIRE(attr2.s(one) == "no");
   attr2.f_(one,5);
-  REQUIRE(attr.s(one) == "no");
-  REQUIRE(attr2.f(one) == 5);
+  CATCH_REQUIRE(attr.s(one) == "no");
+  CATCH_REQUIRE(attr2.f(one) == 5);
 }
 
 void internedStringsTests () {
 
-  REQUIRE(prim::Param == Symbol::prim("Param"));
-  REQUIRE(prim::Return == Symbol::prim("Return"));
-  REQUIRE(prim::Return.toUnqualString() == std::string("Return"));
-  REQUIRE(prim::Return.toQualString() == std::string("prim::Return"));
+  CATCH_REQUIRE(prim::Param == Symbol::prim("Param"));
+  CATCH_REQUIRE(prim::Return == Symbol::prim("Return"));
+  CATCH_REQUIRE(prim::Return.toUnqualString() == std::string("Return"));
+  CATCH_REQUIRE(prim::Return.toQualString() == std::string("prim::Return"));
   Symbol newsym = Symbol::aten("__NEW_SYMBOL");
   size_t symstart = newsym;
-  REQUIRE(newsym.toQualString() == std::string("aten::__NEW_SYMBOL"));
+  CATCH_REQUIRE(newsym.toQualString() == std::string("aten::__NEW_SYMBOL"));
   // TODO: This test is a bit too close to the implementation details.
-  REQUIRE(Symbol::aten("What") == symstart+1);
-  REQUIRE(Symbol::aten("What2") == symstart+2);
-  REQUIRE(Symbol::aten("What") == symstart+1);
-  REQUIRE(Symbol::aten("What2") == symstart+2);
-  REQUIRE(Symbol(symstart+2).toUnqualString() == std::string("What2"));
+  CATCH_REQUIRE(Symbol::aten("What") == symstart+1);
+  CATCH_REQUIRE(Symbol::aten("What2") == symstart+2);
+  CATCH_REQUIRE(Symbol::aten("What") == symstart+1);
+  CATCH_REQUIRE(Symbol::aten("What2") == symstart+2);
+  CATCH_REQUIRE(Symbol(symstart+2).toUnqualString() == std::string("What2"));
 }
 
 void fromQualStringTests() {
-  REQUIRE(Symbol::fromQualString("prim::Param") == Symbol::prim("Param"));
-  REQUIRE(Symbol::fromQualString("aten::mm") == Symbol::aten("mm"));
-  REQUIRE(Symbol::fromQualString("onnx::LSTM") == Symbol::onnx("LSTM"));
-  REQUIRE(Symbol::fromQualString("attr::value") == Symbol::attr("value"));
-  REQUIRE(Symbol::fromQualString("scope::") == Symbol::scope(""));
-  REQUIRE(Symbol::fromQualString("::").toUnqualString() == std::string(""));
-  REQUIRE(Symbol::fromQualString("::").ns().toQualString() == std::string("namespaces::"));
-  REQUIRE(Symbol::fromQualString("new_ns::param").toUnqualString() == std::string("param"));
-  REQUIRE(Symbol::fromQualString("new_ns::param").ns().toUnqualString() == std::string("new_ns"));
-  REQUIRE(Symbol::fromQualString("new_ns::param").ns() == Symbol::fromQualString("namespaces::new_ns"));
+  CATCH_REQUIRE(Symbol::fromQualString("prim::Param") == Symbol::prim("Param"));
+  CATCH_REQUIRE(Symbol::fromQualString("aten::mm") == Symbol::aten("mm"));
+  CATCH_REQUIRE(Symbol::fromQualString("onnx::LSTM") == Symbol::onnx("LSTM"));
+  CATCH_REQUIRE(Symbol::fromQualString("attr::value") == Symbol::attr("value"));
+  CATCH_REQUIRE(Symbol::fromQualString("scope::") == Symbol::scope(""));
+  CATCH_REQUIRE(Symbol::fromQualString("::").toUnqualString() == std::string(""));
+  CATCH_REQUIRE(Symbol::fromQualString("::").ns().toQualString() == std::string("namespaces::"));
+  CATCH_REQUIRE(Symbol::fromQualString("new_ns::param").toUnqualString() == std::string("param"));
+  CATCH_REQUIRE(Symbol::fromQualString("new_ns::param").ns().toUnqualString() == std::string("new_ns"));
+  CATCH_REQUIRE(Symbol::fromQualString("new_ns::param").ns() == Symbol::fromQualString("namespaces::new_ns"));
 
   auto bad_inputs = {"scope", ":", ""};
   for (auto input : bad_inputs) {
     try {
       Symbol::fromQualString(input);
-      REQUIRE(0);
+      CATCH_REQUIRE(0);
     } catch (std::runtime_error c) {
     }
   }
@@ -467,8 +467,8 @@
     std::tie(hx, cx) = lstm(input[0], hx, cx, w_ih, w_hh);
 
     //std::cout << almostEqual(outputs[0],hx) << "\n";
-    REQUIRE(exactlyEqual(outputs[0],hx));
-    REQUIRE(exactlyEqual(outputs[1],cx));
+    CATCH_REQUIRE(exactlyEqual(outputs[0],hx));
+    CATCH_REQUIRE(exactlyEqual(outputs[1],cx));
 }
 
 void interpStageTest() {
@@ -500,8 +500,8 @@
     std::tie(hx, cx) = lstm(input[0], hx, cx1, w_ih, w_hh);
 
     //std::cout << almostEqual(outputs[0],hx) << "\n";
-    REQUIRE(exactlyEqual(outputs[0],hx));
-    REQUIRE(exactlyEqual(outputs[1],cx));
+    CATCH_REQUIRE(exactlyEqual(outputs[0],hx));
+    CATCH_REQUIRE(exactlyEqual(outputs[1],cx));
 }
 
 using var_meta_type = std::vector<int64_t>;
@@ -554,10 +554,10 @@
 }
 
 void assertAllClose(const tensor_list& a, const tensor_list& b) {
-  REQUIRE(a.size() == b.size());
+  CATCH_REQUIRE(a.size() == b.size());
   for (size_t i = 0; i < a.size(); ++i) {
-    REQUIRE(a[i].is_same_size(b[i]));
-    REQUIRE(a[i].allclose(b[i]));
+    CATCH_REQUIRE(a[i].is_same_size(b[i]));
+    CATCH_REQUIRE(a[i].allclose(b[i]));
   }
 }
 
@@ -654,11 +654,11 @@
   std::vector<size_t> expected_captured_outputs = {1};
   std::vector<size_t> expected_input_vjps = {0, 1};
   std::vector<size_t> expected_output_vjps = {0, 1};
-  REQUIRE(grad_spec.f_real_outputs == 1);
-  REQUIRE(grad_spec.df_input_captured_inputs == expected_captured_inputs);
-  REQUIRE(grad_spec.df_input_captured_outputs == expected_captured_outputs);
-  REQUIRE(grad_spec.df_input_vjps == expected_input_vjps);
-  REQUIRE(grad_spec.df_output_vjps == expected_output_vjps);
+  CATCH_REQUIRE(grad_spec.f_real_outputs == 1);
+  CATCH_REQUIRE(grad_spec.df_input_captured_inputs == expected_captured_inputs);
+  CATCH_REQUIRE(grad_spec.df_input_captured_outputs == expected_captured_outputs);
+  CATCH_REQUIRE(grad_spec.df_input_vjps == expected_input_vjps);
+  CATCH_REQUIRE(grad_spec.df_output_vjps == expected_output_vjps);
   out << "testDifferentiate\n";
   out << *grad_spec.f;
   out << *grad_spec.df;
@@ -684,11 +684,11 @@
   auto grad_spec = differentiate(graph);
   std::vector<size_t> expected_input_vjps = {1, 2};  // for e and %4 = (d + a)
   std::vector<size_t> expected_output_vjps = {0};    // only a requires grad
-  REQUIRE(grad_spec.f_real_outputs == 2);              // we need one temporary %4 = (d + a)
-  REQUIRE(grad_spec.df_input_captured_inputs == std::vector<size_t>({0}));
-  REQUIRE(grad_spec.df_input_captured_outputs == std::vector<size_t>({2}));
-  REQUIRE(grad_spec.df_input_vjps == expected_input_vjps);
-  REQUIRE(grad_spec.df_output_vjps == expected_output_vjps);
+  CATCH_REQUIRE(grad_spec.f_real_outputs == 2);              // we need one temporary %4 = (d + a)
+  CATCH_REQUIRE(grad_spec.df_input_captured_inputs == std::vector<size_t>({0}));
+  CATCH_REQUIRE(grad_spec.df_input_captured_outputs == std::vector<size_t>({2}));
+  CATCH_REQUIRE(grad_spec.df_input_vjps == expected_input_vjps);
+  CATCH_REQUIRE(grad_spec.df_output_vjps == expected_output_vjps);
   out << "testDifferentiateWithRequiresGrad\n";
   out << *grad_spec.f;
   out << *grad_spec.df;
@@ -718,7 +718,7 @@
 }
 
 bool isEqual(const CompleteArgumentInfo & ti, const autograd::Variable & v) {
-  REQUIRE(ti.isTensor());
+  CATCH_REQUIRE(ti.isTensor());
   if(!ti.defined())
     return ti.defined() == v.defined();
   return
@@ -754,34 +754,34 @@
 
   CompleteArgumentSpec a(true, list);
   CompleteArgumentSpec b(true, list);
-  REQUIRE(a.hashCode() == b.hashCode());
+  CATCH_REQUIRE(a.hashCode() == b.hashCode());
 
-  REQUIRE(a == b);
+  CATCH_REQUIRE(a == b);
   CompleteArgumentSpec d(true, list2);
-  REQUIRE(d == a);
-  REQUIRE(d.hashCode() == a.hashCode());
+  CATCH_REQUIRE(d == a);
+  CATCH_REQUIRE(d.hashCode() == a.hashCode());
 
   for(size_t i = 0; i < list.size(); ++i) {
-    REQUIRE(isEqual(a.at(i), list[i].toTensor()));
+    CATCH_REQUIRE(isEqual(a.at(i), list[i].toTensor()));
   }
   CompleteArgumentSpec no_grad(/*with_grad=*/false, list);
-  REQUIRE(no_grad != a);
+  CATCH_REQUIRE(no_grad != a);
 
   std::unordered_set<CompleteArgumentSpec> spec;
   spec.insert(std::move(a));
-  REQUIRE(spec.count(b) > 0);
-  REQUIRE(spec.count(no_grad) == 0);
+  CATCH_REQUIRE(spec.count(b) > 0);
+  CATCH_REQUIRE(spec.count(no_grad) == 0);
   spec.insert(std::move(no_grad));
-  REQUIRE(spec.count(CompleteArgumentSpec(true,list)) == 1);
+  CATCH_REQUIRE(spec.count(CompleteArgumentSpec(true,list)) == 1);
 
   list2[1].toTensor().transpose_(0,1);
   CompleteArgumentSpec c(true, list2); // same as list, except for one stride
-  REQUIRE(!(c == a));
-  REQUIRE(spec.count(c) == 0);
+  CATCH_REQUIRE(!(c == a));
+  CATCH_REQUIRE(spec.count(c) == 0);
 
   Stack stack = { var(CF, {1,2}, true), 3, var(CF, {1,2}, true) };
   CompleteArgumentSpec with_const(true, stack);
-  REQUIRE(with_const.at(2).sizes().size() == 2);
+  CATCH_REQUIRE(with_const.at(2).sizes().size() == 2);
 }
 
 void testGraphExecutor() {
@@ -802,11 +802,11 @@
   GraphExecutor executor(g);
   auto stack = createStack({v(input), v(hx), v(cx), v(w_ih), v(w_hh)});
   executor.run(stack);
-  REQUIRE(stack.size() == 2);
+  CATCH_REQUIRE(stack.size() == 2);
   at::Tensor r0, r1;
   std::tie(r0, r1) = lstm(input, hx, cx, w_ih, w_hh);
-  REQUIRE(almostEqual(Variable(stack[0].toTensor()).data(), r0));
-  REQUIRE(almostEqual(Variable(stack[1].toTensor()).data(), r1));
+  CATCH_REQUIRE(almostEqual(Variable(stack[0].toTensor()).data(), r0));
+  CATCH_REQUIRE(almostEqual(Variable(stack[1].toTensor()).data(), r1));
 }
 
 void testBlocks(std::ostream & out) {
@@ -877,11 +877,11 @@
   auto run_binary = [&](const std::string & name, int64_t a, int64_t b) {
     return V(run(name, {L(a), L(b)})[0]);
   };
-  REQUIRE(2 == run_binary("if_test", 1, 2));
-  REQUIRE(3 == run_binary("if_test", 3, 2));
-  REQUIRE(2 == run_binary("if_one", 2, 3));
-  REQUIRE(2 == run_binary("if_one", 3, 2));
-  REQUIRE(256 == run_binary("while_test",2,0));
+  CATCH_REQUIRE(2 == run_binary("if_test", 1, 2));
+  CATCH_REQUIRE(3 == run_binary("if_test", 3, 2));
+  CATCH_REQUIRE(2 == run_binary("if_one", 2, 3));
+  CATCH_REQUIRE(2 == run_binary("if_one", 3, 2));
+  CATCH_REQUIRE(256 == run_binary("while_test",2,0));
 }
 
 void testIValue() {
@@ -939,18 +939,18 @@
     RegisterOperators reg({createOperator(
         "foo::bar", [](double a, at::Tensor b) { return a + b; })});
     auto& ops = getAllOperatorsFor(Symbol::fromQualString("foo::bar"));
-    REQUIRE(ops.size() == 1);
+    CATCH_REQUIRE(ops.size() == 1);
 
     auto& op = ops.front();
-    REQUIRE(op->schema().name == "foo::bar");
+    CATCH_REQUIRE(op->schema().name == "foo::bar");
 
-    REQUIRE(op->schema().arguments.size() == 2);
-    REQUIRE(op->schema().arguments[0].name == "_0");
-    REQUIRE(op->schema().arguments[0].type->kind() == TypeKind::FloatType);
-    REQUIRE(op->schema().arguments[1].name == "_1");
-    REQUIRE(op->schema().arguments[1].type->kind() == TypeKind::DynamicType);
+    CATCH_REQUIRE(op->schema().arguments.size() == 2);
+    CATCH_REQUIRE(op->schema().arguments[0].name == "_0");
+    CATCH_REQUIRE(op->schema().arguments[0].type->kind() == TypeKind::FloatType);
+    CATCH_REQUIRE(op->schema().arguments[1].name == "_1");
+    CATCH_REQUIRE(op->schema().arguments[1].type->kind() == TypeKind::DynamicType);
 
-    REQUIRE(op->schema().returns[0].type->kind() == TypeKind::DynamicType);
+    CATCH_REQUIRE(op->schema().returns[0].type->kind() == TypeKind::DynamicType);
 
     Stack stack;
     push(stack, 2.0f, autograd::make_variable(at::ones(5)));
@@ -958,7 +958,7 @@
     at::Tensor output;
     pop(stack, output);
 
-    REQUIRE(output.allclose(autograd::make_variable(at::full(5, 3.0f))));
+    CATCH_REQUIRE(output.allclose(autograd::make_variable(at::full(5, 3.0f))));
   }
   {
     RegisterOperators reg({createOperator(
@@ -967,19 +967,19 @@
 
     auto& ops =
         getAllOperatorsFor(Symbol::fromQualString("foo::bar_with_schema"));
-    REQUIRE(ops.size() == 1);
+    CATCH_REQUIRE(ops.size() == 1);
 
     auto& op = ops.front();
-    REQUIRE(op->schema().name == "foo::bar_with_schema");
+    CATCH_REQUIRE(op->schema().name == "foo::bar_with_schema");
 
-    REQUIRE(op->schema().arguments.size() == 2);
-    REQUIRE(op->schema().arguments[0].name == "a");
-    REQUIRE(op->schema().arguments[0].type->kind() == TypeKind::FloatType);
-    REQUIRE(op->schema().arguments[1].name == "b");
-    REQUIRE(op->schema().arguments[1].type->kind() == TypeKind::DynamicType);
+    CATCH_REQUIRE(op->schema().arguments.size() == 2);
+    CATCH_REQUIRE(op->schema().arguments[0].name == "a");
+    CATCH_REQUIRE(op->schema().arguments[0].type->kind() == TypeKind::FloatType);
+    CATCH_REQUIRE(op->schema().arguments[1].name == "b");
+    CATCH_REQUIRE(op->schema().arguments[1].type->kind() == TypeKind::DynamicType);
 
-    REQUIRE(op->schema().returns.size() == 1);
-    REQUIRE(op->schema().returns[0].type->kind() == TypeKind::DynamicType);
+    CATCH_REQUIRE(op->schema().returns.size() == 1);
+    CATCH_REQUIRE(op->schema().returns[0].type->kind() == TypeKind::DynamicType);
 
     Stack stack;
     push(stack, 2.0f, autograd::make_variable(at::ones(5)));
@@ -987,7 +987,7 @@
     at::Tensor output;
     pop(stack, output);
 
-    REQUIRE(output.allclose(autograd::make_variable(at::full(5, 3.0f))));
+    CATCH_REQUIRE(output.allclose(autograd::make_variable(at::full(5, 3.0f))));
   }
   {
     // Check that lists work well.
@@ -999,21 +999,21 @@
 
     auto& ops =
         getAllOperatorsFor(Symbol::fromQualString("foo::lists"));
-    REQUIRE(ops.size() == 1);
+    CATCH_REQUIRE(ops.size() == 1);
 
     auto& op = ops.front();
-    REQUIRE(op->schema().name == "foo::lists");
+    CATCH_REQUIRE(op->schema().name == "foo::lists");
 
-    REQUIRE(op->schema().arguments.size() == 3);
-    REQUIRE(op->schema().arguments[0].name == "ints");
-    REQUIRE(op->schema().arguments[0].type->isSubtypeOf(ListType::ofInts()));
-    REQUIRE(op->schema().arguments[1].name == "floats");
-    REQUIRE(op->schema().arguments[1].type->isSubtypeOf(ListType::ofFloats()));
-    REQUIRE(op->schema().arguments[2].name == "tensors");
-    REQUIRE(op->schema().arguments[2].type->isSubtypeOf(ListType::ofTensors()));
+    CATCH_REQUIRE(op->schema().arguments.size() == 3);
+    CATCH_REQUIRE(op->schema().arguments[0].name == "ints");
+    CATCH_REQUIRE(op->schema().arguments[0].type->isSubtypeOf(ListType::ofInts()));
+    CATCH_REQUIRE(op->schema().arguments[1].name == "floats");
+    CATCH_REQUIRE(op->schema().arguments[1].type->isSubtypeOf(ListType::ofFloats()));
+    CATCH_REQUIRE(op->schema().arguments[2].name == "tensors");
+    CATCH_REQUIRE(op->schema().arguments[2].type->isSubtypeOf(ListType::ofTensors()));
 
-    REQUIRE(op->schema().returns.size() == 1);
-    REQUIRE(op->schema().returns[0].type->isSubtypeOf(ListType::ofFloats()));
+    CATCH_REQUIRE(op->schema().returns.size() == 1);
+    CATCH_REQUIRE(op->schema().returns[0].type->isSubtypeOf(ListType::ofFloats()));
 
     Stack stack;
     push(stack, std::vector<int64_t>{1, 2});
@@ -1023,9 +1023,9 @@
     std::vector<double> output;
     pop(stack, output);
 
-    REQUIRE(output.size() == 2);
-    REQUIRE(output[0] == 1.0);
-    REQUIRE(output[1] == 2.0);
+    CATCH_REQUIRE(output.size() == 2);
+    CATCH_REQUIRE(output[0] == 1.0);
+    CATCH_REQUIRE(output[1] == 2.0);
   }
   {
     RegisterOperators reg(
@@ -1034,17 +1034,17 @@
 
     auto& ops =
         getAllOperatorsFor(Symbol::fromQualString("foo::lists2"));
-    REQUIRE(ops.size() == 1);
+    CATCH_REQUIRE(ops.size() == 1);
 
     auto& op = ops.front();
-    REQUIRE(op->schema().name == "foo::lists2");
+    CATCH_REQUIRE(op->schema().name == "foo::lists2");
 
-    REQUIRE(op->schema().arguments.size() == 1);
-    REQUIRE(op->schema().arguments[0].name == "tensors");
-    REQUIRE(op->schema().arguments[0].type->isSubtypeOf(ListType::ofTensors()));
+    CATCH_REQUIRE(op->schema().arguments.size() == 1);
+    CATCH_REQUIRE(op->schema().arguments[0].name == "tensors");
+    CATCH_REQUIRE(op->schema().arguments[0].type->isSubtypeOf(ListType::ofTensors()));
 
-    REQUIRE(op->schema().returns.size() == 1);
-    REQUIRE(op->schema().returns[0].type->isSubtypeOf(ListType::ofTensors()));
+    CATCH_REQUIRE(op->schema().returns.size() == 1);
+    CATCH_REQUIRE(op->schema().returns[0].type->isSubtypeOf(ListType::ofTensors()));
 
     Stack stack;
     push(stack, std::vector<at::Tensor>{autograd::make_variable(at::ones(5))});
@@ -1052,31 +1052,31 @@
     std::vector<at::Tensor> output;
     pop(stack, output);
 
-    REQUIRE(output.size() == 1);
-    REQUIRE(output[0].allclose(autograd::make_variable(at::ones(5))));
+    CATCH_REQUIRE(output.size() == 1);
+    CATCH_REQUIRE(output[0].allclose(autograd::make_variable(at::ones(5))));
   }
   {
 #ifdef USE_CATCH
-    REQUIRE_THROWS_WITH(
+    CATCH_REQUIRE_THROWS_WITH(
         createOperator(
             "foo::bar_with_bad_schema(Tensor a) -> Tensor",
             [](double a, at::Tensor b) { return a + b; }),
         StartsWith("Inferred 2 argument(s) for operator implementation, "
                    "but the provided schema specified 1 argument(s)."));
-    REQUIRE_THROWS_WITH(
+    CATCH_REQUIRE_THROWS_WITH(
         createOperator(
             "foo::bar_with_bad_schema(Tensor a) -> Tensor",
             [](double a) { return a; }),
         StartsWith("Inferred type for argument #0 was float, "
                    "but the provided schema specified type Dynamic "
                    "for the argument in that position"));
-    REQUIRE_THROWS_WITH(
+    CATCH_REQUIRE_THROWS_WITH(
         createOperator(
             "foo::bar_with_bad_schema(float a) -> (float, float)",
             [](double a) { return a; }),
         StartsWith("Inferred 1 return value(s) for operator implementation, "
                    "but the provided schema specified 2 return value(s)."));
-    REQUIRE_THROWS_WITH(
+    CATCH_REQUIRE_THROWS_WITH(
         createOperator(
             "foo::bar_with_bad_schema(float a) -> Tensor",
             [](double a) { return a; }),
@@ -1109,7 +1109,7 @@
         break;
       }
     }
-    REQUIRE(contains_traced_op);
+    CATCH_REQUIRE(contains_traced_op);
   }
   {
 #ifdef USE_CATCH
@@ -1124,7 +1124,7 @@
     Stack stack;
     push(stack, std::vector<double>{1.0});
 
-    REQUIRE_THROWS_WITH(
+    CATCH_REQUIRE_THROWS_WITH(
         op.getOperation()(stack),
         StartsWith("Tracing float lists currently not supported!"));
 #endif
@@ -1156,42 +1156,42 @@
 
 #ifdef USE_CATCH
 
-TEST_CASE( "jit test CPU", "[cpu]" ) {
+CATCH_TEST_CASE( "jit test CPU", "[cpu]" ) {
 
   std::stringstream out;
-  SECTION( "control flow" )
+  CATCH_SECTION( "control flow" )
     testControlFlow();
-  SECTION( "blocks" )
+  CATCH_SECTION( "blocks" )
     testBlocks(out);
-  SECTION( "create autodiff subgraphs" )
+  CATCH_SECTION( "create autodiff subgraphs" )
     testCreateAutodiffSubgraphs(out);
-  SECTION( "differentiate" )
+  CATCH_SECTION( "differentiate" )
     testDifferentiate(out);
-  SECTION( "differentiate with requires grad" )
+  CATCH_SECTION( "differentiate with requires grad" )
     testDifferentiateWithRequiresGrad(out);
-  SECTION( "AD formulas" )
+  CATCH_SECTION( "AD formulas" )
     testADFormulas();
-  SECTION( "code template" )
+  CATCH_SECTION( "code template" )
     codeTemplateTest();
-  SECTION( "attributes" )
+  CATCH_SECTION( "attributes" )
     attributesTest();
-  SECTION( "interned strings" )
+  CATCH_SECTION( "interned strings" )
     internedStringsTests();
-  SECTION( "custom operators" )
+  CATCH_SECTION( "custom operators" )
     testCustomOperators();
 }
 
-TEST_CASE( "jit test CUDA", "[cuda]" ) {
+CATCH_TEST_CASE( "jit test CUDA", "[cuda]" ) {
 
-  SECTION( "graph executor" )
+  CATCH_SECTION( "graph executor" )
     testGraphExecutor();
-  SECTION( "fusion" )
+  CATCH_SECTION( "fusion" )
     fusionTests();
-  SECTION( "interp" )
+  CATCH_SECTION( "interp" )
     interpTest();
-  SECTION( "interp stage" )
+  CATCH_SECTION( "interp stage" )
     interpStageTest();
-  SECTION( "argument spec" )
+  CATCH_SECTION( "argument spec" )
     argumentSpecTest();
 }