blob: 8b51ca5ed08c6292e373aa970ff5d39f5b230dbd [file] [log] [blame]
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/
#pragma once
#include <executorch/runtime/core/error.h>
#include <executorch/runtime/core/event_tracer_hooks.h>
#include <executorch/runtime/platform/compiler.h>
namespace torch {
namespace executor {
/**
* Runtime state and functionality for kernel implementations.
*
* NOTE: Will not be passed to operators if running in ATen mode as those
* operators do not expect to receive a KernelRuntimeContext argument.
*/
class KernelRuntimeContext {
public:
/**
* Construct a new kernel runtime context along with an optional event tracer.
*/
KernelRuntimeContext(EventTracer* event_tracer = nullptr)
: event_tracer_(event_tracer) {}
/**
* Tells the runtime that the kernel call has failed. Prefer this over
* ET_CHECK_*(), which fatally panics the process/system.
*
* If this is not called, the runtime will treat the kernel call as
* successful.
*
* This unusual error-propagation path is required because kernel signatures
* do not have a natural way to return errors directly. They are generally
* compatible with core PyTorch ATen kernel signatures, which use exceptions
* to report errors. But, ExecuTorch does not use exceptions.
*/
void fail(Error error) {
failure_state_ = error;
}
/// Returns the current failure state.
__ET_NODISCARD Error failure_state() const {
return failure_state_;
}
/**
* INTERNAL ONLY
*
* Returns a pointer to an instance of EventTracer to do profiling/debugging
* logging inside the codegen layer. This is only for internal usage inside
* the codegen layer and users should not be accessing this.
*/
EventTracer* internal_event_tracer() {
return event_tracer_;
}
// TODO(T147221312): Add a way to allocate temporary memory.
// TODO(T147221312): Add a way to resize a tensor.
private:
EventTracer* event_tracer_ = nullptr;
Error failure_state_ = Error::Ok;
};
} // namespace executor
} // namespace torch
// TODO(T147221312): Remove these aliases once all code uses
// KernelRuntimeContext.
namespace exec_aten {
using RuntimeContext = torch::executor::KernelRuntimeContext;
} // namespace exec_aten
namespace torch {
namespace executor {
using RuntimeContext = torch::executor::KernelRuntimeContext;
} // namespace executor
} // namespace torch