blob: 9388ec778e6fdee1bba9b9ae823e48cb70db05e8 [file] [log] [blame]
/*
* Copyright (C) 2018 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#define LOG_TAG "OperationResolver"
#include "OperationResolver.h"
#include "NeuralNetworks.h"
namespace android {
namespace nn {
// TODO(b/119608412): Find a way to not reference every operation here.
const OperationRegistration* register_ABS();
const OperationRegistration* register_AXIS_ALIGNED_BBOX_TRANSFORM();
const OperationRegistration* register_CHANNEL_SHUFFLE();
const OperationRegistration* register_EMBEDDING_LOOKUP_SPARSE();
const OperationRegistration* register_EQUAL();
const OperationRegistration* register_EXP();
const OperationRegistration* register_GATHER();
const OperationRegistration* register_GREATER();
const OperationRegistration* register_GREATER_EQUAL();
const OperationRegistration* register_HEATMAP_MAX_KEYPOINT();
const OperationRegistration* register_INSTANCE_NORMALIZATION();
const OperationRegistration* register_LESS();
const OperationRegistration* register_LESS_EQUAL();
const OperationRegistration* register_LOG();
const OperationRegistration* register_LOGICAL_AND();
const OperationRegistration* register_LOGICAL_NOT();
const OperationRegistration* register_LOGICAL_OR();
const OperationRegistration* register_LOG_SOFTMAX();
const OperationRegistration* register_NEG();
const OperationRegistration* register_NOT_EQUAL();
const OperationRegistration* register_PRELU();
const OperationRegistration* register_REDUCE_ALL();
const OperationRegistration* register_REDUCE_ANY();
const OperationRegistration* register_REDUCE_MAX();
const OperationRegistration* register_REDUCE_MIN();
const OperationRegistration* register_REDUCE_PROD();
const OperationRegistration* register_REDUCE_SUM();
const OperationRegistration* register_ROI_ALIGN();
const OperationRegistration* register_ROI_POOLING();
const OperationRegistration* register_RSQRT();
const OperationRegistration* register_SELECT();
const OperationRegistration* register_SIN();
const OperationRegistration* register_SPARSE_TO_DENSE();
const OperationRegistration* register_SQRT();
OperationResolver::OperationResolver() {
registerOperation(register_ABS());
registerOperation(register_AXIS_ALIGNED_BBOX_TRANSFORM());
registerOperation(register_CHANNEL_SHUFFLE());
registerOperation(register_EMBEDDING_LOOKUP_SPARSE());
registerOperation(register_EQUAL());
registerOperation(register_EXP());
registerOperation(register_GATHER());
registerOperation(register_GREATER());
registerOperation(register_GREATER_EQUAL());
registerOperation(register_HEATMAP_MAX_KEYPOINT());
registerOperation(register_INSTANCE_NORMALIZATION());
registerOperation(register_LESS());
registerOperation(register_LESS_EQUAL());
registerOperation(register_LOG());
registerOperation(register_LOGICAL_AND());
registerOperation(register_LOGICAL_NOT());
registerOperation(register_LOGICAL_OR());
registerOperation(register_LOG_SOFTMAX());
registerOperation(register_NEG());
registerOperation(register_NOT_EQUAL());
registerOperation(register_PRELU());
registerOperation(register_REDUCE_ALL());
registerOperation(register_REDUCE_ANY());
registerOperation(register_REDUCE_MAX());
registerOperation(register_REDUCE_MIN());
registerOperation(register_REDUCE_PROD());
registerOperation(register_REDUCE_SUM());
registerOperation(register_ROI_ALIGN());
registerOperation(register_ROI_POOLING());
registerOperation(register_RSQRT());
registerOperation(register_SELECT());
registerOperation(register_SIN());
registerOperation(register_SPARSE_TO_DENSE());
registerOperation(register_SQRT());
}
const OperationRegistration* OperationResolver::findOperation(OperationType operationType) const {
auto index = static_cast<int32_t>(operationType);
if (index < 0 || index >= kNumberOfOperationTypes) {
return nullptr;
}
return mRegistrations[index];
}
void OperationResolver::registerOperation(const OperationRegistration* operationRegistration) {
CHECK(operationRegistration != nullptr);
auto index = static_cast<int32_t>(operationRegistration->type);
CHECK_LE(0, index);
CHECK_LT(index, kNumberOfOperationTypes);
CHECK(mRegistrations[index] == nullptr);
mRegistrations[index] = operationRegistration;
}
} // namespace nn
} // namespace android