blob: 8bebb7d8fdd907a27f27d38ef35f27f04fee30c2 [file] [log] [blame]
#include "gtest/gtest.h"
#include "ATen/ATen.h"
#include "test_seed.h"
using namespace at;
// can't expand empty tensor
void TestEmptyTensor(Type& T) {
auto empty = randn({0}, T);
ASSERT_ANY_THROW(empty.expand({3}));
}
// out-place function with 2 args
void TestOut2Basic(Type& T) {
auto a = randn({3, 1}, T);
auto b = randn({5}, T);
std::vector<int64_t> expanded_sizes = {3, 5};
ASSERT_TRUE(
(a + b).equal(a.expand(expanded_sizes) + b.expand(expanded_sizes)));
}
// with scalar
void TestOut2WithScalar(Type& T) {
auto aScalar = ones({1}, T);
aScalar.unsafeGetTensorImpl()->maybe_zero_dim(true);
auto b = randn({3, 5}, T);
ASSERT_TRUE(
(aScalar + b).equal(aScalar.expand(b.sizes()) + b.expand(b.sizes())));
}
// old fallback behavior yields error
void TestOut2OldFallback(Type& T) {
auto a = randn({3, 5}, T);
auto b = randn({5, 3}, T);
ASSERT_ANY_THROW(a + b);
}
// with mismatched sizes
void TestOut2MismatchedSizes(Type& T) {
auto a = randn({3, 5}, T);
auto b = randn({7, 5}, T);
ASSERT_ANY_THROW(a + b);
}
// out-place function with 3 args
void TestOut3Basic(Type& T) {
auto a = randn({3, 1, 1}, T);
auto b = randn({1, 2, 1}, T);
auto c = randn({1, 1, 5}, T);
std::vector<int64_t> expanded_sizes = {3, 2, 5};
ASSERT_TRUE((a + b + c).equal(
a.expand(expanded_sizes) + b.expand(expanded_sizes) +
c.expand(expanded_sizes)));
}
// with scalar
void TestOut3WithScalar(Type& T) {
auto aTensorScalar = ones({1}, T);
aTensorScalar.unsafeGetTensorImpl()->maybe_zero_dim(true);
auto b = randn({3, 2, 1}, T);
auto c = randn({1, 2, 5}, T);
std::vector<int64_t> expanded_sizes = {3, 2, 5};
ASSERT_TRUE(aTensorScalar.addcmul(b, c).equal(
aTensorScalar.expand(expanded_sizes)
.addcmul(b.expand(expanded_sizes), c.expand(expanded_sizes))));
}
// old fallback behavior yields error
void TestOut3OldFallback(Type& T) {
auto a = randn({3, 2, 5}, T);
auto b = randn({2, 3, 5}, T);
auto c = randn({5, 3, 2}, T);
ASSERT_ANY_THROW(a.addcmul(b, c));
}
// with mismatched sizes
void TestOut3MismatchedSizes(Type& T) {
auto a = randn({3, 2, 5}, T);
auto b = randn({2, 3, 5}, T);
auto c = randn({5, 5, 5}, T);
ASSERT_ANY_THROW(a.addcmul(b, c));
}
// in-place function with 2 args
void TestIn2Basic(Type& T) {
auto a = randn({3, 5}, T);
auto b = randn({3, 1}, T);
ASSERT_TRUE((a + b).equal(a + b.expand({3, 5})));
}
// with scalar
void TestIn2WithScalar(Type& T) {
auto a = randn({3, 5}, T);
auto bScalar = ones({1}, T);
bScalar.unsafeGetTensorImpl()->maybe_zero_dim(true);
ASSERT_TRUE((a + bScalar).equal(a + bScalar.expand(a.sizes())));
}
// error: would have to expand inplace arg
void TestIn2ExpandError(Type& T) {
auto a = randn({1, 5}, T);
auto b = randn({3, 1}, T);
ASSERT_ANY_THROW(a.add_(b));
}
// in-place function with 3 args
void TestIn3Basic(Type& T) {
auto a = randn({3, 5, 2}, T);
auto b = randn({3, 1, 2}, T);
auto c = randn({1, 5, 1}, T);
auto aClone = a.clone();
ASSERT_TRUE(a.addcmul_(b, c).equal(
aClone.addcmul_(b.expand(a.sizes()), c.expand(a.sizes()))));
}
// with scalar
void TestIn3WithScalar(Type& T) {
auto a = randn({3, 5, 2}, T);
auto b = randn({3, 1, 2}, T);
auto c = randn({1, 5, 1}, T);
auto aClone = a.clone();
auto bScalar = ones({1}, T);
bScalar.unsafeGetTensorImpl()->maybe_zero_dim(true);
ASSERT_TRUE(a.addcmul_(bScalar, c)
.equal(aClone.addcmul_(
bScalar.expand(a.sizes()), c.expand(a.sizes()))));
}
// error: would have to expand inplace arg
void TestIn3ExpandError(Type& T) {
auto a = randn({1, 3, 5}, T);
auto b = randn({4, 1, 1}, T);
auto c = randn({1, 3, 1}, T);
ASSERT_ANY_THROW(a.addcmul_(b, c));
}
// explicit dim specification
void TestExplicitDimBasic(Type& T) {
auto a = randn({1}, T);
auto b = randn({5, 3}, T);
auto c = randn({3, 7}, T);
ASSERT_TRUE(a.addmm(b, c).equal(a.expand({5, 7}).addmm(b, c)));
}
// with scalar
void TestExplicitDimWithScalar(Type& T) {
auto a = randn({1}, T);
auto b = randn({5, 3}, T);
auto c = randn({3, 7}, T);
Tensor aScalar = ones({1}, T);
aScalar.unsafeGetTensorImpl()->maybe_zero_dim(true);
ASSERT_TRUE(aScalar.addmm(b, c).equal(aScalar.expand({5, 7}).addmm(b, c)));
}
// with mismatched sizes
void TestExplicitDimWithMismatchedSizes(Type& T) {
auto b = randn({5, 3}, T);
auto c = randn({3, 7}, T);
auto a = randn({3, 3}, T);
ASSERT_ANY_THROW(a.addmm(b, c));
}
TEST(BroadcastTest, Broadcast) {
manual_seed(123, at::kCPU);
Type& T = CPU(kFloat);
TestEmptyTensor(T);
TestOut2Basic(T);
TestOut2WithScalar(T);
TestOut2OldFallback(T);
TestOut2MismatchedSizes(T);
TestOut3Basic(T);
TestOut3WithScalar(T);
TestOut3OldFallback(T);
TestOut3MismatchedSizes(T);
TestIn2Basic(T);
TestIn2WithScalar(T);
TestIn2ExpandError(T);
TestIn3Basic(T);
TestIn3WithScalar(T);
TestIn3ExpandError(T);
TestExplicitDimBasic(T);
TestExplicitDimWithScalar(T);
TestExplicitDimWithMismatchedSizes(T);
}