blob: 4d7a4c97cb2148176220bbf2a9d67b58521e2f3b [file] [log] [blame]
/*
* Copyright (C) 2017 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "RNN.h"
#include "CpuExecutor.h"
#include "HalInterfaces.h"
#include "Tracing.h"
namespace android {
namespace nn {
RNN::RNN(const Operation& operation,
std::vector<RunTimeOperandInfo>& operands) {
NNTRACE_TRANS("RNN::RNN");
input_ = GetInput(operation, operands, kInputTensor);
weights_ = GetInput(operation, operands, kWeightsTensor);
recurrent_weights_ = GetInput(operation, operands, kRecurrentWeightsTensor);
hidden_state_in_ = GetInput(operation, operands, kHiddenStateInTensor);
bias_ = GetInput(operation, operands, kBiasTensor);
activation_ = static_cast<ActivationFn>(
getScalarData<int32_t>(operands[operation.inputs[kActivationParam]]));
hidden_state_out_ = GetOutput(operation, operands, kHiddenStateOutTensor);
output_ = GetOutput(operation, operands, kOutputTensor);
}
bool RNN::Prepare(const Operation &operation,
std::vector<RunTimeOperandInfo> &operands,
Shape *hiddenStateShape,
Shape *outputShape) {
NNTRACE_TRANS("RNN::Prepare");
// Check we have all the inputs and outputs we need.
const int num_inputs = NumInputsWithValues(operation, operands);
NN_CHECK(num_inputs == 5 || num_inputs == 6);
NN_CHECK_EQ(NumOutputs(operation), 2);
const RunTimeOperandInfo *input =
GetInput(operation, operands, kInputTensor);
const RunTimeOperandInfo *input_weights =
GetInput(operation, operands, kWeightsTensor);
const RunTimeOperandInfo *recurrent_weights =
GetInput(operation, operands, kRecurrentWeightsTensor);
const RunTimeOperandInfo *bias =
GetInput(operation, operands, kBiasTensor);
// Check all the parameters of tensor match within themselves and match the
// input configuration.
const uint32_t batch_size = SizeOfDimension(input, 0);
const uint32_t num_units = SizeOfDimension(input_weights, 0);
NN_CHECK_EQ(SizeOfDimension(input, 1), SizeOfDimension(input_weights, 1));
NN_CHECK_EQ(SizeOfDimension(input_weights, 0), SizeOfDimension(bias, 0));
NN_CHECK_EQ(SizeOfDimension(recurrent_weights, 0), SizeOfDimension(bias, 0));
NN_CHECK_EQ(SizeOfDimension(recurrent_weights, 1), SizeOfDimension(bias, 0));
const Shape &inputShape = input->shape();
// Resize state.
hiddenStateShape->type = inputShape.type;
hiddenStateShape->dimensions = { batch_size, num_units };
// Resize output.
outputShape->type = inputShape.type;
outputShape->dimensions = { batch_size, num_units };
return true;
}
bool RNN::Eval() {
NNTRACE_COMP("RNN::Eval");
const float* bias_ptr = reinterpret_cast<float*>(bias_->buffer);
const uint32_t batch_size = input_->shape().dimensions[0];
const uint32_t num_units = weights_->shape().dimensions[0];
const uint32_t input_size = input_->shape().dimensions[1];
const uint32_t input_weights_stride = weights_->shape().dimensions[1];
const uint32_t recurrent_weights_stride =
recurrent_weights_->shape().dimensions[1];
// For each batch
for (uint32_t b = 0; b < batch_size; b++) {
// Initialize the pointer to input, output and bias.
const float* input_ptr_batch =
reinterpret_cast<float*>(input_->buffer) + b * input_size;
const float* hidden_state_in_ptr_batch =
reinterpret_cast<float*>(hidden_state_in_->buffer) + b * num_units;
float* output_ptr_batch =
reinterpret_cast<float*>(output_->buffer) + b * num_units;
float* hidden_state_out_ptr_batch =
reinterpret_cast<float*>(hidden_state_out_->buffer) + b * num_units;
// Initialize input_weights and recurrent_weights.
const float* input_weights_ptr = reinterpret_cast<float*>(weights_->buffer);
const float* recurrent_weights_ptr =
reinterpret_cast<float*>(recurrent_weights_->buffer);
// Output = bias
for (uint32_t o = 0; o < num_units; o++) {
output_ptr_batch[o] = bias_ptr[o];
}
// Output += input * input_weights
for (uint32_t o = 0; o < num_units; o++) {
for (uint32_t i = 0; i < input_size; i++) {
output_ptr_batch[o] += input_ptr_batch[i] * input_weights_ptr[i];
}
input_weights_ptr += input_weights_stride;
}
// Output += recurrent_weights * hidden_state
for (uint32_t o = 0; o < num_units; o++) {
for (uint32_t h = 0; h < num_units; h++) {
output_ptr_batch[o] +=
hidden_state_in_ptr_batch[h] * recurrent_weights_ptr[h];
}
recurrent_weights_ptr += recurrent_weights_stride;
}
// Output = activation(Output) and update hidden_state
for (uint32_t o = 0; o < num_units; o++) {
output_ptr_batch[o] =
(ActivationFunctor(activation_))(output_ptr_batch[o]);
hidden_state_out_ptr_batch[o] = output_ptr_batch[o];
}
}
return true;
}
} // namespace nn
} // namespace android