blob: 8553125c7d990441087deb10d6fd837d03d29129 [file] [log] [blame]
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/
#include <executorch/kernels/test/FunctionHeaderWrapper.h> // Declares the operator
#include <executorch/kernels/test/TestUtil.h>
#include <executorch/runtime/core/exec_aten/exec_aten.h>
#include <executorch/runtime/core/exec_aten/testing_util/tensor_factory.h>
#include <executorch/runtime/core/exec_aten/testing_util/tensor_util.h>
#include <gtest/gtest.h>
using namespace ::testing;
using exec_aten::Scalar;
using exec_aten::ScalarType;
using exec_aten::Tensor;
using torch::executor::testing::TensorFactory;
class OpPowTest : public OperatorTest {
protected:
Tensor&
op_pow_scalar_out(const Scalar& self, const Tensor& exponent, Tensor& out) {
return torch::executor::aten::pow_outf(context_, self, exponent, out);
}
Tensor& op_pow_tensor_scalar_out(
const Tensor& self,
const Scalar& exponent,
Tensor& out) {
return torch::executor::aten::pow_outf(context_, self, exponent, out);
}
Tensor& op_pow_tensor_tensor_out(
const Tensor& self,
const Tensor& exponent,
Tensor& out) {
return torch::executor::aten::pow_outf(context_, self, exponent, out);
}
};
TEST_F(OpPowTest, TensorTensorSanityCheck) {
TensorFactory<ScalarType::Byte> tf;
Tensor self = tf.make({2, 2}, {2, 2, 2, 2});
Tensor exp = tf.make({2, 1}, {4, 4});
Tensor out = tf.make({2, 2}, {16, 16, 16, 16});
Tensor ret = op_pow_tensor_tensor_out(self, exp, out);
EXPECT_TENSOR_EQ(out, ret);
EXPECT_TENSOR_EQ(out, tf.make({2, 2}, {16, 16, 16, 16}));
}
TEST_F(OpPowTest, TensorTensorSanityCheck2) {
TensorFactory<ScalarType::Float> tf1;
TensorFactory<ScalarType::Int> tf2;
TensorFactory<ScalarType::Double> tf3;
Tensor self = tf1.make({2, 2}, {2, 3, 4, 5});
Tensor exp = tf2.make({2, 1}, {2, 2});
Tensor out = tf3.zeros({2, 2});
Tensor ret = op_pow_tensor_tensor_out(self, exp, out);
EXPECT_TENSOR_EQ(out, ret);
EXPECT_TENSOR_EQ(out, tf3.make({2, 2}, {4, 9, 16, 25}));
}
TEST_F(OpPowTest, TensorTensorHalfSupport) {
TensorFactory<ScalarType::Half> tf;
Tensor self = tf.make({2, 2}, {2.0, 3.0, 4.0, 5.0});
Tensor exp = tf.make({2, 1}, {3.0, 2.0});
Tensor out = tf.zeros({2, 2});
Tensor ret = op_pow_tensor_tensor_out(self, exp, out);
EXPECT_TENSOR_EQ(out, ret);
EXPECT_TENSOR_EQ(out, tf.make({2, 2}, {8.0, 27.0, 16.0, 25.0}));
}
TEST_F(OpPowTest, TensorScalarSanityCheck) {
TensorFactory<ScalarType::Byte> tf;
Tensor self = tf.make({2, 2}, {2, 2, 2, 2});
Tensor out = tf.make({2, 2}, {16, 16, 16, 16});
Tensor ret = op_pow_tensor_scalar_out(self, 4, out);
EXPECT_TENSOR_EQ(out, ret);
EXPECT_TENSOR_EQ(out, tf.make({2, 2}, {16, 16, 16, 16}));
}
TEST_F(OpPowTest, TensorScalarHalfSupport) {
TensorFactory<ScalarType::Half> tf;
Tensor self = tf.make({2, 2}, {2.0, 2.0, 2.0, 2.0});
Tensor out = tf.zeros({2, 2});
Tensor ret = op_pow_tensor_scalar_out(self, 4, out);
EXPECT_TENSOR_EQ(out, ret);
EXPECT_TENSOR_EQ(out, tf.make({2, 2}, {16.0, 16.0, 16.0, 16.0}));
}
TEST_F(OpPowTest, ScalarSanityCheck) {
TensorFactory<ScalarType::Byte> tf;
Tensor exp = tf.make({2, 2}, {2, 2, 2, 2});
Tensor out = tf.make({2, 2}, {16, 16, 16, 16});
Tensor ret = op_pow_scalar_out(4, exp, out);
EXPECT_TENSOR_EQ(out, ret);
EXPECT_TENSOR_EQ(out, tf.make({2, 2}, {16, 16, 16, 16}));
}
TEST_F(OpPowTest, ScalarHalfSupport) {
TensorFactory<ScalarType::Half> tf;
Tensor exp = tf.make({2, 2}, {2, 2, 2, 2});
Tensor out = tf.zeros({2, 2});
Tensor ret = op_pow_scalar_out(4, exp, out);
EXPECT_TENSOR_EQ(out, ret);
EXPECT_TENSOR_EQ(out, tf.make({2, 2}, {16, 16, 16, 16}));
}