blob: 71ed1a857b869626707de16507b96058d7ab01d7 [file] [log] [blame]
/*
* Copyright (C) 2021 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "PreparedModel.h"
#include "Burst.h"
#include <aidl/android/hardware/neuralnetworks/BnFencedExecutionCallback.h>
#include <aidl/android/hardware/neuralnetworks/BnPreparedModel.h>
#include <aidl/android/hardware/neuralnetworks/ExecutionResult.h>
#include <aidl/android/hardware/neuralnetworks/FencedExecutionResult.h>
#include <aidl/android/hardware/neuralnetworks/IBurst.h>
#include <aidl/android/hardware/neuralnetworks/Request.h>
#include <android-base/logging.h>
#include <android/binder_auto_utils.h>
#include <nnapi/IPreparedModel.h>
#include <nnapi/Result.h>
#include <nnapi/SharedMemory.h>
#include <nnapi/Types.h>
#include <nnapi/Validation.h>
#include <nnapi/hal/aidl/Conversions.h>
#include <nnapi/hal/aidl/Utils.h>
#include <memory>
#include <utility>
#include <vector>
namespace aidl::android::hardware::neuralnetworks::adapter {
namespace {
class FencedExecutionCallback : public BnFencedExecutionCallback {
public:
FencedExecutionCallback(nn::ExecuteFencedInfoCallback callback)
: kCallback(std::move(callback)) {}
ndk::ScopedAStatus getExecutionInfo(Timing* timingLaunched, Timing* timingFenced,
ErrorStatus* errorStatus) override {
const auto result = kCallback();
if (result.ok()) {
const auto& [nnTimingLaunched, nnTimingFenced] = result.value();
*timingLaunched = utils::convert(nnTimingLaunched).value();
*timingFenced = utils::convert(nnTimingFenced).value();
*errorStatus = ErrorStatus::NONE;
} else {
constexpr auto kNoTiming = Timing{.timeOnDeviceNs = -1, .timeInDriverNs = -1};
const auto& [message, code] = result.error();
LOG(ERROR) << "getExecutionInfo failed with " << code << ": " << message;
const auto aidlStatus = utils::convert(code).value_or(ErrorStatus::GENERAL_FAILURE);
*timingLaunched = kNoTiming;
*timingFenced = kNoTiming;
*errorStatus = aidlStatus;
}
return ndk::ScopedAStatus::ok();
}
private:
const nn::ExecuteFencedInfoCallback kCallback;
};
template <typename Type>
auto convertInput(const Type& object) -> decltype(nn::convert(std::declval<Type>())) {
auto result = nn::convert(object);
if (!result.has_value()) {
result.error().code = nn::ErrorStatus::INVALID_ARGUMENT;
}
return result;
}
nn::GeneralResult<std::vector<nn::SyncFence>> convertSyncFences(
const std::vector<ndk::ScopedFileDescriptor>& waitFor) {
auto handles = NN_TRY(convertInput(waitFor));
constexpr auto valid = [](const nn::SharedHandle& handle) {
return handle != nullptr && handle->ok();
};
if (!std::all_of(handles.begin(), handles.end(), valid)) {
return NN_ERROR(nn::ErrorStatus::INVALID_ARGUMENT) << "Invalid sync fence";
}
std::vector<nn::SyncFence> syncFences;
syncFences.reserve(waitFor.size());
for (auto& handle : handles) {
syncFences.push_back(nn::SyncFence::create(std::move(handle)).value());
}
return syncFences;
}
nn::Duration makeDuration(int64_t durationNs) {
return nn::Duration(std::chrono::nanoseconds(durationNs));
}
nn::GeneralResult<nn::OptionalDuration> makeOptionalDuration(int64_t durationNs) {
if (durationNs < -1) {
return NN_ERROR(nn::ErrorStatus::INVALID_ARGUMENT) << "Invalid duration " << durationNs;
}
return durationNs < 0 ? nn::OptionalDuration{} : makeDuration(durationNs);
}
nn::GeneralResult<nn::OptionalTimePoint> makeOptionalTimePoint(int64_t durationNs) {
if (durationNs < -1) {
return NN_ERROR(nn::ErrorStatus::INVALID_ARGUMENT) << "Invalid time point " << durationNs;
}
return durationNs < 0 ? nn::OptionalTimePoint{} : nn::TimePoint(makeDuration(durationNs));
}
nn::ExecutionResult<ExecutionResult> executeSynchronously(const nn::IPreparedModel& preparedModel,
const Request& request,
bool measureTiming, int64_t deadlineNs,
int64_t loopTimeoutDurationNs) {
const auto nnRequest = NN_TRY(convertInput(request));
const auto nnMeasureTiming = measureTiming ? nn::MeasureTiming::YES : nn::MeasureTiming::NO;
const auto nnDeadline = NN_TRY(makeOptionalTimePoint(deadlineNs));
const auto nnLoopTimeoutDuration = NN_TRY(makeOptionalDuration(loopTimeoutDurationNs));
const auto result =
preparedModel.execute(nnRequest, nnMeasureTiming, nnDeadline, nnLoopTimeoutDuration);
if (!result.ok() && result.error().code == nn::ErrorStatus::OUTPUT_INSUFFICIENT_SIZE) {
const auto& [message, code, outputShapes] = result.error();
LOG(ERROR) << "executeSynchronously failed with " << code << ": " << message;
return ExecutionResult{.outputSufficientSize = false,
.outputShapes = utils::convert(outputShapes).value(),
.timing = {.timeInDriverNs = -1, .timeOnDeviceNs = -1}};
}
const auto& [outputShapes, timing] = NN_TRY(result);
return ExecutionResult{.outputSufficientSize = true,
.outputShapes = utils::convert(outputShapes).value(),
.timing = utils::convert(timing).value()};
}
nn::GeneralResult<FencedExecutionResult> executeFenced(
const nn::IPreparedModel& preparedModel, const Request& request,
const std::vector<ndk::ScopedFileDescriptor>& waitFor, bool measureTiming,
int64_t deadlineNs, int64_t loopTimeoutDurationNs, int64_t durationNs) {
const auto nnRequest = NN_TRY(convertInput(request));
const auto nnWaitFor = NN_TRY(convertSyncFences(waitFor));
const auto nnMeasureTiming = measureTiming ? nn::MeasureTiming::YES : nn::MeasureTiming::NO;
const auto nnDeadline = NN_TRY(makeOptionalTimePoint(deadlineNs));
const auto nnLoopTimeoutDuration = NN_TRY(makeOptionalDuration(loopTimeoutDurationNs));
const auto nnDuration = NN_TRY(makeOptionalDuration(durationNs));
auto [syncFence, executeFencedInfoCallback] = NN_TRY(preparedModel.executeFenced(
nnRequest, nnWaitFor, nnMeasureTiming, nnDeadline, nnLoopTimeoutDuration, nnDuration));
ndk::ScopedFileDescriptor fileDescriptor;
if (syncFence.hasFd()) {
auto uniqueFd = NN_TRY(nn::dupFd(syncFence.getFd()));
fileDescriptor = ndk::ScopedFileDescriptor(uniqueFd.release());
}
return FencedExecutionResult{.callback = ndk::SharedRefBase::make<FencedExecutionCallback>(
std::move(executeFencedInfoCallback)),
.syncFence = std::move(fileDescriptor)};
}
} // namespace
PreparedModel::PreparedModel(nn::SharedPreparedModel preparedModel)
: kPreparedModel(std::move(preparedModel)) {
CHECK(kPreparedModel != nullptr);
}
ndk::ScopedAStatus PreparedModel::executeSynchronously(const Request& request, bool measureTiming,
int64_t deadlineNs,
int64_t loopTimeoutDurationNs,
ExecutionResult* executionResult) {
auto result = adapter::executeSynchronously(*kPreparedModel, request, measureTiming, deadlineNs,
loopTimeoutDurationNs);
if (!result.has_value()) {
const auto& [message, code, _] = result.error();
const auto aidlCode = utils::convert(code).value_or(ErrorStatus::GENERAL_FAILURE);
return ndk::ScopedAStatus::fromServiceSpecificErrorWithMessage(
static_cast<int32_t>(aidlCode), message.c_str());
}
*executionResult = std::move(result).value();
return ndk::ScopedAStatus::ok();
}
ndk::ScopedAStatus PreparedModel::executeFenced(
const Request& request, const std::vector<ndk::ScopedFileDescriptor>& waitFor,
bool measureTiming, int64_t deadlineNs, int64_t loopTimeoutDurationNs, int64_t durationNs,
FencedExecutionResult* executionResult) {
auto result = adapter::executeFenced(*kPreparedModel, request, waitFor, measureTiming,
deadlineNs, loopTimeoutDurationNs, durationNs);
if (!result.has_value()) {
const auto& [message, code] = result.error();
const auto aidlCode = utils::convert(code).value_or(ErrorStatus::GENERAL_FAILURE);
return ndk::ScopedAStatus::fromServiceSpecificErrorWithMessage(
static_cast<int32_t>(aidlCode), message.c_str());
}
*executionResult = std::move(result).value();
return ndk::ScopedAStatus::ok();
}
ndk::ScopedAStatus PreparedModel::configureExecutionBurst(std::shared_ptr<IBurst>* burst) {
auto result = kPreparedModel->configureExecutionBurst();
if (!result.has_value()) {
const auto& [message, code] = result.error();
const auto aidlCode = utils::convert(code).value_or(ErrorStatus::GENERAL_FAILURE);
return ndk::ScopedAStatus::fromServiceSpecificErrorWithMessage(
static_cast<int32_t>(aidlCode), message.c_str());
}
*burst = ndk::SharedRefBase::make<Burst>(std::move(result).value());
return ndk::ScopedAStatus::ok();
}
nn::SharedPreparedModel PreparedModel::getUnderlyingPreparedModel() const {
return kPreparedModel;
}
} // namespace aidl::android::hardware::neuralnetworks::adapter