blob: ed72dbc4128c58fb8eb95cc6786595f55cc603fd [file] [log] [blame]
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/
/**
* @file
* Kernel Test utilities.
*/
#pragma once
#include <executorch/runtime/core/error.h>
#include <executorch/runtime/kernel/kernel_includes.h>
#include <executorch/runtime/platform/runtime.h>
#include <executorch/test/utils/DeathTest.h>
#include <gtest/gtest.h>
#ifdef USE_ATEN_LIB
/**
* Ensure the kernel will fail when `_statement` is executed.
* @param _statement Statement to execute.
*/
#define ET_EXPECT_KERNEL_FAILURE(_context, _statement) \
EXPECT_ANY_THROW(_statement)
#define ET_EXPECT_KERNEL_FAILURE_WITH_MSG(_context, _statement, _matcher) \
EXPECT_ANY_THROW(_statement)
#else
#define ET_EXPECT_KERNEL_FAILURE(_context, _statement) \
do { \
_statement; \
expect_failure(); \
if ((_context).failure_state() == torch::executor::Error::Ok) { \
ET_LOG(Error, "Expected kernel failure but found success."); \
ADD_FAILURE(); \
} \
} while (false)
#define ET_EXPECT_KERNEL_FAILURE_WITH_MSG(_context, _statement, _msg) \
do { \
_statement; \
expect_failure(); \
if ((_context).failure_state() == torch::executor::Error::Ok) { \
ET_LOG(Error, "Expected kernel failure but found success."); \
ADD_FAILURE(); \
} \
} while (false)
#endif // USE_ATEN_LIB
/*
* Common test fixture for kernel / operator-level tests. Provides
* a runtime context object and verifies failure state post-execution.
*/
class OperatorTest : public ::testing::Test {
public:
OperatorTest() : expect_failure_(false) {}
void SetUp() override {
torch::executor::runtime_init();
}
void TearDown() override {
// Validate error state.
if (!expect_failure_) {
EXPECT_EQ(context_.failure_state(), torch::executor::Error::Ok);
} else {
EXPECT_NE(context_.failure_state(), torch::executor::Error::Ok);
}
}
void expect_failure() {
expect_failure_ = true;
}
protected:
exec_aten::RuntimeContext context_;
bool expect_failure_;
};