blob: 006571d0b0a4961c32f28defdae83ada33586922 [file] [log] [blame]
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/mlir/lite/utils/lstm_utils.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/None.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/raw_ostream.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" // from @llvm-project
#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
#include "mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/Builders.h" // from @llvm-project
#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
#include "mlir/IR/BuiltinTypes.h" // from @llvm-project
#include "mlir/IR/Location.h" // from @llvm-project
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/IR/OpDefinition.h" // from @llvm-project
#include "mlir/IR/Operation.h" // from @llvm-project
#include "mlir/IR/Types.h" // from @llvm-project
#include "mlir/IR/Value.h" // from @llvm-project
#include "mlir/Support/LLVM.h" // from @llvm-project
#include "mlir/Support/LogicalResult.h" // from @llvm-project
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
namespace mlir {
namespace TFL {
namespace {
Value CreateI32SplatConst(OpBuilder* builder, ArrayRef<int64_t> shape,
int32_t val, mlir::Location location) {
auto type = RankedTensorType::get(shape, builder->getIntegerType(32));
auto attr = DenseElementsAttr::get(type, val);
return builder->create<arith::ConstantOp>(location, type, attr);
}
Value CreateF32SplatConst(OpBuilder* builder, ArrayRef<int64_t> shape,
float val, mlir::Location location) {
auto type = RankedTensorType::get(shape, builder->getF32Type());
auto attr = DenseElementsAttr::get(type, val);
return builder->create<arith::ConstantOp>(location, type, attr);
}
Value CreatTfF32ConstOp(OpBuilder* builder, ArrayRef<int64_t> shape, float val,
mlir::Location location) {
auto type = RankedTensorType::get(shape, builder->getF32Type());
auto ele_type = RankedTensorType::get({1}, builder->getF32Type());
auto attr = DenseElementsAttr::get(ele_type, val);
return builder->create<TF::ConstOp>(location, type, attr);
}
Value CreateI64DenseConst(OpBuilder* builder, ArrayRef<int64_t> shape,
ArrayRef<int64_t> values, mlir::Location location) {
auto type = RankedTensorType::get(static_cast<int>(shape.size()),
builder->getIntegerType(64));
auto attr = DenseElementsAttr::get(type, values);
return builder->create<arith::ConstantOp>(location, type, attr);
}
Value CreateI32DenseConst(OpBuilder* builder, ArrayRef<int32_t> values,
mlir::Location location) {
auto type = RankedTensorType::get(static_cast<int>(values.size()),
builder->getIntegerType(32));
auto attr = DenseElementsAttr::get(type, values);
return builder->create<arith::ConstantOp>(location, type, attr);
}
Value CreateNoneValue(OpBuilder* builder, mlir::Location location) {
return builder->create<TFL::NoValueOp>(location, builder->getNoneType(),
builder->getUnitAttr());
}
Value Transpose(OpBuilder* builder, Value value_to_transpose,
SmallVector<int32_t, 4> perm, RankedTensorType original_type,
mlir::Location location) {
// Create a constant op for transpose permutation.
auto perm_op = CreateI32DenseConst(builder, perm, location);
// Create tensor type for the transpose result.
auto transpose_type = original_type;
auto transpose_shape =
llvm::to_vector<8>(llvm::map_range(perm, [transpose_type](int32_t dim) {
return transpose_type.getDimSize(dim);
}));
auto elem_type = transpose_type.getElementType();
auto result_type = RankedTensorType::get(transpose_shape, elem_type);
return builder->create<TF::TransposeOp>(location, result_type,
value_to_transpose, perm_op);
}
Value Transpose2D(OpBuilder* builder, Value value_to_transpose,
RankedTensorType type, mlir::Location location) {
// Create a constant op for transpose permutation.
SmallVector<int32_t, 4> perm = {1, 0};
return Transpose(builder, value_to_transpose, perm, type, location);
}
Value Reverse(OpBuilder* builder, Value value_to_reverse, int axis,
RankedTensorType type, mlir::Location location) {
auto axis_op = CreateI32SplatConst(builder, {1}, axis, location);
// The result type will be the same as the input.
return builder->create<TF::ReverseV2Op>(location, type, value_to_reverse,
axis_op);
}
ArrayRef<int64_t> GetRankedTensorShape(Value value) {
return value.getType().cast<RankedTensorType>().getShape();
}
Value SliceRankedTensor(OpBuilder* builder, Value input,
ArrayRef<int64_t> begin_shape,
ArrayRef<int64_t> begin_values,
ArrayRef<int64_t> size_shape,
ArrayRef<int64_t> size_values,
mlir::Location location) {
// If the size of the tensor to be sliced from the input overflows
// the input tensor's dimensions, return 0-valued tensor of the requested
// shape.
ArrayRef<int64_t> input_shape = GetRankedTensorShape(input);
for (int i = 0, end = input_shape.size(); i < end; i++) {
if (begin_values[i] < 0 ||
(begin_values[i] + size_values[i] > input_shape[i])) {
return CreateF32SplatConst(builder, size_shape, 0, location);
}
}
// Create a dense constant op for slice's begin
auto slice_i2c_begin =
CreateI64DenseConst(builder, begin_shape, begin_values, location);
// Create a dense constant op for slice's size
auto slice_i2c_size =
CreateI64DenseConst(builder, size_shape, size_values, location);
return builder->create<TF::SliceOp>(
location,
RankedTensorType::get(
size_values,
input.getType().cast<RankedTensorType>().getElementType()),
input, slice_i2c_begin, slice_i2c_size);
}
Value CreateStridedSliceOp(mlir::Location loc, ArrayRef<int64_t> output_shape,
Value input, ArrayRef<int32_t> begin,
ArrayRef<int32_t> end, ArrayRef<int32_t> strides,
int64_t begin_mask, int64_t end_mask,
int64_t ellipsis_mask, int64_t new_axis_mask,
int64_t shrink_axis_mask, OpBuilder* builder) {
auto output_type = RankedTensorType::get(
output_shape, input.getType().cast<RankedTensorType>().getElementType());
auto begin_tensor = CreateI32DenseConst(builder, begin, loc);
auto end_tensor = CreateI32DenseConst(builder, end, loc);
auto strides_tensor = CreateI32DenseConst(builder, strides, loc);
return builder->create<TF::StridedSliceOp>(
loc, output_type, input, begin_tensor, end_tensor, strides_tensor,
builder->getI64IntegerAttr(begin_mask),
builder->getI64IntegerAttr(end_mask),
builder->getI64IntegerAttr(ellipsis_mask),
builder->getI64IntegerAttr(new_axis_mask),
builder->getI64IntegerAttr(shrink_axis_mask));
}
} // namespace
void ConvertLSTMCellSimpleToFusedLSTM::SetWeightForInputToCellGate() {
SmallVector<int64_t, 2> begin_i2c_values = {0, 0};
input2cell_ = SliceRankedTensor(
&builder_, weight_transposed_, weight_slice_shape_, begin_i2c_values,
weight_slice_shape_, weight_slice_size_input_values_,
fused_func_op_.getLoc());
}
void ConvertLSTMCellSimpleToFusedLSTM::SetWeightForInputToInputGate() {
SmallVector<int64_t, 2> begin_i2i_values = {n_cell_, 0};
input2input_ = couple_input_forget_gates_
? none_
: SliceRankedTensor(&builder_, weight_transposed_,
weight_slice_shape_, begin_i2i_values,
weight_slice_shape_,
weight_slice_size_input_values_,
fused_func_op_.getLoc());
}
void ConvertLSTMCellSimpleToFusedLSTM::SetWeightForInputToForgetGate() {
int input_forget_start = couple_input_forget_gates_ ? n_cell_ : 2 * n_cell_;
SmallVector<int64_t, 2> begin_i2f_values = {input_forget_start, 0};
input2forget_ = SliceRankedTensor(
&builder_, weight_transposed_, weight_slice_shape_, begin_i2f_values,
weight_slice_shape_, weight_slice_size_input_values_,
fused_func_op_.getLoc());
}
void ConvertLSTMCellSimpleToFusedLSTM::SetWeightForInputToOutputGate() {
int input_output_start =
couple_input_forget_gates_ ? 2 * n_cell_ : 3 * n_cell_;
SmallVector<int64_t, 2> begin_i2o_values = {input_output_start, 0};
input2output_ = SliceRankedTensor(
&builder_, weight_transposed_, weight_slice_shape_, begin_i2o_values,
weight_slice_shape_, weight_slice_size_input_values_,
fused_func_op_.getLoc());
}
void ConvertLSTMCellSimpleToFusedLSTM::SetWeightForRecurrentToCellGate() {
SmallVector<int64_t, 2> begin_rec2c_values = {0, n_input_};
rec2cell_ = SliceRankedTensor(
&builder_, weight_transposed_, weight_slice_shape_, begin_rec2c_values,
weight_slice_shape_, weight_slice_size_recurrent_values_,
fused_func_op_.getLoc());
}
void ConvertLSTMCellSimpleToFusedLSTM::SetWeightForRecurrentToInputGate() {
SmallVector<int64_t, 2> begin_rec2i_values = {n_cell_, n_input_};
rec2input_ = couple_input_forget_gates_
? none_
: SliceRankedTensor(&builder_, weight_transposed_,
weight_slice_shape_, begin_rec2i_values,
weight_slice_shape_,
weight_slice_size_recurrent_values_,
fused_func_op_.getLoc());
}
void ConvertLSTMCellSimpleToFusedLSTM::SetWeightForRecurrentToForgetGate() {
int rec_forget_start = couple_input_forget_gates_ ? n_cell_ : 2 * n_cell_;
SmallVector<int64_t, 2> begin_rec2f_values = {rec_forget_start, n_input_};
rec2forget_ = SliceRankedTensor(
&builder_, weight_transposed_, weight_slice_shape_, begin_rec2f_values,
weight_slice_shape_, weight_slice_size_recurrent_values_,
fused_func_op_.getLoc());
}
void ConvertLSTMCellSimpleToFusedLSTM::SetWeightForRecurrentToOutputGate() {
int rec_output_start = couple_input_forget_gates_ ? 2 * n_cell_ : 3 * n_cell_;
SmallVector<int64_t, 2> begin_rec2o_values = {rec_output_start, n_input_};
rec2output_ = SliceRankedTensor(
&builder_, weight_transposed_, weight_slice_shape_, begin_rec2o_values,
weight_slice_shape_, weight_slice_size_recurrent_values_,
fused_func_op_.getLoc());
}
void ConvertLSTMCellSimpleToFusedLSTM::SetBiasToCellGate() {
SmallVector<int64_t, 1> begin_bias2c_values = {0};
bias2cell_ = SliceRankedTensor(&builder_, bias_, bias_slice_shape_,
begin_bias2c_values, bias_slice_shape_,
bias_size_values_, fused_func_op_.getLoc());
}
void ConvertLSTMCellSimpleToFusedLSTM::SetBiasToInputGate() {
SmallVector<int64_t, 1> begin_bias2i_values = {n_cell_};
bias2input_ =
couple_input_forget_gates_
? none_
: SliceRankedTensor(&builder_, bias_, bias_slice_shape_,
begin_bias2i_values, bias_slice_shape_,
bias_size_values_, fused_func_op_.getLoc());
}
void ConvertLSTMCellSimpleToFusedLSTM::SetBiasToForgetGate() {
int bias_forget_start = couple_input_forget_gates_ ? n_cell_ : 2 * n_cell_;
SmallVector<int64_t, 1> begin_bias2f_values = {bias_forget_start};
bias2forget_ = SliceRankedTensor(&builder_, bias_, bias_slice_shape_,
begin_bias2f_values, bias_slice_shape_,
bias_size_values_, fused_func_op_.getLoc());
}
void ConvertLSTMCellSimpleToFusedLSTM::SetBiasToOutputGate() {
int bias_output_start =
couple_input_forget_gates_ ? 2 * n_cell_ : 3 * n_cell_;
SmallVector<int64_t, 1> begin_bias2o_values = {bias_output_start};
bias2output_ = SliceRankedTensor(&builder_, bias_, bias_slice_shape_,
begin_bias2o_values, bias_slice_shape_,
bias_size_values_, fused_func_op_.getLoc());
}
void ConvertLSTMCellSimpleToFusedLSTM::SetProjection() {
SmallVector<int64_t, 2> projection_slice_shape = {
1, num_cols_projection_transposed_};
SmallVector<int64_t, 2> projection_slice_size_values = {n_output_, n_cell_};
SmallVector<int64_t, 2> projection_slice_begin_values = {0, 0};
proj_weight_ =
!projection_
? none_
: SliceRankedTensor(
&builder_, projection_transposed_, projection_slice_shape,
projection_slice_begin_values, projection_slice_shape,
projection_slice_size_values, fused_func_op_.getLoc());
}
void ConvertLSTMCellSimpleToFusedLSTM::SetProjectionBias() {
proj_bias_ = !projection_type_
? none_
: CreateF32SplatConst(&builder_, {n_output_}, 0,
fused_func_op_.getLoc());
}
void ConvertLSTMCellSimpleToFusedLSTM::SetInputActivationState() {
input_activation_state_ = CreateF32SplatConst(&builder_, {1, n_output_}, 0,
fused_func_op_.getLoc());
}
void ConvertLSTMCellSimpleToFusedLSTM::SetInputCellState() {
input_cell_state_ =
CreateF32SplatConst(&builder_, {1, n_cell_}, 0, fused_func_op_.getLoc());
}
void ConvertLSTMCellSimpleToFusedLSTM::SetCellLayerNormCoefficients() {
cell_layer_norm_coefficients_ = none_;
}
void ConvertLSTMCellSimpleToFusedLSTM::SetInputLayerNormCoefficients() {
input_layer_norm_coefficients_ = none_;
}
void ConvertLSTMCellSimpleToFusedLSTM::SetForgetLayerNormCoefficients() {
forget_layer_norm_coefficients_ = none_;
}
void ConvertLSTMCellSimpleToFusedLSTM::SetOutputLayerNormCoefficients() {
output_layer_norm_coefficients_ = none_;
}
void ConvertLSTMCellSimpleToFusedLSTM::GenerateFusedOpOperands() {
// Transpose both weight and projection.
weight_transposed_ =
Transpose2D(&builder_, weight_, weight_type_, fused_func_op_.getLoc());
projection_transposed_ = Transpose2D(&builder_, projection_, projection_type_,
fused_func_op_.getLoc());
none_ = CreateNoneValue(&builder_, fused_func_op_.getLoc());
// Extract input to cifg gates via slicing the weight tensor
SetWeightForInputToCellGate();
SetWeightForInputToInputGate();
SetWeightForInputToForgetGate();
SetWeightForInputToOutputGate();
// Extract recurrent to cifg gates via slicing the weight tensor
SetWeightForRecurrentToCellGate();
SetWeightForRecurrentToInputGate();
SetWeightForRecurrentToForgetGate();
SetWeightForRecurrentToOutputGate();
// Extract bias to cifg gates via slicing the bias tensor
SetBiasToCellGate();
SetBiasToInputGate();
SetBiasToForgetGate();
SetBiasToOutputGate();
// Extract projection and set an empty projection bias
SetProjection();
SetProjectionBias();
// Set the variable tensors
SetInputActivationState();
SetInputCellState();
// Extract the layer norm coefficients
SetCellLayerNormCoefficients();
SetInputLayerNormCoefficients();
SetForgetLayerNormCoefficients();
SetOutputLayerNormCoefficients();
}
void ConvertLSTMCellSimpleToFusedLSTM::UpdateFuncSignature() {
// https://github.com/tensorflow/community/pull/113
SmallVector<int64_t, 2> output_shape{1, -1};
auto input_types = fused_func_op_.getFunctionType().getInputs();
auto output_type = mlir::RankedTensorType::get(
output_shape, input_.getType().cast<RankedTensorType>().getElementType());
fused_func_op_.setType(mlir::FunctionType::get(fused_func_op_.getContext(),
input_types, output_type));
}
LogicalResult ConvertLSTMCellSimpleToFusedLSTM::RewriteFunc() {
LogicalResult result = Initialize();
if (failed(result)) {
return result;
}
// Update the func signature, based on output shape.
// The func will ultimately return the output of the fused
// LSTM op.
UpdateFuncSignature();
// Transform the weights, projection, bias and layer norm coefficients
// to generate operands for the TFL fused LSTM op.
GenerateFusedOpOperands();
// Create the fused LSTM op.
SmallVector<int64_t, 2> output_shape = {1, n_output_};
auto result_type = mlir::RankedTensorType::get(
output_shape, input_.getType().cast<RankedTensorType>().getElementType());
lstm_ = builder_.create<mlir::TFL::LSTMOp>(
fused_func_op_.getLoc(), result_type, input_, input2input_, input2forget_,
input2cell_, input2output_, rec2input_, rec2forget_, rec2cell_,
rec2output_, /*cell_to_input_weights*/ none_,
/*cell_to_forget_weights*/ none_,
/*cell_to_output_weights*/ none_, bias2input_, bias2forget_, bias2cell_,
bias2output_, proj_weight_, proj_bias_, input_activation_state_,
input_cell_state_, input_layer_norm_coefficients_,
forget_layer_norm_coefficients_, cell_layer_norm_coefficients_,
output_layer_norm_coefficients_, builder_.getStringAttr("TANH"),
builder_.getF32FloatAttr(10.0), builder_.getF32FloatAttr(0.0),
mlir::TFL::LSTMKernelTypeAttr::get(builder_.getContext(),
mlir::TFL::LSTMKernelType::FULL),
/*asymmetric_quantize_inputs=*/mlir::BoolAttr(),
/*input_to_input_intermediate=*/mlir::TypeAttr(),
/*input_to_forget_intermediate=*/mlir::TypeAttr(),
/*input_to_cell_intermediate=*/mlir::TypeAttr(),
/*input_to_output_intermediate=*/mlir::TypeAttr(),
/*effective_hidden_scale_intermediate=*/mlir::TypeAttr());
// Cast the static shaped lstm result to FuncOp's signature -
// Ranked but unknown 2nd dimension to support stacking these.
SmallVector<int64_t, 2> func_output_shape = {1, -1};
auto func_result_type = mlir::RankedTensorType::get(
func_output_shape,
input_.getType().cast<RankedTensorType>().getElementType());
auto tensor_cast = builder_.create<mlir::tensor::CastOp>(
fused_func_op_.getLoc(), func_result_type, lstm_.getResult());
builder_.create<mlir::func::ReturnOp>(fused_func_op_.getLoc(),
tensor_cast.getResult());
return success();
}
LogicalResult ConvertLSTMCellSimpleToFusedLSTM::InitializeFromFuncAttributes() {
auto attr = fused_func_op_->getAttrOfType<StringAttr>(kTFImplements);
if (!attr) {
return fused_func_op_.emitError()
<< "Invalid function attribute, expected " << kTFImplements
<< " attribute "
"not found";
}
// TODO(ashwinm, b/144775479): Make these NamedAttribute on TF import
// once tf.function can support this.
llvm::SmallVector<llvm::StringRef, 4> attr_tokens;
attr.getValue().split(attr_tokens, ",");
if (attr_tokens.empty()) {
return fused_func_op_.emitError()
<< kTFImplements << " attribute should be set";
}
// Check if the interface matches.
if (GetCompositeOpName().str() != attr_tokens[0]) {
return fused_func_op_.emitError()
<< "Unexpected interface for the composite op. Expected: "
<< GetCompositeOpName() << " Actual: " << attr_tokens[0];
}
// Extract other interface attributes, for now cifg.
couple_input_forget_gates_ =
std::find(attr_tokens.begin() + 1, attr_tokens.end(),
kCoupleInputForgetGates) != attr_tokens.end();
return success();
}
LogicalResult ConvertLSTMCellSimpleToFusedLSTM::Initialize() {
if (failed(InitializeFromFuncAttributes())) {
return fused_func_op_.emitError()
<< "Expected function attributes were not set on the function "
"encapsulating the composite op";
}
num_gates_ = couple_input_forget_gates_ ? 3 : 4;
input_ = fused_func_op_.getArgument(0);
bias_ = fused_func_op_.getArgument(2);
weight_ = fused_func_op_.getArgument(1);
weight_type_ = weight_.getType().cast<RankedTensorType>();
if (weight_type_.getRank() != 2) {
return fused_func_op_.emitError() << "The weight tensor was not of rank 2";
}
if (weight_type_.getDimSize(1) % num_gates_ != 0) {
return fused_func_op_.emitError()
<< "Invalid dimension 1 of weight tensor, "
"should be divisible by the number of gates";
}
n_cell_ = weight_type_.getDimSize(1) / num_gates_;
projection_ = fused_func_op_.getArgument(3);
projection_type_ = projection_.getType().cast<RankedTensorType>();
if (projection_type_.getRank() != 2) {
n_output_ = n_cell_;
} else {
n_output_ = projection_type_.getDimSize(1);
}
n_input_ = weight_type_.getDimSize(0) - n_output_;
num_cols_weight_transposed_ = weight_type_.getDimSize(0);
num_cols_projection_transposed_ = projection_type_.getDimSize(0);
bias_slice_shape_ = {n_cell_};
bias_size_values_ = {n_cell_};
weight_slice_shape_ = {1, num_cols_weight_transposed_};
weight_slice_size_input_values_ = {n_cell_, n_input_};
weight_slice_size_recurrent_values_ = {n_cell_, n_output_};
return success();
}
LogicalResult ConvertLayerNormalizedLSTMCellSimpleToFusedLSTM::Initialize() {
if (failed(ConvertLSTMCellSimpleToFusedLSTM::Initialize())) {
return fused_func_op_.emitError()
<< "Specified LayerNormalizedLSTMCellSimple was not of the expected "
"interface and cannot not be converted to the fused LSTM op";
}
layer_norm_scale_ = fused_func_op_.getArgument(4);
layer_norm_scale_type_ = layer_norm_scale_.getType().cast<RankedTensorType>();
if (layer_norm_scale_type_.getRank() != 1) {
return fused_func_op_.emitError()
<< "The layer_norm_scale tensor was not of rank 1";
}
layer_norm_slice_shape_ = {n_cell_};
layer_norm_size_values_ = {n_cell_};
return success();
}
void ConvertLayerNormalizedLSTMCellSimpleToFusedLSTM::
SetCellLayerNormCoefficients() {
SmallVector<int64_t, 1> begin_cell_layer_norm_values = {0};
cell_layer_norm_coefficients_ =
SliceRankedTensor(&builder_, layer_norm_scale_, layer_norm_slice_shape_,
begin_cell_layer_norm_values, layer_norm_slice_shape_,
layer_norm_size_values_, fused_func_op_.getLoc());
}
void ConvertLayerNormalizedLSTMCellSimpleToFusedLSTM::
SetInputLayerNormCoefficients() {
SmallVector<int64_t, 1> begin_input_layer_norm_values = {n_cell_};
input_layer_norm_coefficients_ =
couple_input_forget_gates_
? none_
: SliceRankedTensor(
&builder_, layer_norm_scale_, layer_norm_slice_shape_,
begin_input_layer_norm_values, layer_norm_slice_shape_,
layer_norm_size_values_, fused_func_op_.getLoc());
}
void ConvertLayerNormalizedLSTMCellSimpleToFusedLSTM::
SetForgetLayerNormCoefficients() {
SmallVector<int64_t, 1> begin_forget_layer_norm_values = {2 * n_cell_};
forget_layer_norm_coefficients_ =
SliceRankedTensor(&builder_, layer_norm_scale_, layer_norm_slice_shape_,
begin_forget_layer_norm_values, layer_norm_slice_shape_,
layer_norm_size_values_, fused_func_op_.getLoc());
}
void ConvertLayerNormalizedLSTMCellSimpleToFusedLSTM::
SetOutputLayerNormCoefficients() {
SmallVector<int64_t, 1> begin_output_layer_norm_values = {3 * n_cell_};
output_layer_norm_coefficients_ =
SliceRankedTensor(&builder_, layer_norm_scale_, layer_norm_slice_shape_,
begin_output_layer_norm_values, layer_norm_slice_shape_,
layer_norm_size_values_, fused_func_op_.getLoc());
}
TF::ConstOp Create1DConstantOp(const std::vector<int>& value, Location loc,
OpBuilder* builder) {
auto type =
mlir::RankedTensorType::get(value.size(), builder->getIntegerType(32));
auto dense_values = mlir::DenseIntElementsAttr::get(type, value);
return builder->create<TF::ConstOp>(loc, dense_values);
}
TF::ConstOp CreateScalarConstantOp(int value, Location loc,
OpBuilder* builder) {
return builder->create<TF::ConstOp>(loc, builder->getI32IntegerAttr(value));
}
LogicalResult CreateEqualSizeSplitVOp(Value input, int axis, int splits,
Location loc, OpBuilder* builder,
Operation** result) {
auto input_type = input.getType().cast<RankedTensorType>();
SmallVector<int64_t, 4> output_shape;
int size_of_splits;
if (input_type.getRank() < axis || axis < 0) return failure();
for (int i = 0; i < input_type.getRank(); ++i) {
int dim = input_type.getDimSize(i);
if (i == axis) {
if (dim % splits != 0) {
return failure();
}
size_of_splits = dim / splits;
output_shape.push_back(size_of_splits);
} else {
output_shape.push_back(dim);
}
}
SmallVector<mlir::Type, 4> output_types;
for (int i = 0; i < splits; ++i) {
output_types.push_back(
mlir::RankedTensorType::get(output_shape, input_type.getElementType()));
}
auto size_of_splits_op = Create1DConstantOp(
{size_of_splits, size_of_splits, size_of_splits, size_of_splits}, loc,
builder);
auto axis_op = CreateScalarConstantOp(axis, loc, builder);
*result = builder->create<TF::SplitVOp>(loc, output_types, input,
size_of_splits_op.getResult(),
axis_op.getResult());
return success();
}
// TODO(b/147436982): Consider refactor this to be more general.
LogicalResult ConvertKerasLSTMLayer(mlir::func::FuncOp func_op,
OpBuilder* builder) {
// For argument order, please check out standard_lstm under
// tensorflow/python/keras/layers/recurrent_v2.py
Value input = func_op.getArgument(0);
Value output_init_state = func_op.getArgument(1);
Value hidden_init_state = func_op.getArgument(2);
Value weight_kernel = func_op.getArgument(3);
Value recurrent_kernel = func_op.getArgument(4);
Value bias = func_op.getArgument(5);
// The func op should have 5 outputs.
if (func_op.getNumResults() != 5) return failure();
// TFL lstm only supports time-majored inputs, so if it's not time-majored,
// we will transpose the inputs and outputs.
auto time_major_attr = func_op->getAttrOfType<BoolAttr>("tf.time_major");
if (time_major_attr == nullptr) return failure();
bool time_majored = time_major_attr.getValue();
auto input_type = input.getType().dyn_cast_or_null<RankedTensorType>();
if (!input_type) {
func_op.emitError() << "Input type is not a ranked tensor type";
return failure();
}
auto final_inputs = input;
auto final_input_type = input_type;
// Handle go_backwards:
// LSTM in Keras semantic will reverse the input sequence if it's go_backwards
auto go_backwards_attr = func_op->getAttrOfType<BoolAttr>("tf.go_backwards");
if (go_backwards_attr != nullptr && go_backwards_attr.getValue()) {
int time_dim = time_majored ? 0 : 1;
final_inputs = Reverse(builder, final_inputs, time_dim, final_input_type,
func_op.getLoc());
}
int batch = time_majored ? final_input_type.getDimSize(1)
: final_input_type.getDimSize(0);
int time = time_majored ? final_input_type.getDimSize(0)
: final_input_type.getDimSize(1);
// Setup correct weights.
RankedTensorType weight_type =
weight_kernel.getType().cast<RankedTensorType>();
if (weight_type.getRank() != 2)
return func_op.emitError() << "The weight should be rank of 2";
Value transposed_weight_kernel =
Transpose2D(builder, weight_kernel, weight_type, func_op.getLoc());
RankedTensorType recurrent_kernel_type =
recurrent_kernel.getType().cast<RankedTensorType>();
const int n_output = recurrent_kernel_type.getDimSize(0);
Value transpose_recurrent_kernel = Transpose2D(
builder, recurrent_kernel, recurrent_kernel_type, func_op.getLoc());
// Splits the weights into 4: i, f, c, o.
const int splits = 4;
Operation* weights_array;
if (failed(CreateEqualSizeSplitVOp(transposed_weight_kernel, 0, splits,
func_op.getLoc(), builder,
&weights_array)))
return failure();
// Splits the recurrent_weights into 4:
Operation* recurrent_weights_array;
if (failed(CreateEqualSizeSplitVOp(transpose_recurrent_kernel, 0, splits,
func_op.getLoc(), builder,
&recurrent_weights_array)))
return failure();
// Splits the bias into 4:
Operation* bias_array;
if (failed(CreateEqualSizeSplitVOp(bias, 0, splits, func_op.getLoc(), builder,
&bias_array)))
return failure();
// Build the lstm op.
SmallVector<int64_t, 3> output_shape;
if (time_majored) {
output_shape = {time, batch, n_output};
} else {
output_shape = {batch, time, n_output};
}
auto result_type = mlir::RankedTensorType::get(
output_shape,
final_inputs.getType().cast<RankedTensorType>().getElementType());
Value none = CreateNoneValue(builder, func_op.getLoc());
auto lstm = builder->create<mlir::TFL::UnidirectionalSequenceLSTMOp>(
func_op.getLoc(), result_type, /*input=*/final_inputs,
/*input_to_input_weights=*/weights_array->getResult(0),
/*input_to_forget_weights=*/weights_array->getResult(1),
/*input_to_cell_weights=*/weights_array->getResult(2),
/*input_to_output_weights=*/weights_array->getResult(3),
/*recurrent_to_input_weights=*/recurrent_weights_array->getResult(0),
/*recurrent_to_forget_weights=*/recurrent_weights_array->getResult(1),
/*recurrent_to_cell_weights=*/recurrent_weights_array->getResult(2),
/*recurrent_to_output_weights=*/recurrent_weights_array->getResult(3),
/*cell_to_input_weights=*/none,
/*cell_to_forget_weights=*/none,
/*cell_to_output_weights=*/none,
/*input_gate_bias=*/bias_array->getResult(0),
/*forget_gate_bias=*/bias_array->getResult(1),
/*cell_bias=*/bias_array->getResult(2),
/*output_gate_bias=*/bias_array->getResult(3),
/*projection_weights=*/none,
/*projection_bias=*/none,
/*input_activation_state=*/output_init_state,
/*input_cell_state=*/hidden_init_state,
/*input_layer_norm_coefficients=*/none,
/*forget_layer_norm_coefficients=*/none,
/*cell_layer_norm_coefficients=*/none,
/*output_layer_norm_coefficients=*/none,
/*fused_activation_function*/ builder->getStringAttr("TANH"),
/*cell_clip*/ builder->getF32FloatAttr(10.0),
/*proj_clip*/ builder->getF32FloatAttr(0.0),
/*time_major*/ builder->getBoolAttr(time_majored),
/*asymmetric_quantize_inputs=*/mlir::BoolAttr(),
/*input_to_input_intermediate=*/mlir::TypeAttr(),
/*input_to_forget_intermediate=*/mlir::TypeAttr(),
/*input_to_cell_intermediate=*/mlir::TypeAttr(),
/*input_to_output_intermediate=*/mlir::TypeAttr(),
/*effective_hidden_scale_intermediate=*/mlir::TypeAttr());
auto final_output_full_sequences = lstm.getResult();
// Populate the last output: last output is sliced from the full sequences.
// If time_major: last_output = outputs[-1, :, :]
// else: last_output = outputs[:, -1, :]
//
// As we are creating the strided_slice op, we need to populate the following
// fields:
// end: should always be (0, 0, 0)
// strides: should always be (1, 1, 1)
// begin: should be (0, -1, 0) or (-1, 0, 0) if it's time-majored.
// new_axis_mask: should always be 0.
// ellipsis_mask: should always be 0.
// begin_mask & end_mask: should be 0b101 = 5 or 0b110 = 4 if it's
// time-majored. shrink_axis_mask: should be 0b010 = 2 or 0b001 = 1 if it's
// time-majored.
SmallVector<int64_t, 2> last_output_shape({batch, n_output});
SmallVector<int32_t, 3> end({0, 0, 0});
SmallVector<int32_t, 3> strides({1, 1, 1});
SmallVector<int32_t, 3> begin;
int64_t new_axis_mask = 0;
int64_t ellipsis_mask = 0;
int64_t begin_mask;
int64_t end_mask;
int64_t shrink_axis_mask;
if (time_majored) {
begin_mask = 6;
end_mask = 6;
shrink_axis_mask = 1;
begin = {-1, 0, 0};
} else {
begin_mask = 5;
end_mask = 5;
shrink_axis_mask = 2;
begin = {0, -1, 0};
}
auto last_output = CreateStridedSliceOp(
func_op.getLoc(), last_output_shape, final_output_full_sequences, begin,
end, strides, begin_mask, end_mask, ellipsis_mask, new_axis_mask,
shrink_axis_mask, builder);
SmallVector<Value, 5> outputs;
SmallVector<Type, 5> output_types;
// Due to the existence of the while loop, the timestamp may be unknown
// for the signature, for us, since we know the inputs, we can infer the time
// steps.
// Last output.
outputs.push_back(last_output);
output_types.push_back(last_output.getType());
// Full sequences.
outputs.push_back(final_output_full_sequences);
output_types.push_back(final_output_full_sequences.getType());
// All the rest: states, device.
for (int i = 2; i < 5; ++i) {
auto result_type =
func_op.getCallableResults()[i].dyn_cast<RankedTensorType>();
outputs.push_back(CreatTfF32ConstOp(builder, result_type.getShape(), 0.0f,
func_op.getLoc()));
output_types.push_back(result_type);
}
// Update function signatures.
func_op.setType(mlir::FunctionType::get(func_op.getContext(),
func_op.getFunctionType().getInputs(),
output_types));
builder->create<mlir::func::ReturnOp>(func_op.getLoc(), outputs);
return success();
}
} // namespace TFL
} // namespace mlir