blob: 986f599da6d11b66b72ec63a98a5fc190ce4a97f [file] [log] [blame]
#define CATCH_CONFIG_MAIN
#include "catch.hpp"
#include "cuda.h"
#include "cuda_runtime.h"
#include "ATen/cuda/detail/TensorInfo.cuh"
/*
Tests related to tensor indexing and applying operations.
*/
#ifndef _WIN32
TEST_CASE("2D Contiguous", "Collapses a 2D contiguous tensor to 1D contiguous") {
int sizes[] = {4, 4};
int strides[] = {4, 1};
::at::cuda::detail::TensorInfo<void, int> ti{nullptr, 2, sizes, strides};
ti.collapseDims();
REQUIRE(ti.dims == 1);
REQUIRE(ti.sizes[0] == (4 * 4));
}
TEST_CASE("3D Contiguous", "Collapses a 3D contiguous tensor to a 1D contiguous") {
int sizes[] = {6, 3, 7};
int strides[] = {3 * 7, 7, 1};
::at::cuda::detail::TensorInfo<void, int> ti{nullptr, 3, sizes, strides};
ti.collapseDims();
REQUIRE(ti.dims == 1);
REQUIRE(ti.sizes[0] == (6 * 3 * 7));
}
TEST_CASE("3D Partial Collapse", "Collapses a 3D noncontiguous tensor to a 2D tensor") {
int sizes[] = {4, 3, 2};
int strides[] = {3 * 3, 3, 1};
::at::cuda::detail::TensorInfo<void, int> ti{nullptr, 3, sizes, strides};
ti.collapseDims();
REQUIRE(ti.dims == 2);
REQUIRE(ti.sizes[0] == (4 * 3));
REQUIRE(ti.sizes[1] == 2);
}
TEST_CASE("2D Strided Collapse", "Collapses a 2D skip contiguous tensor to a 1D skip contiguous tensor") {
int sizes[] = {3, 2};
int strides[] = {2 * 2, 2};
::at::cuda::detail::TensorInfo<void, int> ti{nullptr, 2, sizes, strides};
ti.collapseDims();
REQUIRE(ti.dims == 1);
REQUIRE(ti.sizes[0] == (3 * 2));
REQUIRE(ti.strides[0] == 2);
}
TEST_CASE("4D Partial Strided Collapse", "Collapses a 4D tensor to a 2D tensor"){
int sizes[] = {3, 6, 5, 2};
int strides[] = {6 * 22, 22, 2 * 2, 2};
::at::cuda::detail::TensorInfo<void, int> ti{nullptr, 4, sizes, strides};
ti.collapseDims();
REQUIRE(ti.dims == 2);
REQUIRE(ti.sizes[0] == (3 * 6));
REQUIRE(ti.strides[0] == 22);
REQUIRE(ti.sizes[1] == (5 * 2));
REQUIRE(ti.strides[1] == 2);
}
TEST_CASE("Collapsing Zeros and Ones", "Collapses a 5D tensor to a 1D tensor") {
int sizes[] = {1, 10, 1, 5, 4};
int strides[] = {4, 0, 16, 0, 1};
::at::cuda::detail::TensorInfo<void, int> ti{nullptr, 5, sizes, strides};
ti.collapseDims();
REQUIRE(ti.dims == 2);
REQUIRE(ti.sizes[0] == (10 * 5));
REQUIRE(ti.strides[0] == 0);
REQUIRE(ti.sizes[1] == 4);
REQUIRE(ti.strides[1] == 1);
}
TEST_CASE("Collapsing to a Point Tensor", "Collapses a 3D tensor to a point tensor") {
int sizes[] = {1, 1, 1};
int strides[] = {17, 12, 3};
::at::cuda::detail::TensorInfo<void, int> ti{nullptr, 3, sizes, strides};
REQUIRE(ti.collapseDims() == 0);
REQUIRE(ti.dims == 1);
REQUIRE(ti.sizes[0] == 1);
REQUIRE(ti.strides[0] == 1);
}
TEST_CASE("Excluding in a 4D Contiguous", "Collapses a 4D tensor to a 3D tensor") {
int sizes[] = {3, 6, 5, 2};
int strides[] = {6 * 22, 22, 2 * 2, 2};
::at::cuda::detail::TensorInfo<void, int> ti{nullptr, 4, sizes, strides};
REQUIRE(ti.collapseDims(1) == 1);
REQUIRE(ti.dims == 3);
REQUIRE(ti.sizes[0] == 3);
REQUIRE(ti.strides[0] == (6 * 22));
REQUIRE(ti.sizes[1] == 6);
REQUIRE(ti.strides[1] == 22);
REQUIRE(ti.sizes[2] == (5 * 2));
REQUIRE(ti.strides[2] == 2);
}
TEST_CASE("Roving Exclusion", "Collapses a 4D tensor to a 3D tensor") {
int sizes[] = {3, 6, 5, 2};
int strides[] = {6 * 22, 22, 2 * 2, 2};
::at::cuda::detail::TensorInfo<void, int> ti{nullptr, 4, sizes, strides};
REQUIRE(ti.collapseDims(2) == 1);
REQUIRE(ti.dims == 3);
REQUIRE(ti.sizes[0] == (3 * 6));
REQUIRE(ti.strides[0] == 22);
REQUIRE(ti.sizes[1] == 5);
REQUIRE(ti.strides[1] == 4);
REQUIRE(ti.sizes[2] == 2);
REQUIRE(ti.strides[2] == 2);
}
TEST_CASE("Invalid Exclusion", "Attempts to exclude a nonexisting dimension") {
int sizes[] = {1, 1, 1};
int strides[] = {17, 12, 3};
::at::cuda::detail::TensorInfo<void, int> ti{nullptr, 3, sizes, strides};
REQUIRE_THROWS(ti.collapseDims(5));
}
#endif