blob: 8d782d3c2a99a6ce9a237ba74fef0277e6ce8653 [file] [log] [blame]
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/
/**
* @file
* Kernel Test utilities.
*/
#pragma once
#include <executorch/runtime/core/error.h>
#include <executorch/runtime/kernel/kernel_includes.h>
#include <executorch/runtime/platform/runtime.h>
#include <executorch/test/utils/DeathTest.h>
#include <gtest/gtest.h>
#ifdef USE_ATEN_LIB
/**
* Ensure the kernel will fail when `_statement` is executed.
* @param _statement Statement to execute.
*/
#define ET_EXPECT_KERNEL_FAILURE(_context, _statement) \
EXPECT_ANY_THROW(_statement)
#define ET_EXPECT_KERNEL_FAILURE_WITH_MSG(_context, _statement, _matcher) \
EXPECT_ANY_THROW(_statement)
#define ET_TEST_OP_SUPPORTS_MEMORY_FORMATS( \
tf, op, input_contiguous, expected_contiguous, channels_last_support) \
Tensor input_channels_last = tf.channels_last_like(input_contiguous); \
Tensor expected_channel_last = tf.channels_last_like(expected_contiguous); \
\
Tensor output_contiguous = tf.zeros_like(expected_contiguous); \
Tensor output_channels_last = tf.channels_last_like(output_contiguous); \
\
Tensor ret = op(input_channels_last, output_channels_last); \
if (channels_last_support) { \
EXPECT_TENSOR_EQ(output_channels_last, expected_channel_last); \
} else { \
EXPECT_TENSOR_NE(output_channels_last, expected_channel_last); \
} \
EXPECT_TENSOR_EQ(output_channels_last, ret);
#else
#define ET_EXPECT_KERNEL_FAILURE(_context, _statement) \
do { \
_statement; \
expect_failure(); \
if ((_context).failure_state() == torch::executor::Error::Ok) { \
ET_LOG(Error, "Expected kernel failure but found success."); \
ADD_FAILURE(); \
} \
} while (false)
#define ET_EXPECT_KERNEL_FAILURE_WITH_MSG(_context, _statement, _msg) \
do { \
_statement; \
expect_failure(); \
if ((_context).failure_state() == torch::executor::Error::Ok) { \
ET_LOG(Error, "Expected kernel failure but found success."); \
ADD_FAILURE(); \
} \
} while (false)
#define ET_TEST_OP_SUPPORTS_MEMORY_FORMATS( \
tf, op, input_contiguous, expected_contiguous, channels_last_support) \
Tensor input_channels_last = tf.channels_last_like(input_contiguous); \
Tensor expected_channel_last = tf.channels_last_like(expected_contiguous); \
\
Tensor output_contiguous = tf.zeros_like(expected_contiguous); \
Tensor output_channels_last = tf.channels_last_like(output_contiguous); \
\
Tensor ret = op(input_channels_last, output_channels_last); \
if (channels_last_support) { \
EXPECT_TENSOR_EQ(output_channels_last, expected_channel_last); \
} else { \
EXPECT_TENSOR_NE(output_channels_last, expected_channel_last); \
} \
EXPECT_TENSOR_EQ(output_channels_last, ret); \
ET_EXPECT_KERNEL_FAILURE( \
context_, op(input_channels_last, output_contiguous)); \
ET_EXPECT_KERNEL_FAILURE( \
context_, op(input_contiguous, output_channels_last));
#endif // USE_ATEN_LIB
/*
* Common test fixture for kernel / operator-level tests. Provides
* a runtime context object and verifies failure state post-execution.
*/
class OperatorTest : public ::testing::Test {
public:
OperatorTest() : expect_failure_(false) {}
void SetUp() override {
torch::executor::runtime_init();
}
void TearDown() override {
// Validate error state.
if (!expect_failure_) {
EXPECT_EQ(context_.failure_state(), torch::executor::Error::Ok);
} else {
EXPECT_NE(context_.failure_state(), torch::executor::Error::Ok);
}
}
void expect_failure() {
expect_failure_ = true;
}
protected:
exec_aten::RuntimeContext context_;
bool expect_failure_;
};