blob: 86ff4d07a3ed683d00e0bce969d1428c42133bdc [file] [log] [blame]
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/kernels/lstm_eval.h"
#include <algorithm>
#include <cstdint>
#ifdef GEMMLOWP_PROFILING
#include "profiling/profiler.h"
#endif
#include "third_party/eigen3/Eigen/Core"
#include "tensorflow/lite/c/builtin_op_data.h"
#include "tensorflow/lite/c/c_api_internal.h"
#include "tensorflow/lite/kernels/internal/kernel_utils.h"
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
#include "tensorflow/lite/kernels/internal/tensor_utils.h"
#include "tensorflow/lite/kernels/op_macros.h"
namespace tflite {
namespace ops {
namespace builtin {
namespace lstm_eval {
namespace {
// Small float to avoid divergence during calculation of deviation for layer
// norm lstm.
const float kLayerNormEpsilon = 1e-8;
// Performs an LSTM batch inference step for input specified by input_ptr_batch.
// The LSTM cell is specified by the pointers to its weights (*_weights_ptr) and
// biases (*_bias_ptr), and buffers (*_scratch), along with additional
// parameters:
// - params: various LSTM params including activation, clipping, etc.,
// - n_batch: size of batch,
// - n_cell: number of cells (or units),
// - n_input: the input size,
// - n_aux_input: the auxiliary input size.
// - n_output: the output size.
// - output_batch_leading_dim: the leading dimension of the output buffer.
//
// LSTM weights:
// Input weights of size 'n_cell * n_input':
// input_to_input_weights - optional (can be nullptr)
// input_to_forget_weights
// input_to_cell_weights
// input_to_output_weights
// Auxiliary input weights of size 'n_cell * n_aux_input':
// aux_input_to_input_weights - optional
// aux_input_to_forget_weights - optional
// aux_input_to_cell_weights - optional
// aux_input_to_output_weights - optional
// Recurrent weights of size 'n_cell * n_output':
// recurrent_to_input_weights - optional
// recurrent_to_forget_weights
// recurrent_to_cell_weights
// recurrent_to_input_weights
// Peephole weights of size 'n_cell', representing diagonal matrices.
// cell_to_input_weights - optional
// cell_to_cell_weights - optional
// cell_to_output_weights - optional
// Projection weights of size 'n_output * n_cell'
// projection_weights_ptr - optional
// Gate biases of size 'n_cell':
// input_gate_bias_ptr - optional
// forget_gate_bias_ptr
// cell_gate_bias_ptr
// output_gate_bias_ptr
//
// Layer norm coefficients of size 'n_cell', representing diagonal matrices.
// input_layer_norm_coefficients_ptr - optional
// forget_layer_norm_coefficients_ptr - optional
// cell_layer_norm_coefficients_ptr - optional
// output_layer_norm_coefficients_ptr - optional
//
// The pointers to the cell and output state and the output are updated.
//
// The pointers with the suffix "_batch" point to data aligned in batch_major
// order, and each step processes batch_size many inputs from input_ptr_batch,
// and updates batch_size many cell and output states.
//
// The output_batch_dim is output.shape[-1], i.e. the outermost dimension of the
// output tensor, and in most cases will be equal to n_output. It is usually not
// when we want to store the LSTM output into a slice of the output tensor, e.g.
// for bidirectional LSTMs with merge_outputs. In this case, the batched
// operations cannot be used since they assume that the batched outputs are
// contiguous, and we manually loop over the batched outputs.
inline void LstmStepWithAuxInput(
const float* input_ptr_batch, const float* input_to_input_weights_ptr,
const float* input_to_forget_weights_ptr,
const float* input_to_cell_weights_ptr,
const float* input_to_output_weights_ptr, const float* aux_input_ptr_batch,
const float* aux_input_to_input_weights_ptr,
const float* aux_input_to_forget_weights_ptr,
const float* aux_input_to_cell_weights_ptr,
const float* aux_input_to_output_weights_ptr,
const float* recurrent_to_input_weights_ptr,
const float* recurrent_to_forget_weights_ptr,
const float* recurrent_to_cell_weights_ptr,
const float* recurrent_to_output_weights_ptr,
const float* cell_to_input_weights_ptr,
const float* cell_to_forget_weights_ptr,
const float* cell_to_output_weights_ptr,
const float* input_layer_norm_coefficients_ptr,
const float* forget_layer_norm_coefficients_ptr,
const float* cell_layer_norm_coefficients_ptr,
const float* output_layer_norm_coefficients_ptr,
const float* input_gate_bias_ptr, const float* forget_gate_bias_ptr,
const float* cell_bias_ptr, const float* output_gate_bias_ptr,
const float* projection_weights_ptr, const float* projection_bias_ptr,
const TfLiteLSTMParams* params, int n_batch, int n_cell, int n_input,
int n_aux_input, int n_output, int output_batch_leading_dim,
float* output_state_ptr, float* cell_state_ptr, float* input_gate_scratch,
float* forget_gate_scratch, float* cell_scratch, float* output_gate_scratch,
float* output_ptr_batch) {
#ifdef GEMMLOWP_PROFILING
gemmlowp::ScopedProfilingLabel label("LstmStepWithAuxInputFloat");
#endif
// Since we have already checked that weights are all there or none, we can
// check the existence of only one to the get the condition.
const bool use_cifg = (input_to_input_weights_ptr == nullptr);
const bool use_peephole = (cell_to_output_weights_ptr != nullptr);
const bool is_layer_norm_lstm =
(forget_layer_norm_coefficients_ptr != nullptr);
// Initialize scratch buffers with bias for regular lstm or initialize with
// zero for layer norm lstm.
if (is_layer_norm_lstm) {
if (!use_cifg) {
std::fill_n(input_gate_scratch, n_cell * n_batch, 0.0f);
}
std::fill_n(forget_gate_scratch, n_cell * n_batch, 0.0f);
std::fill_n(cell_scratch, n_cell * n_batch, 0.0f);
std::fill_n(output_gate_scratch, n_cell * n_batch, 0.0f);
} else {
if (!use_cifg) {
tensor_utils::VectorBatchVectorAssign(input_gate_bias_ptr, n_cell,
n_batch, input_gate_scratch);
}
tensor_utils::VectorBatchVectorAssign(forget_gate_bias_ptr, n_cell, n_batch,
forget_gate_scratch);
tensor_utils::VectorBatchVectorAssign(cell_bias_ptr, n_cell, n_batch,
cell_scratch);
tensor_utils::VectorBatchVectorAssign(output_gate_bias_ptr, n_cell, n_batch,
output_gate_scratch);
}
// For each batch and cell: compute input_weight * input.
if (!use_cifg) {
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
input_to_input_weights_ptr, n_cell, n_input, input_ptr_batch, n_batch,
input_gate_scratch, /*result_stride=*/1);
}
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
input_to_forget_weights_ptr, n_cell, n_input, input_ptr_batch, n_batch,
forget_gate_scratch, /*result_stride=*/1);
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
input_to_cell_weights_ptr, n_cell, n_input, input_ptr_batch, n_batch,
cell_scratch, /*result_stride=*/1);
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
input_to_output_weights_ptr, n_cell, n_input, input_ptr_batch, n_batch,
output_gate_scratch, /*result_stride=*/1);
// If auxiliary input is available then compute aux_input_weight * aux_input
if (aux_input_ptr_batch != nullptr) {
if (!use_cifg) {
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
aux_input_to_input_weights_ptr, n_cell, n_aux_input,
aux_input_ptr_batch, n_batch, input_gate_scratch,
/*result_stride=*/1);
}
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
aux_input_to_forget_weights_ptr, n_cell, n_aux_input,
aux_input_ptr_batch, n_batch, forget_gate_scratch, /*result_stride=*/1);
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
aux_input_to_cell_weights_ptr, n_cell, n_aux_input, aux_input_ptr_batch,
n_batch, cell_scratch, /*result_stride=*/1);
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
aux_input_to_output_weights_ptr, n_cell, n_aux_input,
aux_input_ptr_batch, n_batch, output_gate_scratch, /*result_stride=*/1);
}
// For each batch and cell: compute recurrent_weight * output_state.
if (!use_cifg) {
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
recurrent_to_input_weights_ptr, n_cell, n_output, output_state_ptr,
n_batch, input_gate_scratch, /*result_stride=*/1);
}
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
recurrent_to_forget_weights_ptr, n_cell, n_output, output_state_ptr,
n_batch, forget_gate_scratch,
/*result_stride=*/1);
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
recurrent_to_cell_weights_ptr, n_cell, n_output, output_state_ptr,
n_batch, cell_scratch, /*result_stride=*/1);
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
recurrent_to_output_weights_ptr, n_cell, n_output, output_state_ptr,
n_batch, output_gate_scratch,
/*result_stride=*/1);
// For each batch and cell: update input gate.
if (!use_cifg) {
if (use_peephole) {
tensor_utils::VectorBatchVectorCwiseProductAccumulate(
cell_to_input_weights_ptr, n_cell, cell_state_ptr, n_batch,
input_gate_scratch);
}
if (is_layer_norm_lstm) {
tensor_utils::MeanStddevNormalization(input_gate_scratch,
input_gate_scratch, n_cell, n_batch,
kLayerNormEpsilon);
tensor_utils::VectorBatchVectorCwiseProduct(
input_layer_norm_coefficients_ptr, n_cell, input_gate_scratch,
n_batch, input_gate_scratch);
tensor_utils::VectorBatchVectorAdd(input_gate_bias_ptr, n_cell, n_batch,
input_gate_scratch);
}
tensor_utils::ApplySigmoidToVector(input_gate_scratch, n_cell * n_batch,
input_gate_scratch);
}
// For each batch and cell: update forget gate.
if (use_peephole) {
tensor_utils::VectorBatchVectorCwiseProductAccumulate(
cell_to_forget_weights_ptr, n_cell, cell_state_ptr, n_batch,
forget_gate_scratch);
}
if (is_layer_norm_lstm) {
tensor_utils::MeanStddevNormalization(forget_gate_scratch,
forget_gate_scratch, n_cell, n_batch,
kLayerNormEpsilon);
tensor_utils::VectorBatchVectorCwiseProduct(
forget_layer_norm_coefficients_ptr, n_cell, forget_gate_scratch,
n_batch, forget_gate_scratch);
tensor_utils::VectorBatchVectorAdd(forget_gate_bias_ptr, n_cell, n_batch,
forget_gate_scratch);
}
tensor_utils::ApplySigmoidToVector(forget_gate_scratch, n_cell * n_batch,
forget_gate_scratch);
// For each batch and cell: update the cell.
tensor_utils::VectorVectorCwiseProduct(forget_gate_scratch, cell_state_ptr,
n_batch * n_cell, cell_state_ptr);
if (is_layer_norm_lstm) {
tensor_utils::MeanStddevNormalization(cell_scratch, cell_scratch, n_cell,
n_batch, kLayerNormEpsilon);
tensor_utils::VectorBatchVectorCwiseProduct(
cell_layer_norm_coefficients_ptr, n_cell, cell_scratch, n_batch,
cell_scratch);
tensor_utils::VectorBatchVectorAdd(cell_bias_ptr, n_cell, n_batch,
cell_scratch);
}
tensor_utils::ApplyActivationToVector(cell_scratch, n_batch * n_cell,
params->activation, cell_scratch);
if (use_cifg) {
tensor_utils::Sub1Vector(forget_gate_scratch, n_batch * n_cell,
forget_gate_scratch);
tensor_utils::VectorVectorCwiseProductAccumulate(
cell_scratch, forget_gate_scratch, n_batch * n_cell, cell_state_ptr);
} else {
tensor_utils::VectorVectorCwiseProductAccumulate(
cell_scratch, input_gate_scratch, n_batch * n_cell, cell_state_ptr);
}
if (params->cell_clip > 0.0) {
tensor_utils::ClipVector(cell_state_ptr, n_batch * n_cell,
params->cell_clip, cell_state_ptr);
}
// For each batch and cell: update the output gate.
if (use_peephole) {
tensor_utils::VectorBatchVectorCwiseProductAccumulate(
cell_to_output_weights_ptr, n_cell, cell_state_ptr, n_batch,
output_gate_scratch);
}
if (is_layer_norm_lstm) {
tensor_utils::MeanStddevNormalization(output_gate_scratch,
output_gate_scratch, n_cell, n_batch,
kLayerNormEpsilon);
tensor_utils::VectorBatchVectorCwiseProduct(
output_layer_norm_coefficients_ptr, n_cell, output_gate_scratch,
n_batch, output_gate_scratch);
tensor_utils::VectorBatchVectorAdd(output_gate_bias_ptr, n_cell, n_batch,
output_gate_scratch);
}
tensor_utils::ApplySigmoidToVector(output_gate_scratch, n_batch * n_cell,
output_gate_scratch);
tensor_utils::ApplyActivationToVector(cell_state_ptr, n_batch * n_cell,
params->activation, cell_scratch);
tensor_utils::VectorVectorCwiseProduct(output_gate_scratch, cell_scratch,
n_batch * n_cell, output_gate_scratch);
const bool use_projection_weight = (projection_weights_ptr != nullptr);
const bool use_projection_bias = (projection_bias_ptr != nullptr);
// For each batch: update the projection and output_state. Note that since
// the output batch rows may not be contiguous (output_batch_leading_dim !=
// n_output), we unroll the batched operations where this is the case.
if (output_batch_leading_dim == n_output) {
if (use_projection_weight) {
if (use_projection_bias) {
tensor_utils::VectorBatchVectorAssign(projection_bias_ptr, n_output,
n_batch, output_ptr_batch);
} else {
std::fill_n(output_ptr_batch, n_batch * n_output, 0.0f);
}
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
projection_weights_ptr, n_output, n_cell, output_gate_scratch,
n_batch, output_ptr_batch, /*result_stride=*/1);
if (params->proj_clip > 0.0) {
tensor_utils::ClipVector(output_ptr_batch, n_batch * n_output,
params->proj_clip, output_ptr_batch);
}
} else {
std::copy_n(output_gate_scratch, n_batch * n_output, output_ptr_batch);
}
std::copy_n(output_ptr_batch, n_batch * n_output, output_state_ptr);
} else {
if (use_projection_weight) {
if (use_projection_bias) {
for (int k = 0; k < n_batch; k++) {
std::copy_n(projection_bias_ptr, n_output,
output_ptr_batch + k * output_batch_leading_dim);
}
} else {
for (int k = 0; k < n_batch; k++) {
std::fill_n(output_ptr_batch + k * output_batch_leading_dim, n_output,
0.0f);
}
}
for (int k = 0; k < n_batch; k++) {
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
projection_weights_ptr, n_output, n_cell,
output_gate_scratch + k * n_cell,
/*n_batch=*/1, output_ptr_batch + k * output_batch_leading_dim,
/*result_stride=*/1);
if (params->proj_clip > 0.0) {
tensor_utils::ClipVector(
output_ptr_batch + k * output_batch_leading_dim, n_output,
params->proj_clip,
output_ptr_batch + k * output_batch_leading_dim);
}
}
} else {
for (int k = 0; k < n_batch; k++) {
std::copy_n(output_gate_scratch + k * n_output, n_output,
output_ptr_batch + k * output_batch_leading_dim);
}
}
for (int k = 0; k < n_batch; k++) {
std::copy_n(output_ptr_batch + k * output_batch_leading_dim, n_output,
output_state_ptr + k * n_output);
}
}
}
void ApplyActivationsToVector(float* input, int input_size,
TfLiteFusedActivation activation_type,
float* output) {
using VectorMap = Eigen::Map<Eigen::Matrix<float, Eigen::Dynamic, 1>>;
VectorMap input_map(input, input_size, 1);
VectorMap output_map(output, input_size, 1);
switch (activation_type) {
case kTfLiteActSigmoid: {
output_map.array() = input_map.array().logistic();
break;
}
case kTfLiteActTanh: {
output_map.array() = input_map.array().tanh();
break;
}
default: {
tensor_utils::ApplyActivationToVector(input, input_size, activation_type,
output);
}
}
}
// Same as above but with quantized weight matrices. In detail:
// Input of size 'n_batch * n_input':
// input_ptr_batch
//
// LSTM weights:
// Quantized input weights of size 'n_cell * n_input':
// input_to_input_weights - optional (can be nullptr)
// input_to_forget_weights
// input_to_cell_weights
// input_to_input_weights
// Quantized auxiliary input weights of size 'n_cell * n_aux_input':
// aux_input_to_input_weights - optional
// aux_input_to_forget_weights - optional
// aux_input_to_cell_weights - optional
// aux_input_to_output_weights - optional
// Quantized recurrent weights of size 'n_cell * n_output':
// recurrent_to_input_weights - optional
// recurrent_to_forget_weights
// recurrent_to_cell_weights
// recurrent_to_input_weights
// Quantized peephole weights of size 'n_cell', representing diagonal matrices.
// cell_to_input_weights - optional
// cell_to_cell_weights - optional
// cell_to_output_weights - optional
// Quantized projection weights of size 'n_output * n_cell'
// projection_weights_ptr - optional
// Weight scales (scalars) for each of the weights above.
// input_to_input_weights_scale - optional
// input_to_forget_weights_scale
// input_to_cell_weights_scale
// input_to_output_weights_scale
// aux_input_to_input_weights_scale - optional
// aux_input_to_forget_weights_scale - optional
// aux_input_to_cell_weights_scale - optional
// aux_input_to_output_weights_scale - optional
// recurrent_to_input_weights_scale - optional
// recurrent_to_forget_weights_scale
// recurrent_to_cell_weights_scale
// recurrent_to_output_weights_scale
// cell_to_input_weights_scale,
// cell_to_forget_weights_scale,
// cell_to_output_weights_scale,
// projection_weights_scale - optional
// Gate biases of size 'n_cell':
// input_gate_bias_ptr - optional
// forget_gate_bias_ptr
// cell_gate_bias_ptr
// output_gate_bias_ptr
//
// Layer norm coefficients of size 'n_cell', representing diagonal matrices.
// input_layer_norm_coefficients_ptr - optional
// forget_layer_norm_coefficients_ptr - optional
// cell_layer_norm_coefficients_ptr - optional
// output_layer_norm_coefficients_ptr - optional
//
// Temporary pre-allocated storage for quantized values:
// quantized_input_ptr_batch (same size as input_ptr_batch)
// quantized_output_state_ptr (same size as output_state_ptr)
// quantized_cell_state_ptr (same size as cell_state_ptr)
// Temporary pre-allocated storage for recovered values:
// recovered_cell_weights (same size as cell_to_*_weights)
//
// Outputs:
// output_state_ptr - size 'n_batch * n_output'
// cell_state_ptr - size 'n_batch * n_cell'
// output_ptr_batch - size 'n_batch * output_batch_leading_dim'
inline void LstmStepWithAuxInput(
const float* input_ptr_batch, const int8_t* input_to_input_weights_ptr,
float input_to_input_weights_scale,
const int8_t* input_to_forget_weights_ptr,
float input_to_forget_weights_scale,
const int8_t* input_to_cell_weights_ptr, float input_to_cell_weights_scale,
const int8_t* input_to_output_weights_ptr,
float input_to_output_weights_scale, const float* aux_input_ptr_batch,
const int8_t* aux_input_to_input_weights_ptr,
float aux_input_to_input_weights_scale,
const int8_t* aux_input_to_forget_weights_ptr,
float aux_input_to_forget_weights_scale,
const int8_t* aux_input_to_cell_weights_ptr,
float aux_input_to_cell_weights_scale,
const int8_t* aux_input_to_output_weights_ptr,
float aux_input_to_output_weights_scale,
const int8_t* recurrent_to_input_weights_ptr,
float recurrent_to_input_weights_scale,
const int8_t* recurrent_to_forget_weights_ptr,
float recurrent_to_forget_weights_scale,
const int8_t* recurrent_to_cell_weights_ptr,
float recurrent_to_cell_weights_scale,
const int8_t* recurrent_to_output_weights_ptr,
float recurrent_to_output_weights_scale,
const int8_t* cell_to_input_weights_ptr, float cell_to_input_weights_scale,
const int8_t* cell_to_forget_weights_ptr,
float cell_to_forget_weights_scale,
const int8_t* cell_to_output_weights_ptr,
float cell_to_output_weights_scale,
const float* input_layer_norm_coefficients_ptr,
const float* forget_layer_norm_coefficients_ptr,
const float* cell_layer_norm_coefficients_ptr,
const float* output_layer_norm_coefficients_ptr,
const float* input_gate_bias_ptr, const float* forget_gate_bias_ptr,
const float* cell_bias_ptr, const float* output_gate_bias_ptr,
const int8_t* projection_weights_ptr, float projection_weights_scale,
const float* projection_bias_ptr, const TfLiteLSTMParams* params,
int n_batch, int n_cell, int n_input, int n_aux_input, int n_output,
int output_batch_leading_dim, float* input_gate_scratch,
float* forget_gate_scratch, float* cell_scratch, float* output_gate_scratch,
float* scaling_factors, float* product_scaling_factors,
float* recovered_cell_weights, int8_t* quantized_input_ptr_batch,
int8_t* quantized_aux_input_ptr_batch, int8_t* quantized_output_state_ptr,
int8_t* quantized_cell_state_ptr, float* output_state_ptr,
float* cell_state_ptr, float* output_ptr_batch) {
#ifdef GEMMLOWP_PROFILING
gemmlowp::ScopedProfilingLabel label("LstmStepWithAuxInputHybrid");
#endif
// Since we have already checked that weights are all there or none, we
// can check the existence of only one to the get the condition.
const bool use_cifg = (input_to_input_weights_ptr == nullptr);
const bool use_peephole = (cell_to_output_weights_ptr != nullptr);
const bool is_layer_norm_lstm =
(forget_layer_norm_coefficients_ptr != nullptr);
// Initialize scratch buffers with bias.
if (is_layer_norm_lstm) {
if (!use_cifg) {
std::fill_n(input_gate_scratch, n_cell * n_batch, 0.0f);
}
std::fill_n(forget_gate_scratch, n_cell * n_batch, 0.0f);
std::fill_n(cell_scratch, n_cell * n_batch, 0.0f);
std::fill_n(output_gate_scratch, n_cell * n_batch, 0.0f);
} else {
if (!use_cifg) {
tensor_utils::VectorBatchVectorAssign(input_gate_bias_ptr, n_cell,
n_batch, input_gate_scratch);
}
tensor_utils::VectorBatchVectorAssign(forget_gate_bias_ptr, n_cell, n_batch,
forget_gate_scratch);
tensor_utils::VectorBatchVectorAssign(cell_bias_ptr, n_cell, n_batch,
cell_scratch);
tensor_utils::VectorBatchVectorAssign(output_gate_bias_ptr, n_cell, n_batch,
output_gate_scratch);
}
if (!tensor_utils::IsZeroVector(input_ptr_batch, n_batch * n_input)) {
// Save quantization and matmul computation for all zero input.
float unused_min, unused_max;
for (int b = 0; b < n_batch; ++b) {
const int offset = b * n_input;
tensor_utils::SymmetricQuantizeFloats(
input_ptr_batch + offset, n_input, quantized_input_ptr_batch + offset,
&unused_min, &unused_max, &scaling_factors[b]);
}
// For each batch and cell: compute input_weight * input.
if (!use_cifg) {
for (int b = 0; b < n_batch; ++b) {
product_scaling_factors[b] =
scaling_factors[b] * input_to_input_weights_scale;
}
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
input_to_input_weights_ptr, n_cell, n_input,
quantized_input_ptr_batch, product_scaling_factors, n_batch,
input_gate_scratch, /*result_stride=*/1);
}
for (int b = 0; b < n_batch; ++b) {
product_scaling_factors[b] =
scaling_factors[b] * input_to_forget_weights_scale;
}
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
input_to_forget_weights_ptr, n_cell, n_input, quantized_input_ptr_batch,
product_scaling_factors, n_batch, forget_gate_scratch,
/*result_stride=*/1);
for (int b = 0; b < n_batch; ++b) {
product_scaling_factors[b] =
scaling_factors[b] * input_to_cell_weights_scale;
}
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
input_to_cell_weights_ptr, n_cell, n_input, quantized_input_ptr_batch,
product_scaling_factors, n_batch, cell_scratch, /*result_stride=*/1);
for (int b = 0; b < n_batch; ++b) {
product_scaling_factors[b] =
scaling_factors[b] * input_to_output_weights_scale;
}
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
input_to_output_weights_ptr, n_cell, n_input, quantized_input_ptr_batch,
product_scaling_factors, n_batch, output_gate_scratch,
/*result_stride=*/1);
}
if (aux_input_ptr_batch != nullptr &&
!tensor_utils::IsZeroVector(aux_input_ptr_batch, n_batch * n_input)) {
// Save quantization and matmul computation for all zero input.
float unused_min, unused_max;
for (int b = 0; b < n_batch; ++b) {
const int offset = b * n_input;
tensor_utils::SymmetricQuantizeFloats(
aux_input_ptr_batch + offset, n_input,
quantized_aux_input_ptr_batch + offset, &unused_min, &unused_max,
&scaling_factors[b]);
}
// For each batch and cell: compute input_weight * input.
if (!use_cifg) {
for (int b = 0; b < n_batch; ++b) {
product_scaling_factors[b] =
scaling_factors[b] * aux_input_to_input_weights_scale;
}
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
aux_input_to_input_weights_ptr, n_cell, n_input,
quantized_aux_input_ptr_batch, product_scaling_factors, n_batch,
input_gate_scratch, /*result_stride=*/1);
}
for (int b = 0; b < n_batch; ++b) {
product_scaling_factors[b] =
scaling_factors[b] * aux_input_to_forget_weights_scale;
}
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
aux_input_to_forget_weights_ptr, n_cell, n_input,
quantized_aux_input_ptr_batch, product_scaling_factors, n_batch,
forget_gate_scratch, /*result_stride=*/1);
for (int b = 0; b < n_batch; ++b) {
product_scaling_factors[b] =
scaling_factors[b] * aux_input_to_cell_weights_scale;
}
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
aux_input_to_cell_weights_ptr, n_cell, n_input,
quantized_aux_input_ptr_batch, product_scaling_factors, n_batch,
cell_scratch, /*result_stride=*/1);
for (int b = 0; b < n_batch; ++b) {
product_scaling_factors[b] =
scaling_factors[b] * aux_input_to_output_weights_scale;
}
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
aux_input_to_output_weights_ptr, n_cell, n_input,
quantized_aux_input_ptr_batch, product_scaling_factors, n_batch,
output_gate_scratch, /*result_stride=*/1);
}
if (!tensor_utils::IsZeroVector(output_state_ptr, n_batch * n_output)) {
// Save quantization and matmul computation for all zero input.
float unused_min, unused_max;
for (int b = 0; b < n_batch; ++b) {
const int offset = b * n_output;
tensor_utils::SymmetricQuantizeFloats(output_state_ptr + offset, n_output,
quantized_output_state_ptr + offset,
&unused_min, &unused_max,
&scaling_factors[b]);
}
// For each batch and cell: compute recurrent_weight * output_state.
if (!use_cifg) {
for (int b = 0; b < n_batch; ++b) {
product_scaling_factors[b] =
scaling_factors[b] * recurrent_to_input_weights_scale;
}
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
recurrent_to_input_weights_ptr, n_cell, n_output,
quantized_output_state_ptr, product_scaling_factors, n_batch,
input_gate_scratch, /*result_stride=*/1);
}
for (int b = 0; b < n_batch; ++b) {
product_scaling_factors[b] =
scaling_factors[b] * recurrent_to_forget_weights_scale;
}
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
recurrent_to_forget_weights_ptr, n_cell, n_output,
quantized_output_state_ptr, product_scaling_factors, n_batch,
forget_gate_scratch, /*result_stride=*/1);
for (int b = 0; b < n_batch; ++b) {
product_scaling_factors[b] =
scaling_factors[b] * recurrent_to_cell_weights_scale;
}
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
recurrent_to_cell_weights_ptr, n_cell, n_output,
quantized_output_state_ptr, product_scaling_factors, n_batch,
cell_scratch, /*result_stride=*/1);
for (int b = 0; b < n_batch; ++b) {
product_scaling_factors[b] =
scaling_factors[b] * recurrent_to_output_weights_scale;
}
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
recurrent_to_output_weights_ptr, n_cell, n_output,
quantized_output_state_ptr, product_scaling_factors, n_batch,
output_gate_scratch, /*result_stride=*/1);
}
// Save quantization and matmul computation for all zero input.
bool is_cell_state_all_zeros =
tensor_utils::IsZeroVector(cell_state_ptr, n_batch * n_cell);
// For each batch and cell: update input gate.
if (!use_cifg) {
if (use_peephole && !is_cell_state_all_zeros) {
tensor_utils::VectorScalarMultiply(cell_to_input_weights_ptr, n_cell,
cell_to_input_weights_scale,
recovered_cell_weights);
tensor_utils::VectorBatchVectorCwiseProductAccumulate(
recovered_cell_weights, n_cell, cell_state_ptr, n_batch,
input_gate_scratch);
}
if (is_layer_norm_lstm) {
tensor_utils::MeanStddevNormalization(input_gate_scratch,
input_gate_scratch, n_cell, n_batch,
kLayerNormEpsilon);
tensor_utils::VectorBatchVectorCwiseProduct(
input_layer_norm_coefficients_ptr, n_cell, input_gate_scratch,
n_batch, input_gate_scratch);
tensor_utils::VectorBatchVectorAdd(input_gate_bias_ptr, n_cell, n_batch,
input_gate_scratch);
}
ApplyActivationsToVector(input_gate_scratch, n_cell * n_batch,
kTfLiteActSigmoid, input_gate_scratch);
}
// For each batch and cell: update forget gate.
if (use_peephole && !is_cell_state_all_zeros) {
tensor_utils::VectorScalarMultiply(cell_to_forget_weights_ptr, n_cell,
cell_to_forget_weights_scale,
recovered_cell_weights);
tensor_utils::VectorBatchVectorCwiseProductAccumulate(
recovered_cell_weights, n_cell, cell_state_ptr, n_batch,
forget_gate_scratch);
}
if (is_layer_norm_lstm) {
tensor_utils::MeanStddevNormalization(forget_gate_scratch,
forget_gate_scratch, n_cell, n_batch,
kLayerNormEpsilon);
tensor_utils::VectorBatchVectorCwiseProduct(
forget_layer_norm_coefficients_ptr, n_cell, forget_gate_scratch,
n_batch, forget_gate_scratch);
tensor_utils::VectorBatchVectorAdd(forget_gate_bias_ptr, n_cell, n_batch,
forget_gate_scratch);
}
ApplyActivationsToVector(forget_gate_scratch, n_cell * n_batch,
kTfLiteActSigmoid, forget_gate_scratch);
// For each batch and cell: update the cell.
tensor_utils::VectorVectorCwiseProduct(forget_gate_scratch, cell_state_ptr,
n_batch * n_cell, cell_state_ptr);
if (is_layer_norm_lstm) {
tensor_utils::MeanStddevNormalization(cell_scratch, cell_scratch, n_cell,
n_batch, kLayerNormEpsilon);
tensor_utils::VectorBatchVectorCwiseProduct(
cell_layer_norm_coefficients_ptr, n_cell, cell_scratch, n_batch,
cell_scratch);
tensor_utils::VectorBatchVectorAdd(cell_bias_ptr, n_cell, n_batch,
cell_scratch);
}
ApplyActivationsToVector(cell_scratch, n_batch * n_cell, params->activation,
cell_scratch);
if (use_cifg) {
tensor_utils::Sub1Vector(forget_gate_scratch, n_batch * n_cell,
forget_gate_scratch);
tensor_utils::VectorVectorCwiseProductAccumulate(
cell_scratch, forget_gate_scratch, n_batch * n_cell, cell_state_ptr);
} else {
tensor_utils::VectorVectorCwiseProductAccumulate(
cell_scratch, input_gate_scratch, n_batch * n_cell, cell_state_ptr);
}
if (params->cell_clip > 0.0) {
tensor_utils::ClipVector(cell_state_ptr, n_batch * n_cell,
params->cell_clip, cell_state_ptr);
}
is_cell_state_all_zeros =
tensor_utils::IsZeroVector(cell_state_ptr, n_batch * n_cell);
// For each batch and cell: update the output gate.
if (use_peephole && !is_cell_state_all_zeros) {
tensor_utils::VectorScalarMultiply(cell_to_output_weights_ptr, n_cell,
cell_to_output_weights_scale,
recovered_cell_weights);
tensor_utils::VectorBatchVectorCwiseProductAccumulate(
recovered_cell_weights, n_cell, cell_state_ptr, n_batch,
output_gate_scratch);
}
if (is_layer_norm_lstm) {
tensor_utils::MeanStddevNormalization(output_gate_scratch,
output_gate_scratch, n_cell, n_batch,
kLayerNormEpsilon);
tensor_utils::VectorBatchVectorCwiseProduct(
output_layer_norm_coefficients_ptr, n_cell, output_gate_scratch,
n_batch, output_gate_scratch);
tensor_utils::VectorBatchVectorAdd(output_gate_bias_ptr, n_cell, n_batch,
output_gate_scratch);
}
ApplyActivationsToVector(output_gate_scratch, n_batch * n_cell,
kTfLiteActSigmoid, output_gate_scratch);
ApplyActivationsToVector(cell_state_ptr, n_batch * n_cell, params->activation,
cell_scratch);
tensor_utils::VectorVectorCwiseProduct(output_gate_scratch, cell_scratch,
n_batch * n_cell, output_gate_scratch);
const bool use_projection_weight = (projection_weights_ptr != nullptr);
const bool use_projection_bias = (projection_bias_ptr != nullptr);
// For each batch: update the projection and output_state. Note that since
// the output batch rows may not be contiguous (output_batch_leading_dim !=
// n_output), we unroll the batched operations where this is the case.
if (output_batch_leading_dim == n_output) {
if (use_projection_weight) {
if (use_projection_bias) {
tensor_utils::VectorBatchVectorAssign(projection_bias_ptr, n_output,
n_batch, output_ptr_batch);
} else {
std::fill_n(output_ptr_batch, n_batch * n_output, 0.0f);
}
if (!tensor_utils::IsZeroVector(output_gate_scratch, n_batch * n_cell)) {
// Save quantization and matmul computation for all zero input.
float unused_min, unused_max;
for (int b = 0; b < n_batch; ++b) {
const int offset = b * n_cell;
tensor_utils::SymmetricQuantizeFloats(
output_gate_scratch + offset, n_cell,
quantized_cell_state_ptr + offset, &unused_min, &unused_max,
&scaling_factors[b]);
}
for (int b = 0; b < n_batch; ++b) {
product_scaling_factors[b] =
scaling_factors[b] * projection_weights_scale;
}
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
projection_weights_ptr, n_output, n_cell, quantized_cell_state_ptr,
product_scaling_factors, n_batch, output_ptr_batch,
/*result_stride=*/1);
}
if (params->proj_clip > 0.0) {
tensor_utils::ClipVector(output_ptr_batch, n_batch * n_output,
params->proj_clip, output_ptr_batch);
}
} else {
std::copy_n(output_gate_scratch, n_batch * n_output, output_ptr_batch);
}
std::copy_n(output_ptr_batch, n_batch * n_output, output_state_ptr);
} else {
if (use_projection_weight) {
if (use_projection_bias) {
for (int k = 0; k < n_batch; k++) {
std::copy_n(projection_bias_ptr, n_output,
output_ptr_batch + k * output_batch_leading_dim);
}
} else {
for (int k = 0; k < n_batch; k++) {
std::fill_n(output_ptr_batch + k * output_batch_leading_dim, n_output,
0.0f);
}
}
if (!tensor_utils::IsZeroVector(output_gate_scratch, n_batch * n_cell)) {
// Save quantization and matmul computation for all zero input.
float unused_min, unused_max;
for (int b = 0; b < n_batch; ++b) {
const int offset = b * n_cell;
tensor_utils::SymmetricQuantizeFloats(
output_gate_scratch + offset, n_cell,
quantized_cell_state_ptr + offset, &unused_min, &unused_max,
&scaling_factors[b]);
}
for (int b = 0; b < n_batch; ++b) {
product_scaling_factors[b] =
scaling_factors[b] * projection_weights_scale;
}
for (int k = 0; k < n_batch; k++) {
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
projection_weights_ptr, n_output, n_cell,
quantized_cell_state_ptr + k * n_cell,
&product_scaling_factors[k],
/*n_batch=*/1, output_ptr_batch + k * output_batch_leading_dim,
/*result_stride=*/1);
}
}
if (params->proj_clip > 0.0) {
for (int k = 0; k < n_batch; k++) {
tensor_utils::ClipVector(
output_ptr_batch + k * output_batch_leading_dim, n_output,
params->proj_clip,
output_ptr_batch + k * output_batch_leading_dim);
}
}
} else {
for (int k = 0; k < n_batch; k++) {
std::copy_n(output_gate_scratch + k * n_output, n_output,
output_ptr_batch + k * output_batch_leading_dim);
}
}
for (int k = 0; k < n_batch; k++) {
std::copy_n(output_ptr_batch + k * output_batch_leading_dim, n_output,
output_state_ptr + k * n_output);
}
}
}
inline void LstmStepQuantized(
const int8_t* input_ptr, int32_t input_zp,
const int8_t* input_to_input_weight_ptr,
int32_t effective_input_to_input_scale_a,
int32_t effective_input_to_input_scale_b,
const int8_t* input_to_forget_weight_ptr,
int32_t effective_input_to_forget_scale_a,
int32_t effective_input_to_forget_scale_b,
const int8_t* input_to_cell_weight_ptr,
int32_t effective_input_to_cell_scale_a,
int32_t effective_input_to_cell_scale_b,
const int8_t* input_to_output_weight_ptr,
int32_t effective_input_to_output_scale_a,
int32_t effective_input_to_output_scale_b,
const int8_t* recurrent_to_input_weight_ptr,
int32_t effective_recurrent_to_input_scale_a,
int32_t effective_recurrent_to_input_scale_b,
const int8_t* recurrent_to_forget_weight_ptr,
int32_t effective_recurrent_to_forget_scale_a,
int32_t effective_recurrent_to_forget_scale_b,
const int8_t* recurrent_to_cell_weight_ptr,
int32_t effective_recurrent_to_cell_scale_a,
int32_t effective_recurrent_to_cell_scale_b,
const int8_t* recurrent_to_output_weight_ptr,
int32_t effective_recurrent_to_output_scale_a,
int32_t effective_recurrent_to_output_scale_b,
const int8_t* cell_to_input_weight_ptr,
int32_t effective_cell_to_input_scale_a,
int32_t effective_cell_to_input_scale_b,
const int8_t* cell_to_forget_weight_ptr,
int32_t effective_cell_to_forget_scale_a,
int32_t effective_cell_to_forget_scale_b,
const int8_t* cell_to_output_weight_ptr,
int32_t effective_cell_to_output_scale_a,
int32_t effective_cell_to_output_scale_b, const int8_t* proj_weight_ptr,
int32_t effective_proj_scale_a, int32_t effective_proj_scale_b,
const int16_t* layer_norm_input_weight_ptr,
int32_t layer_norm_input_scale_a, int32_t layer_norm_input_scale_b,
const int16_t* layer_norm_forget_weight_ptr,
int32_t layer_norm_forget_scale_a, int32_t layer_norm_forget_scale_b,
const int16_t* layer_norm_cell_weight_ptr, int32_t layer_norm_cell_scale_a,
int32_t layer_norm_cell_scale_b,
const int16_t* layer_norm_output_weight_ptr,
int32_t layer_norm_output_scale_a, int32_t layer_norm_output_scale_b,
const int32_t* input_bias_ptr, const int32_t* forget_bias_ptr,
const int32_t* cell_bias_ptr, const int32_t* output_bias_ptr,
const int32_t* proj_bias_ptr, int32 quantized_cell_clip,
int32 quantized_proj_clip, const int32_t* inv_large_value, int32 n_batch,
int32 n_cell, int32 n_input, int32 n_output, int32 output_batch_leading_dim,
int8_t* activation_ptr, int32_t activation_zp, int16_t* cell_ptr,
int8_t* output_ptr, int16_t* scratch_0_ptr, int16_t* scratch_1_ptr,
int16_t* scratch_2_ptr, int16_t* scratch_3_ptr, int8_t* scratch_4_ptr) {
// Set scratch to 0.
memset(scratch_0_ptr, 0, n_batch * n_cell * sizeof(int16_t));
memset(scratch_1_ptr, 0, n_batch * n_cell * sizeof(int16_t));
memset(scratch_2_ptr, 0, n_batch * n_cell * sizeof(int16_t));
memset(scratch_3_ptr, 0, n_batch * n_cell * sizeof(int16_t));
// Forget gate.
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
input_ptr, input_zp, input_to_forget_weight_ptr,
effective_input_to_forget_scale_a, effective_input_to_forget_scale_b,
nullptr, n_batch, n_input, n_cell, 0, scratch_1_ptr);
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
activation_ptr, activation_zp, recurrent_to_forget_weight_ptr,
effective_recurrent_to_forget_scale_a,
effective_recurrent_to_forget_scale_b, nullptr, n_batch, n_output, n_cell,
0, scratch_1_ptr);
tensor_utils::ApplyLayerNorm(scratch_1_ptr, layer_norm_forget_weight_ptr,
forget_bias_ptr, layer_norm_forget_scale_a,
layer_norm_forget_scale_b, inv_large_value[1],
n_batch, n_cell, scratch_1_ptr);
tensor_utils::ApplySigmoid(scratch_1_ptr, n_batch, n_cell, scratch_1_ptr);
// Modulation gate.
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
input_ptr, input_zp, input_to_cell_weight_ptr,
effective_input_to_cell_scale_a, effective_input_to_cell_scale_b, nullptr,
n_batch, n_input, n_cell, 0, scratch_2_ptr);
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
activation_ptr, activation_zp, recurrent_to_cell_weight_ptr,
effective_recurrent_to_cell_scale_a, effective_recurrent_to_cell_scale_b,
nullptr, n_batch, n_output, n_cell, 0, scratch_2_ptr);
tensor_utils::ApplyLayerNorm(scratch_2_ptr, layer_norm_cell_weight_ptr,
cell_bias_ptr, layer_norm_cell_scale_a,
layer_norm_cell_scale_b, inv_large_value[2],
n_batch, n_cell, scratch_2_ptr);
tensor_utils::ApplyTanh3(scratch_2_ptr, n_batch, n_cell, scratch_2_ptr);
// Ouptut gate.
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
input_ptr, input_zp, input_to_output_weight_ptr,
effective_input_to_output_scale_a, effective_input_to_output_scale_b,
nullptr, n_batch, n_input, n_cell, 0, scratch_3_ptr);
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
activation_ptr, activation_zp, recurrent_to_output_weight_ptr,
effective_recurrent_to_output_scale_a,
effective_recurrent_to_output_scale_b, nullptr, n_batch, n_output, n_cell,
0, scratch_3_ptr);
tensor_utils::ApplyLayerNorm(scratch_3_ptr, layer_norm_output_weight_ptr,
output_bias_ptr, layer_norm_output_scale_a,
layer_norm_output_scale_b, inv_large_value[3],
n_batch, n_cell, scratch_3_ptr);
tensor_utils::ApplySigmoid(scratch_3_ptr, n_batch, n_cell, scratch_3_ptr);
// Input gate.
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
input_ptr, input_zp, input_to_input_weight_ptr,
effective_input_to_input_scale_a, effective_input_to_input_scale_b,
nullptr, n_batch, n_input, n_cell, 0, scratch_0_ptr);
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
activation_ptr, activation_zp, recurrent_to_input_weight_ptr,
effective_recurrent_to_input_scale_a,
effective_recurrent_to_input_scale_b, nullptr, n_batch, n_output, n_cell,
0, scratch_0_ptr);
tensor_utils::ApplyLayerNorm(scratch_0_ptr, layer_norm_input_weight_ptr,
input_bias_ptr, layer_norm_input_scale_a,
layer_norm_input_scale_b, inv_large_value[0],
n_batch, n_cell, scratch_0_ptr);
tensor_utils::ApplySigmoid(scratch_0_ptr, n_batch, n_cell, scratch_0_ptr);
// Cell and hidden.
tensor_utils::CwiseMul(scratch_1_ptr, cell_ptr, n_batch, n_cell, 15,
scratch_1_ptr);
tensor_utils::CwiseMul(scratch_0_ptr, scratch_2_ptr, n_batch, n_cell, 19,
scratch_2_ptr);
tensor_utils::CwiseAdd(scratch_1_ptr, scratch_2_ptr, n_batch, n_cell,
cell_ptr);
if (quantized_cell_clip > 0) {
tensor_utils::CwiseClipping(cell_ptr, quantized_cell_clip, n_batch, n_cell);
}
tensor_utils::ApplyTanh4(cell_ptr, n_batch, n_cell, scratch_0_ptr);
tensor_utils::CwiseMul(scratch_3_ptr, scratch_0_ptr, n_batch, n_cell, 23,
scratch_4_ptr);
// Projection.
memset(output_ptr, 0, n_batch * n_output * sizeof(int8_t));
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
scratch_4_ptr, 0, proj_weight_ptr, effective_proj_scale_a,
effective_proj_scale_b, proj_bias_ptr, n_batch, n_cell, n_output,
activation_zp, output_ptr);
if (quantized_proj_clip > 0) {
tensor_utils::CwiseClipping(output_ptr, quantized_proj_clip, n_batch,
n_output);
}
memcpy(activation_ptr, output_ptr, n_batch * n_output * sizeof(int8_t));
}
} // namespace
TfLiteStatus EvalFloat(
const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights,
const TfLiteTensor* input_to_forget_weights,
const TfLiteTensor* input_to_cell_weights,
const TfLiteTensor* input_to_output_weights,
const TfLiteTensor* recurrent_to_input_weights,
const TfLiteTensor* recurrent_to_forget_weights,
const TfLiteTensor* recurrent_to_cell_weights,
const TfLiteTensor* recurrent_to_output_weights,
const TfLiteTensor* cell_to_input_weights,
const TfLiteTensor* cell_to_forget_weights,
const TfLiteTensor* cell_to_output_weights,
const TfLiteTensor* input_layer_norm_coefficients,
const TfLiteTensor* forget_layer_norm_coefficients,
const TfLiteTensor* cell_layer_norm_coefficients,
const TfLiteTensor* output_layer_norm_coefficients,
const TfLiteTensor* aux_input,
const TfLiteTensor* aux_input_to_input_weights,
const TfLiteTensor* aux_input_to_forget_weights,
const TfLiteTensor* aux_input_to_cell_weights,
const TfLiteTensor* aux_input_to_output_weights,
const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias,
const TfLiteTensor* cell_bias, const TfLiteTensor* output_gate_bias,
const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias,
const TfLiteLSTMParams* params, bool forward_sequence, bool time_major,
int output_offset, TfLiteTensor* scratch_buffer,
TfLiteTensor* activation_state, TfLiteTensor* cell_state,
TfLiteTensor* output) {
TF_LITE_ASSERT(input->dims->size >= 2 && input->dims->size <= 3);
int max_time, n_batch;
if (input->dims->size == 3) {
max_time = (time_major) ? input->dims->data[0] : input->dims->data[1];
n_batch = (time_major) ? input->dims->data[1] : input->dims->data[0];
} else {
max_time = 1;
n_batch = input->dims->data[0];
}
const int n_input = input->dims->data[input->dims->size - 1];
const int aux_input_size =
(aux_input) ? aux_input->dims->data[aux_input->dims->size - 1] : 0;
// n_cell and n_output will be the same size when there is no projection.
const int n_cell = input_to_output_weights->dims->data[0];
const int n_output = recurrent_to_output_weights->dims->data[1];
// Since we have already checked that weights are all there or none, we can
// check the existence of only one to the get the condition.
const bool use_cifg = (input_to_input_weights == nullptr);
const bool use_peephole = (cell_to_output_weights != nullptr);
const bool is_layer_norm_lstm = (forget_layer_norm_coefficients != nullptr);
// Index the scratch buffers pointers to the global scratch buffer.
float* input_gate_scratch = nullptr;
float* cell_scratch = nullptr;
float* forget_gate_scratch = nullptr;
float* output_gate_scratch = nullptr;
if (use_cifg) {
cell_scratch = scratch_buffer->data.f;
forget_gate_scratch = scratch_buffer->data.f + n_cell * n_batch;
output_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch;
} else {
input_gate_scratch = scratch_buffer->data.f;
cell_scratch = scratch_buffer->data.f + n_cell * n_batch;
forget_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch;
output_gate_scratch = scratch_buffer->data.f + 3 * n_cell * n_batch;
}
// Check optional tensors, the respective pointers can be null.
const float* input_to_input_weights_ptr =
(use_cifg) ? nullptr : input_to_input_weights->data.f;
const float* recurrent_to_input_weights_ptr =
(use_cifg) ? nullptr : recurrent_to_input_weights->data.f;
const float* input_gate_bias_ptr =
(use_cifg) ? nullptr : input_gate_bias->data.f;
const float* cell_to_input_weights_ptr =
(use_peephole && !use_cifg) ? cell_to_input_weights->data.f : nullptr;
const float* cell_to_forget_weights_ptr =
(use_peephole) ? cell_to_forget_weights->data.f : nullptr;
const float* cell_to_output_weights_ptr =
(use_peephole) ? cell_to_output_weights->data.f : nullptr;
const float* input_layer_norm_coefficients_ptr =
(is_layer_norm_lstm && !use_cifg) ? input_layer_norm_coefficients->data.f
: nullptr;
const float* forget_layer_norm_coefficients_ptr =
is_layer_norm_lstm ? forget_layer_norm_coefficients->data.f : nullptr;
const float* cell_layer_norm_coefficients_ptr =
is_layer_norm_lstm ? cell_layer_norm_coefficients->data.f : nullptr;
const float* output_layer_norm_coefficients_ptr =
is_layer_norm_lstm ? output_layer_norm_coefficients->data.f : nullptr;
const float* projection_weights_ptr =
(projection_weights == nullptr) ? nullptr : projection_weights->data.f;
const float* projection_bias_ptr =
(projection_bias == nullptr) ? nullptr : projection_bias->data.f;
float* aux_input_ptr = nullptr;
float* aux_input_to_input_weights_ptr = nullptr;
float* aux_input_to_forget_weights_ptr = nullptr;
float* aux_input_to_cell_weights_ptr = nullptr;
float* aux_input_to_output_weights_ptr = nullptr;
if (aux_input_size > 0) {
if (!use_cifg) {
aux_input_to_input_weights_ptr = aux_input_to_input_weights->data.f;
}
aux_input_to_forget_weights_ptr = aux_input_to_forget_weights->data.f;
aux_input_to_cell_weights_ptr = aux_input_to_cell_weights->data.f;
aux_input_to_output_weights_ptr = aux_input_to_output_weights->data.f;
}
const int output_batch_leading_dim =
output->dims->data[output->dims->size - 1];
if (time_major) {
// Loop through the sequence.
const int input_step = n_batch * n_input;
const int output_step = n_batch * output_batch_leading_dim;
for (int t = 0; t < max_time; t++) {
// If this is the forward_sequence, step forward, otherwise step
// backwards.
const int t_rel = forward_sequence ? t : max_time - t - 1;
const float* input_ptr_batch = input->data.f + t_rel * input_step;
if (aux_input) {
aux_input_ptr = aux_input->data.f + t_rel * input_step;
}
float* output_ptr_time =
output->data.f + t_rel * output_step + output_offset;
LstmStepWithAuxInput(
input_ptr_batch, input_to_input_weights_ptr,
input_to_forget_weights->data.f, input_to_cell_weights->data.f,
input_to_output_weights->data.f, aux_input_ptr,
aux_input_to_input_weights_ptr, aux_input_to_forget_weights_ptr,
aux_input_to_cell_weights_ptr, aux_input_to_output_weights_ptr,
recurrent_to_input_weights_ptr, recurrent_to_forget_weights->data.f,
recurrent_to_cell_weights->data.f,
recurrent_to_output_weights->data.f, cell_to_input_weights_ptr,
cell_to_forget_weights_ptr, cell_to_output_weights_ptr,
input_layer_norm_coefficients_ptr, forget_layer_norm_coefficients_ptr,
cell_layer_norm_coefficients_ptr, output_layer_norm_coefficients_ptr,
input_gate_bias_ptr, forget_gate_bias->data.f, cell_bias->data.f,
output_gate_bias->data.f, projection_weights_ptr, projection_bias_ptr,
params, n_batch, n_cell, n_input, aux_input_size, n_output,
output_batch_leading_dim, activation_state->data.f,
cell_state->data.f, input_gate_scratch, forget_gate_scratch,
cell_scratch, output_gate_scratch, output_ptr_time);
}
} else {
for (int b = 0; b < n_batch; b++) {
const int input_step = n_input;
const int output_step = output_batch_leading_dim;
for (int t = 0; t < max_time; t++) {
// If this is the forward_sequence, step forward, otherwise step
// backwards.
const int t_rel = forward_sequence ? t : max_time - t - 1;
const int time_offset = b * max_time + t_rel;
const float* input_ptr = input->data.f + time_offset * input_step;
if (aux_input) {
aux_input_ptr = aux_input->data.f + time_offset * input_step;
}
float* output_ptr =
output->data.f + time_offset * output_step + output_offset;
// Offset the {activation,cell}_state pointers to the right batch.
float* activation_state_ptr =
activation_state->data.f + b * output_batch_leading_dim;
float* cell_state_ptr = cell_state->data.f + b * n_cell;
// Offset the scratch pointers to the right batch.
float* input_gate_scratch_ptr =
input_gate_scratch ? input_gate_scratch + b * n_cell : nullptr;
float* forget_gate_scratch_ptr = forget_gate_scratch + b * n_cell;
float* cell_scratch_ptr = cell_scratch + b * n_cell;
float* output_gate_scratch_ptr = output_gate_scratch + b * n_cell;
LstmStepWithAuxInput(
input_ptr, input_to_input_weights_ptr,
input_to_forget_weights->data.f, input_to_cell_weights->data.f,
input_to_output_weights->data.f, aux_input_ptr,
aux_input_to_input_weights_ptr, aux_input_to_forget_weights_ptr,
aux_input_to_cell_weights_ptr, aux_input_to_output_weights_ptr,
recurrent_to_input_weights_ptr, recurrent_to_forget_weights->data.f,
recurrent_to_cell_weights->data.f,
recurrent_to_output_weights->data.f, cell_to_input_weights_ptr,
cell_to_forget_weights_ptr, cell_to_output_weights_ptr,
input_layer_norm_coefficients_ptr,
forget_layer_norm_coefficients_ptr,
cell_layer_norm_coefficients_ptr,
output_layer_norm_coefficients_ptr, input_gate_bias_ptr,
forget_gate_bias->data.f, cell_bias->data.f,
output_gate_bias->data.f, projection_weights_ptr,
projection_bias_ptr, params, /*n_batch=*/1, n_cell, n_input,
aux_input_size, n_output, output_batch_leading_dim,
activation_state_ptr, cell_state_ptr, input_gate_scratch_ptr,
forget_gate_scratch_ptr, cell_scratch_ptr, output_gate_scratch_ptr,
output_ptr);
}
}
}
return kTfLiteOk;
}
TfLiteStatus EvalHybrid(
const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights,
const TfLiteTensor* input_to_forget_weights,
const TfLiteTensor* input_to_cell_weights,
const TfLiteTensor* input_to_output_weights,
const TfLiteTensor* recurrent_to_input_weights,
const TfLiteTensor* recurrent_to_forget_weights,
const TfLiteTensor* recurrent_to_cell_weights,
const TfLiteTensor* recurrent_to_output_weights,
const TfLiteTensor* cell_to_input_weights,
const TfLiteTensor* cell_to_forget_weights,
const TfLiteTensor* cell_to_output_weights,
const TfLiteTensor* input_layer_norm_coefficients,
const TfLiteTensor* forget_layer_norm_coefficients,
const TfLiteTensor* cell_layer_norm_coefficients,
const TfLiteTensor* output_layer_norm_coefficients,
const TfLiteTensor* aux_input,
const TfLiteTensor* aux_input_to_input_weights,
const TfLiteTensor* aux_input_to_forget_weights,
const TfLiteTensor* aux_input_to_cell_weights,
const TfLiteTensor* aux_input_to_output_weights,
const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias,
const TfLiteTensor* cell_bias, const TfLiteTensor* output_gate_bias,
const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias,
const TfLiteLSTMParams* params, bool forward_sequence, bool time_major,
int output_offset, TfLiteTensor* scratch_buffer,
TfLiteTensor* scaling_factors, TfLiteTensor* prod_scaling_factors,
TfLiteTensor* recovered_cell_weights, TfLiteTensor* input_quantized,
TfLiteTensor* aux_input_quantized, TfLiteTensor* output_state_quantized,
TfLiteTensor* cell_state_quantized, TfLiteTensor* output_state,
TfLiteTensor* cell_state, TfLiteTensor* output) {
TF_LITE_ASSERT(input->dims->size >= 2 && input->dims->size <= 3);
const int n_input = input->dims->data[input->dims->size - 1];
int max_time, n_batch;
if (input->dims->size == 2) {
max_time = 1;
n_batch = input->dims->data[0];
} else {
max_time = (time_major) ? input->dims->data[0] : input->dims->data[1];
n_batch = (time_major) ? input->dims->data[1] : input->dims->data[0];
}
const int aux_input_size =
(aux_input) ? aux_input->dims->data[aux_input->dims->size - 1] : 0;
// n_cell and n_output will be the same size when there is no projection.
const int n_cell = input_to_output_weights->dims->data[0];
const int n_output = recurrent_to_output_weights->dims->data[1];
// Since we have already checked that weights are all there or none, we can
// check the existence of only one to get the condition.
const bool use_cifg = (input_to_input_weights == nullptr);
const bool use_peephole = (cell_to_output_weights != nullptr);
const bool is_layer_norm_lstm = (forget_layer_norm_coefficients != nullptr);
float* input_gate_scratch = nullptr;
float* cell_scratch = nullptr;
float* forget_gate_scratch = nullptr;
float* output_gate_scratch = nullptr;
if (use_cifg) {
cell_scratch = scratch_buffer->data.f;
forget_gate_scratch = scratch_buffer->data.f + n_cell * n_batch;
output_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch;
} else {
input_gate_scratch = scratch_buffer->data.f;
cell_scratch = scratch_buffer->data.f + n_cell * n_batch;
forget_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch;
output_gate_scratch = scratch_buffer->data.f + 3 * n_cell * n_batch;
}
// Check optional tensors, the respective pointers can be null.
const int8_t* input_to_input_weights_ptr = nullptr;
float input_to_input_weights_scale = 1.0f;
const int8_t* recurrent_to_input_weights_ptr = nullptr;
float recurrent_to_input_weights_scale = 1.0f;
float* input_gate_bias_ptr = nullptr;
if (!use_cifg) {
input_to_input_weights_ptr = GetTensorData<int8_t>(input_to_input_weights);
recurrent_to_input_weights_ptr =
GetTensorData<int8_t>(recurrent_to_input_weights);
input_gate_bias_ptr = input_gate_bias->data.f;
input_to_input_weights_scale = input_to_input_weights->params.scale;
recurrent_to_input_weights_scale = recurrent_to_input_weights->params.scale;
}
const int8_t* cell_to_input_weights_ptr = nullptr;
const int8_t* cell_to_forget_weights_ptr = nullptr;
const int8_t* cell_to_output_weights_ptr = nullptr;
float cell_to_input_weights_scale = 1.0f;
float cell_to_forget_weights_scale = 1.0f;
float cell_to_output_weights_scale = 1.0f;
if (use_peephole) {
if (!use_cifg) {
cell_to_input_weights_ptr = GetTensorData<int8_t>(cell_to_input_weights);
cell_to_input_weights_scale = cell_to_input_weights->params.scale;
}
cell_to_forget_weights_ptr = GetTensorData<int8_t>(cell_to_forget_weights);
cell_to_output_weights_ptr = GetTensorData<int8_t>(cell_to_output_weights);
cell_to_forget_weights_scale = cell_to_forget_weights->params.scale;
cell_to_output_weights_scale = cell_to_output_weights->params.scale;
}
const float* input_layer_norm_coefficients_ptr =
(is_layer_norm_lstm && !use_cifg) ? input_layer_norm_coefficients->data.f
: nullptr;
const float* forget_layer_norm_coefficients_ptr =
is_layer_norm_lstm ? forget_layer_norm_coefficients->data.f : nullptr;
const float* cell_layer_norm_coefficients_ptr =
is_layer_norm_lstm ? cell_layer_norm_coefficients->data.f : nullptr;
const float* output_layer_norm_coefficients_ptr =
is_layer_norm_lstm ? output_layer_norm_coefficients->data.f : nullptr;
const int8_t* projection_weights_ptr =
(projection_weights == nullptr)
? nullptr
: GetTensorData<int8_t>(projection_weights);
const float projection_weights_scale =
(projection_weights == nullptr) ? 1.0f : projection_weights->params.scale;
const float* projection_bias_ptr =
(projection_bias == nullptr) ? nullptr : projection_bias->data.f;
// Required tensors, pointers are non-null.
const int8_t* input_to_forget_weights_ptr =
GetTensorData<int8_t>(input_to_forget_weights);
const float input_to_forget_weights_scale =
input_to_forget_weights->params.scale;
const int8_t* input_to_cell_weights_ptr =
GetTensorData<int8_t>(input_to_cell_weights);
const float input_to_cell_weights_scale = input_to_cell_weights->params.scale;
const int8_t* input_to_output_weights_ptr =
GetTensorData<int8_t>(input_to_output_weights);
const float input_to_output_weights_scale =
input_to_output_weights->params.scale;
const int8_t* recurrent_to_forget_weights_ptr =
GetTensorData<int8_t>(recurrent_to_forget_weights);
const float recurrent_to_forget_weights_scale =
recurrent_to_forget_weights->params.scale;
const int8_t* recurrent_to_cell_weights_ptr =
GetTensorData<int8_t>(recurrent_to_cell_weights);
const float recurrent_to_cell_weights_scale =
recurrent_to_cell_weights->params.scale;
const int8_t* recurrent_to_output_weights_ptr =
GetTensorData<int8_t>(recurrent_to_output_weights);
const float recurrent_to_output_weights_scale =
recurrent_to_output_weights->params.scale;
const float* forget_gate_bias_ptr = forget_gate_bias->data.f;
const float* cell_bias_ptr = cell_bias->data.f;
const float* output_gate_bias_ptr = output_gate_bias->data.f;
// Temporary storage for quantized values and scaling factors.
int8_t* quantized_input_ptr = GetTensorData<int8_t>(input_quantized);
int8_t* quantized_aux_input_ptr =
(aux_input_quantized == nullptr)
? nullptr
: GetTensorData<int8_t>(aux_input_quantized);
int8_t* quantized_output_state_ptr =
GetTensorData<int8_t>(output_state_quantized);
int8_t* quantized_cell_state_ptr =
GetTensorData<int8_t>(cell_state_quantized);
float* scaling_factors_ptr = scaling_factors->data.f;
float* prod_scaling_factors_ptr = prod_scaling_factors->data.f;
float* recovered_cell_weights_ptr = recovered_cell_weights->data.f;
// Auxiliary input and weights.
float* aux_input_ptr = nullptr;
const int8_t* aux_input_to_input_weights_ptr = nullptr;
const int8_t* aux_input_to_forget_weights_ptr = nullptr;
const int8_t* aux_input_to_cell_weights_ptr = nullptr;
const int8_t* aux_input_to_output_weights_ptr = nullptr;
float aux_input_to_input_weights_scale = 0.0f;
float aux_input_to_forget_weights_scale = 0.0f;
float aux_input_to_cell_weights_scale = 0.0f;
float aux_input_to_output_weights_scale = 0.0f;
if (aux_input_size > 0) {
if (!use_cifg) {
aux_input_to_input_weights_ptr =
GetTensorData<int8_t>(aux_input_to_input_weights);
}
aux_input_to_forget_weights_ptr =
GetTensorData<int8_t>(aux_input_to_forget_weights);
aux_input_to_cell_weights_ptr =
GetTensorData<int8_t>(aux_input_to_cell_weights);
aux_input_to_output_weights_ptr =
GetTensorData<int8_t>(aux_input_to_output_weights);
if (!use_cifg) {
aux_input_to_input_weights_scale =
aux_input_to_input_weights->params.scale;
}
aux_input_to_forget_weights_scale =
aux_input_to_forget_weights->params.scale;
aux_input_to_cell_weights_scale = aux_input_to_cell_weights->params.scale;
aux_input_to_output_weights_scale =
aux_input_to_output_weights->params.scale;
}
const int output_batch_leading_dim =
output->dims->data[output->dims->size - 1];
if (time_major) {
// Feed the sequence into the LSTM step-by-step.
const int input_step = n_batch * n_input;
const int output_step = n_batch * output_batch_leading_dim;
for (int t = 0; t < max_time; t++) {
// If this is the forward_sequence, step forward, otherwise step
// backwards.
const int t_rel = forward_sequence ? t : max_time - t - 1;
const float* input_ptr_batch = input->data.f + t_rel * input_step;
if (aux_input) {
aux_input_ptr = aux_input->data.f + t_rel * input_step;
}
float* output_ptr_batch =
output->data.f + t_rel * output_step + output_offset;
LstmStepWithAuxInput(
input_ptr_batch, input_to_input_weights_ptr,
input_to_input_weights_scale, input_to_forget_weights_ptr,
input_to_forget_weights_scale, input_to_cell_weights_ptr,
input_to_cell_weights_scale, input_to_output_weights_ptr,
input_to_output_weights_scale, aux_input_ptr,
aux_input_to_input_weights_ptr, aux_input_to_input_weights_scale,
aux_input_to_forget_weights_ptr, aux_input_to_forget_weights_scale,
aux_input_to_cell_weights_ptr, aux_input_to_cell_weights_scale,
aux_input_to_output_weights_ptr, aux_input_to_output_weights_scale,
recurrent_to_input_weights_ptr, recurrent_to_input_weights_scale,
recurrent_to_forget_weights_ptr, recurrent_to_forget_weights_scale,
recurrent_to_cell_weights_ptr, recurrent_to_cell_weights_scale,
recurrent_to_output_weights_ptr, recurrent_to_output_weights_scale,
cell_to_input_weights_ptr, cell_to_input_weights_scale,
cell_to_forget_weights_ptr, cell_to_forget_weights_scale,
cell_to_output_weights_ptr, cell_to_output_weights_scale,
input_layer_norm_coefficients_ptr, forget_layer_norm_coefficients_ptr,
cell_layer_norm_coefficients_ptr, output_layer_norm_coefficients_ptr,
input_gate_bias_ptr, forget_gate_bias_ptr, cell_bias_ptr,
output_gate_bias_ptr, projection_weights_ptr,
projection_weights_scale, projection_bias_ptr, params, n_batch,
n_cell, n_input, aux_input_size, n_output, output_batch_leading_dim,
input_gate_scratch, forget_gate_scratch, cell_scratch,
output_gate_scratch, scaling_factors_ptr, prod_scaling_factors_ptr,
recovered_cell_weights_ptr, quantized_input_ptr,
quantized_aux_input_ptr, quantized_output_state_ptr,
quantized_cell_state_ptr, output_state->data.f, cell_state->data.f,
output_ptr_batch);
}
} else {
for (int b = 0; b < n_batch; b++) {
const int input_step = n_input;
const int output_step = output_batch_leading_dim;
for (int t = 0; t < max_time; t++) {
// If this is the forward_sequence, step forward, otherwise step
// backwards.
const int t_rel = forward_sequence ? t : max_time - t - 1;
const int time_offset = b * max_time + t_rel;
const float* input_ptr = input->data.f + time_offset * input_step;
if (aux_input) {
aux_input_ptr = aux_input->data.f + time_offset * input_step;
}
float* output_ptr =
output->data.f + time_offset * output_step + output_offset;
// Offset the {output,cell}_state pointers to the right batch.
float* output_state_ptr =
output_state->data.f + b * output_batch_leading_dim;
float* cell_state_ptr = cell_state->data.f + b * n_cell;
// Offset the scratch pointers to the right batch.
float* input_gate_scratch_ptr =
input_gate_scratch ? input_gate_scratch + b * n_cell : nullptr;
float* forget_gate_scratch_ptr = forget_gate_scratch + b * n_cell;
float* cell_scratch_ptr = cell_scratch + b * n_cell;
float* output_gate_scratch_ptr = output_gate_scratch + b * n_cell;
LstmStepWithAuxInput(
input_ptr, input_to_input_weights_ptr, input_to_input_weights_scale,
input_to_forget_weights_ptr, input_to_forget_weights_scale,
input_to_cell_weights_ptr, input_to_cell_weights_scale,
input_to_output_weights_ptr, input_to_output_weights_scale,
aux_input_ptr, aux_input_to_input_weights_ptr,
aux_input_to_input_weights_scale, aux_input_to_forget_weights_ptr,
aux_input_to_forget_weights_scale, aux_input_to_cell_weights_ptr,
aux_input_to_cell_weights_scale, aux_input_to_output_weights_ptr,
aux_input_to_output_weights_scale, recurrent_to_input_weights_ptr,
recurrent_to_input_weights_scale, recurrent_to_forget_weights_ptr,
recurrent_to_forget_weights_scale, recurrent_to_cell_weights_ptr,
recurrent_to_cell_weights_scale, recurrent_to_output_weights_ptr,
recurrent_to_output_weights_scale, cell_to_input_weights_ptr,
cell_to_input_weights_scale, cell_to_forget_weights_ptr,
cell_to_forget_weights_scale, cell_to_output_weights_ptr,
cell_to_output_weights_scale, input_layer_norm_coefficients_ptr,
forget_layer_norm_coefficients_ptr,
cell_layer_norm_coefficients_ptr,
output_layer_norm_coefficients_ptr, input_gate_bias_ptr,
forget_gate_bias_ptr, cell_bias_ptr, output_gate_bias_ptr,
projection_weights_ptr, projection_weights_scale,
projection_bias_ptr, params,
/*n_batch=*/1, n_cell, n_input, aux_input_size, n_output,
output_batch_leading_dim, input_gate_scratch_ptr,
forget_gate_scratch_ptr, cell_scratch_ptr, output_gate_scratch_ptr,
scaling_factors_ptr, prod_scaling_factors_ptr,
recovered_cell_weights_ptr, quantized_input_ptr,
quantized_aux_input_ptr, quantized_output_state_ptr,
quantized_cell_state_ptr, output_state_ptr, cell_state_ptr,
output_ptr);
}
}
}
return kTfLiteOk;
}
TfLiteStatus EvalQuantized(
const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights,
const TfLiteTensor* input_to_forget_weights,
const TfLiteTensor* input_to_cell_weights,
const TfLiteTensor* input_to_output_weights,
const TfLiteTensor* recurrent_to_input_weights,
const TfLiteTensor* recurrent_to_forget_weights,
const TfLiteTensor* recurrent_to_cell_weights,
const TfLiteTensor* recurrent_to_output_weights,
const TfLiteTensor* cell_to_input_weights,
const TfLiteTensor* cell_to_forget_weights,
const TfLiteTensor* cell_to_output_weights,
const TfLiteTensor* input_layer_norm_coefficients,
const TfLiteTensor* forget_layer_norm_coefficients,
const TfLiteTensor* cell_layer_norm_coefficients,
const TfLiteTensor* output_layer_norm_coefficients,
const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias,
const TfLiteTensor* cell_bias, const TfLiteTensor* output_gate_bias,
const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias,
const TfLiteLSTMParams* params,
const lstm_eval::QuantizedLstmParameter* quantized_lstm_param,
TfLiteTensor* activation_state, TfLiteTensor* cell_state,
TfLiteTensor* output, TfLiteTensor* scratch0, TfLiteTensor* scratch1,
TfLiteTensor* scratch2, TfLiteTensor* scratch3, TfLiteTensor* scratch4) {
TF_LITE_ASSERT(input->dims->size >= 2 && input->dims->size <= 3);
const int n_input = input->dims->data[input->dims->size - 1];
int max_time, n_batch;
if (input->dims->size == 2) {
max_time = 1;
n_batch = input->dims->data[0];
} else {
max_time = input->dims->data[0];
n_batch = input->dims->data[1];
}
// n_cell and n_output will be the same size when there is no projection.
const int n_cell = input_to_output_weights->dims->data[0];
const int n_output = recurrent_to_output_weights->dims->data[1];
// Since we have already checked that weights are all there or none, we can
// check the existence of only one to get the condition.
const bool use_cifg = (input_to_input_weights == nullptr);
const bool use_peephole = (cell_to_output_weights != nullptr);
const bool is_layer_norm_lstm = (forget_layer_norm_coefficients != nullptr);
const bool use_projection = (projection_weights != nullptr);
// Weights and states.
int8_t* input_to_input_weight_ptr = nullptr;
int8_t* recurrent_to_input_weight_ptr = nullptr;
int8_t* cell_to_input_weight_ptr = nullptr;
int8_t* input_to_forget_weight_ptr = nullptr;
int8_t* recurrent_to_forget_weight_ptr = nullptr;
int8_t* cell_to_forget_weight_ptr = nullptr;
int8_t* input_to_cell_weight_ptr = nullptr;
int8_t* recurrent_to_cell_weight_ptr = nullptr;
int8_t* input_to_output_weight_ptr = nullptr;
int8_t* recurrent_to_output_weight_ptr = nullptr;
int8_t* cell_to_output_weight_ptr = nullptr;
int8_t* proj_weight_ptr = nullptr;
int16_t* layer_norm_input_weight_ptr = nullptr;
int16_t* layer_norm_forget_weight_ptr = nullptr;
int16_t* layer_norm_cell_weight_ptr = nullptr;
int16_t* layer_norm_output_weight_ptr = nullptr;
int32_t* input_bias_ptr = nullptr;
int32_t* forget_bias_ptr = nullptr;
int32_t* cell_bias_ptr = nullptr;
int32_t* output_bias_ptr = nullptr;
int32_t* proj_bias_ptr = nullptr;
int16_t* cell_ptr = nullptr;
int8_t* activation_ptr = nullptr;
int8_t* output_ptr = nullptr;
// Zero points
int input_zp = 0;
int activation_zp = 0;
// Populate all the values.
if (!use_cifg) {
input_to_input_weight_ptr = input_to_input_weights->data.int8;
recurrent_to_input_weight_ptr = recurrent_to_input_weights->data.int8;
input_bias_ptr = input_gate_bias->data.i32;
}
if (use_peephole) {
if (!use_cifg) {
cell_to_input_weight_ptr = cell_to_input_weights->data.int8;
}
cell_to_forget_weight_ptr = cell_to_forget_weights->data.int8;
cell_to_output_weight_ptr = cell_to_output_weights->data.int8;
}
if (is_layer_norm_lstm) {
if (!use_cifg) {
layer_norm_input_weight_ptr = input_layer_norm_coefficients->data.i16;
}
layer_norm_forget_weight_ptr = forget_layer_norm_coefficients->data.i16;
layer_norm_cell_weight_ptr = cell_layer_norm_coefficients->data.i16;
layer_norm_output_weight_ptr = output_layer_norm_coefficients->data.i16;
}
if (use_projection) {
proj_weight_ptr = projection_weights->data.int8;
if (projection_bias) {
proj_bias_ptr = projection_bias->data.i32;
}
}
input_to_forget_weight_ptr = input_to_forget_weights->data.int8;
input_to_cell_weight_ptr = input_to_cell_weights->data.int8;
input_to_output_weight_ptr = input_to_output_weights->data.int8;
recurrent_to_forget_weight_ptr = recurrent_to_forget_weights->data.int8;
recurrent_to_cell_weight_ptr = recurrent_to_cell_weights->data.int8;
recurrent_to_output_weight_ptr = recurrent_to_output_weights->data.int8;
forget_bias_ptr = forget_gate_bias->data.i32;
cell_bias_ptr = cell_bias->data.i32;
output_bias_ptr = output_gate_bias->data.i32;
activation_ptr = activation_state->data.int8;
cell_ptr = cell_state->data.i16;
input_zp = input->params.zero_point;
activation_zp = activation_state->params.zero_point;
// Get params for time/batch/sequence.
const int output_batch_leading_dim =
output->dims->data[output->dims->size - 1];
const int input_step = n_batch * n_input;
const int output_step = n_batch * output_batch_leading_dim;
for (int t = 0; t < max_time; t++) {
const int t_rel = t;
output_ptr = output->data.int8 + t_rel * output_step;
// Input can be int8 asymmetric or int16 symmetric.
const int8_t* input_ptr = input->data.int8 + t_rel * input_step;
LstmStepQuantized(
input_ptr, input_zp, input_to_input_weight_ptr,
quantized_lstm_param->effective_input_to_input_scale_a,
quantized_lstm_param->effective_input_to_input_scale_b,
input_to_forget_weight_ptr,
quantized_lstm_param->effective_input_to_forget_scale_a,
quantized_lstm_param->effective_input_to_forget_scale_b,
input_to_cell_weight_ptr,
quantized_lstm_param->effective_input_to_cell_scale_a,
quantized_lstm_param->effective_input_to_cell_scale_b,
input_to_output_weight_ptr,
quantized_lstm_param->effective_input_to_output_scale_a,
quantized_lstm_param->effective_input_to_output_scale_b,
recurrent_to_input_weight_ptr,
quantized_lstm_param->effective_recurrent_to_input_scale_a,
quantized_lstm_param->effective_recurrent_to_input_scale_b,
recurrent_to_forget_weight_ptr,
quantized_lstm_param->effective_recurrent_to_forget_scale_a,
quantized_lstm_param->effective_recurrent_to_forget_scale_b,
recurrent_to_cell_weight_ptr,
quantized_lstm_param->effective_recurrent_to_cell_scale_a,
quantized_lstm_param->effective_recurrent_to_cell_scale_b,
recurrent_to_output_weight_ptr,
quantized_lstm_param->effective_recurrent_to_output_scale_a,
quantized_lstm_param->effective_recurrent_to_output_scale_b,
cell_to_input_weight_ptr,
quantized_lstm_param->effective_cell_to_input_scale_a,
quantized_lstm_param->effective_cell_to_input_scale_b,
cell_to_forget_weight_ptr,
quantized_lstm_param->effective_cell_to_forget_scale_a,
quantized_lstm_param->effective_cell_to_forget_scale_b,
cell_to_output_weight_ptr,
quantized_lstm_param->effective_cell_to_output_scale_a,
quantized_lstm_param->effective_cell_to_output_scale_b, proj_weight_ptr,
quantized_lstm_param->effective_proj_scale_a,
quantized_lstm_param->effective_proj_scale_b,
layer_norm_input_weight_ptr,
quantized_lstm_param->layer_norm_input_scale_a,
quantized_lstm_param->layer_norm_input_scale_b,
layer_norm_forget_weight_ptr,
quantized_lstm_param->layer_norm_forget_scale_a,
quantized_lstm_param->layer_norm_forget_scale_b,
layer_norm_cell_weight_ptr,
quantized_lstm_param->layer_norm_cell_scale_a,
quantized_lstm_param->layer_norm_cell_scale_b,
layer_norm_output_weight_ptr,
quantized_lstm_param->layer_norm_output_scale_a,
quantized_lstm_param->layer_norm_output_scale_b, input_bias_ptr,
forget_bias_ptr, cell_bias_ptr, output_bias_ptr, proj_bias_ptr,
quantized_lstm_param->quantized_cell_clip,
quantized_lstm_param->quantized_proj_clip,
quantized_lstm_param->inv_large_value.data(), n_batch, n_cell, n_input,
n_output, output_batch_leading_dim, activation_ptr, activation_zp,
cell_ptr, output_ptr, scratch0->data.i16, scratch1->data.i16,
scratch2->data.i16, scratch3->data.i16, scratch4->data.int8);
}
return kTfLiteOk;
}
} // namespace lstm_eval
} // namespace builtin
} // namespace ops
} // namespace tflite