blob: e6792b2a6a2737c52e48c155742fcdb22cb77fda [file] [log] [blame]
/*
* Copyright (C) 2018 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#define LOG_TAG "OperationResolver"
#include "OperationResolver.h"
#include "NeuralNetworks.h"
namespace android {
namespace nn {
// TODO(b/119608412): Find a way to not reference every operation here.
const OperationRegistration* register_ABS();
const OperationRegistration* register_ADD();
const OperationRegistration* register_AVERAGE_POOL_2D();
const OperationRegistration* register_AXIS_ALIGNED_BBOX_TRANSFORM();
const OperationRegistration* register_BIDIRECTIONAL_SEQUENCE_RNN();
const OperationRegistration* register_BOX_WITH_NMS_LIMIT();
const OperationRegistration* register_CHANNEL_SHUFFLE();
const OperationRegistration* register_CONCATENATION();
const OperationRegistration* register_CONV_2D();
const OperationRegistration* register_DEPTHWISE_CONV_2D();
const OperationRegistration* register_DEQUANTIZE();
const OperationRegistration* register_DETECTION_POSTPROCESSING();
const OperationRegistration* register_DIV();
const OperationRegistration* register_ELU();
const OperationRegistration* register_EQUAL();
const OperationRegistration* register_EXP();
const OperationRegistration* register_FILL();
const OperationRegistration* register_FLOOR();
const OperationRegistration* register_FULLY_CONNECTED();
const OperationRegistration* register_GATHER();
const OperationRegistration* register_GENERATE_PROPOSALS();
const OperationRegistration* register_GREATER();
const OperationRegistration* register_GREATER_EQUAL();
const OperationRegistration* register_HARD_SWISH();
const OperationRegistration* register_HEATMAP_MAX_KEYPOINT();
const OperationRegistration* register_INSTANCE_NORMALIZATION();
const OperationRegistration* register_L2_NORMALIZATION();
const OperationRegistration* register_L2_POOL_2D();
const OperationRegistration* register_LESS();
const OperationRegistration* register_LESS_EQUAL();
const OperationRegistration* register_LOCAL_RESPONSE_NORMALIZATION();
const OperationRegistration* register_LOG();
const OperationRegistration* register_LOGICAL_AND();
const OperationRegistration* register_LOGICAL_NOT();
const OperationRegistration* register_LOGICAL_OR();
const OperationRegistration* register_LOGISTIC();
const OperationRegistration* register_LOG_SOFTMAX();
const OperationRegistration* register_MAX_POOL_2D();
const OperationRegistration* register_MUL();
const OperationRegistration* register_NEG();
const OperationRegistration* register_NOT_EQUAL();
const OperationRegistration* register_PRELU();
const OperationRegistration* register_QUANTIZE();
const OperationRegistration* register_QUANTIZED_LSTM();
const OperationRegistration* register_RANK();
const OperationRegistration* register_REDUCE_ALL();
const OperationRegistration* register_REDUCE_ANY();
const OperationRegistration* register_REDUCE_MAX();
const OperationRegistration* register_REDUCE_MIN();
const OperationRegistration* register_REDUCE_PROD();
const OperationRegistration* register_REDUCE_SUM();
const OperationRegistration* register_RELU();
const OperationRegistration* register_RELU1();
const OperationRegistration* register_RELU6();
const OperationRegistration* register_RESIZE_BILINEAR();
const OperationRegistration* register_RESIZE_NEAREST_NEIGHBOR();
const OperationRegistration* register_ROI_ALIGN();
const OperationRegistration* register_ROI_POOLING();
const OperationRegistration* register_RSQRT();
const OperationRegistration* register_SELECT();
const OperationRegistration* register_SIN();
const OperationRegistration* register_SLICE();
const OperationRegistration* register_SOFTMAX();
const OperationRegistration* register_SQRT();
const OperationRegistration* register_SQUEEZE();
const OperationRegistration* register_STRIDED_SLICE();
const OperationRegistration* register_SUB();
const OperationRegistration* register_TANH();
const OperationRegistration* register_TOPK_V2();
const OperationRegistration* register_TRANSPOSE();
const OperationRegistration* register_TRANSPOSE_CONV_2D();
const OperationRegistration* register_UNIDIRECTIONAL_SEQUENCE_LSTM();
const OperationRegistration* register_UNIDIRECTIONAL_SEQUENCE_RNN();
BuiltinOperationResolver::BuiltinOperationResolver() {
registerOperation(register_ABS());
registerOperation(register_ADD());
registerOperation(register_AVERAGE_POOL_2D());
registerOperation(register_AXIS_ALIGNED_BBOX_TRANSFORM());
registerOperation(register_BIDIRECTIONAL_SEQUENCE_RNN());
registerOperation(register_BOX_WITH_NMS_LIMIT());
registerOperation(register_CHANNEL_SHUFFLE());
registerOperation(register_CONCATENATION());
registerOperation(register_CONV_2D());
registerOperation(register_DEPTHWISE_CONV_2D());
registerOperation(register_DEQUANTIZE());
registerOperation(register_DETECTION_POSTPROCESSING());
registerOperation(register_DIV());
registerOperation(register_ELU());
registerOperation(register_EQUAL());
registerOperation(register_EXP());
registerOperation(register_FILL());
registerOperation(register_FLOOR());
registerOperation(register_FULLY_CONNECTED());
registerOperation(register_GATHER());
registerOperation(register_GENERATE_PROPOSALS());
registerOperation(register_GREATER());
registerOperation(register_GREATER_EQUAL());
registerOperation(register_HARD_SWISH());
registerOperation(register_HEATMAP_MAX_KEYPOINT());
registerOperation(register_INSTANCE_NORMALIZATION());
registerOperation(register_L2_NORMALIZATION());
registerOperation(register_L2_POOL_2D());
registerOperation(register_LESS());
registerOperation(register_LESS_EQUAL());
registerOperation(register_LOCAL_RESPONSE_NORMALIZATION());
registerOperation(register_LOG());
registerOperation(register_LOGICAL_AND());
registerOperation(register_LOGICAL_NOT());
registerOperation(register_LOGICAL_OR());
registerOperation(register_LOGISTIC());
registerOperation(register_LOG_SOFTMAX());
registerOperation(register_MAX_POOL_2D());
registerOperation(register_MUL());
registerOperation(register_NEG());
registerOperation(register_NOT_EQUAL());
registerOperation(register_PRELU());
registerOperation(register_QUANTIZE());
registerOperation(register_QUANTIZED_LSTM());
registerOperation(register_RANK());
registerOperation(register_REDUCE_ALL());
registerOperation(register_REDUCE_ANY());
registerOperation(register_REDUCE_MAX());
registerOperation(register_REDUCE_MIN());
registerOperation(register_REDUCE_PROD());
registerOperation(register_REDUCE_SUM());
registerOperation(register_RELU());
registerOperation(register_RELU1());
registerOperation(register_RELU6());
registerOperation(register_RESIZE_BILINEAR());
registerOperation(register_RESIZE_NEAREST_NEIGHBOR());
registerOperation(register_ROI_ALIGN());
registerOperation(register_ROI_POOLING());
registerOperation(register_RSQRT());
registerOperation(register_SELECT());
registerOperation(register_SIN());
registerOperation(register_SLICE());
registerOperation(register_SOFTMAX());
registerOperation(register_SQRT());
registerOperation(register_SQUEEZE());
registerOperation(register_STRIDED_SLICE());
registerOperation(register_SUB());
registerOperation(register_TANH());
registerOperation(register_TOPK_V2());
registerOperation(register_TRANSPOSE());
registerOperation(register_TRANSPOSE_CONV_2D());
registerOperation(register_UNIDIRECTIONAL_SEQUENCE_LSTM());
registerOperation(register_UNIDIRECTIONAL_SEQUENCE_RNN());
}
const OperationRegistration* BuiltinOperationResolver::findOperation(
OperationType operationType) const {
auto index = static_cast<int32_t>(operationType);
if (index < 0 || index >= kNumberOfOperationTypes) {
return nullptr;
}
return mRegistrations[index];
}
void BuiltinOperationResolver::registerOperation(
const OperationRegistration* operationRegistration) {
CHECK(operationRegistration != nullptr);
auto index = static_cast<int32_t>(operationRegistration->type);
CHECK_LE(0, index);
CHECK_LT(index, kNumberOfOperationTypes);
CHECK(mRegistrations[index] == nullptr);
mRegistrations[index] = operationRegistration;
}
} // namespace nn
} // namespace android