blob: 4268d484fc06f5d9703f7a13e19d54ee5699fc17 [file] [log] [blame]
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/mlir/tensorflow/transforms/collection_ops_util.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/Optional.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/Casting.h"
#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/Builders.h" // from @llvm-project
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
#include "mlir/IR/BuiltinTypes.h" // from @llvm-project
#include "mlir/IR/Location.h" // from @llvm-project
#include "mlir/IR/Operation.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/Support/LLVM.h" // from @llvm-project
#include "mlir/Support/LogicalResult.h" // from @llvm-project
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
namespace mlir {
namespace TF {
namespace collection_ops_util {
Value CreateScalarConst(int32_t value, OpBuilder builder, Location loc) {
auto attr = DenseIntElementsAttr::get(
RankedTensorType::get({}, builder.getI32Type()), value);
return builder.create<TF::ConstOp>(loc, attr);
}
Value GetR1Const(ArrayRef<int64_t> r1, OpBuilder builder, Location loc,
int bitwidth) {
llvm::SmallVector<APInt, 4> values;
int64_t rank = r1.size();
values.reserve(rank);
for (int i = 0; i < rank; ++i) values.push_back(APInt(bitwidth, r1[i]));
auto result_type = RankedTensorType::get(
{rank}, IntegerType::get(builder.getContext(), bitwidth));
return builder.create<TF::ConstOp>(
loc, DenseElementsAttr::get(result_type, values));
}
Value GetIndicesForElement(Value index, Value buffer, OpBuilder builder,
Location loc) {
auto buffer_type = buffer.getType().cast<RankedTensorType>();
if (buffer_type.getShape().size() == 1) return index;
// Create a concat of index and trailing zeros.
llvm::SmallVector<int64_t, 8> zeros(buffer_type.getShape().size() - 1, 0);
auto zeros_tensor = GetR1Const(zeros, builder, loc);
return builder.create<TF::ConcatV2Op>(
loc,
ArrayRef<Type>{RankedTensorType::get(
{static_cast<int64_t>(buffer_type.getShape().size())},
getElementTypeOrSelf(index.getType()))},
ArrayRef<Value>{index, zeros_tensor, CreateScalarConst(0, builder, loc)});
}
Value GetElement(Value index, Value buffer, OpBuilder builder, Location loc,
bool keep_slice_shape) {
auto buffer_type = buffer.getType().cast<RankedTensorType>();
// Create a slice then reshape to remove the leading trivial dimension of
// size 1.
llvm::SmallVector<int64_t, 8> slice_size =
llvm::to_vector<8>(buffer_type.getShape());
slice_size[0] = 1;
auto size_const = GetR1Const(slice_size, builder, loc);
auto slice_type =
RankedTensorType::get(slice_size, buffer_type.getElementType());
auto slice = builder.create<TF::SliceOp>(
loc, ArrayRef<Type>{slice_type},
ArrayRef<Value>{buffer, GetIndicesForElement(index, buffer, builder, loc),
size_const});
if (keep_slice_shape) return slice;
auto element_type = RankedTensorType::get(buffer_type.getShape().drop_front(),
buffer_type.getElementType());
auto reshape = builder.create<TF::ReshapeOp>(
loc, ArrayRef<Type>{element_type},
ArrayRef<Value>{slice,
GetR1Const(element_type.getShape(), builder, loc)});
return reshape.output();
}
Value SetElement(Value index, Value buffer, Value element, OpBuilder builder,
Location loc) {
auto buffer_type = buffer.getType().cast<RankedTensorType>();
// Reshape the element to add a leading dimension of size 1 if th element does
// not have that dimension, then perform a dynamic update slice.
auto slice_shape = llvm::to_vector<8>(buffer_type.getShape());
slice_shape[0] = 1;
auto slice_type =
RankedTensorType::get(slice_shape, buffer_type.getElementType());
auto update_slice = element;
if (element.getType() != slice_type) {
update_slice = builder.create<TF::ReshapeOp>(
loc, ArrayRef<Type>{slice_type},
ArrayRef<Value>{element, GetR1Const(slice_shape, builder, loc)});
}
return builder
.create<TF::XlaDynamicUpdateSliceOp>(
loc, ArrayRef<Type>{buffer.getType()},
ArrayRef<Value>{buffer, update_slice,
GetIndicesForElement(index, buffer, builder, loc)})
.output();
}
TensorType GetSizeType(OpBuilder builder) {
return RankedTensorType::get({1}, builder.getIntegerType(32));
}
Value ReshapeScalarToSizeType(OpBuilder builder, Value scalar, Location loc) {
auto size_type = GetSizeType(builder);
return builder.create<TF::ReshapeOp>(
loc, ArrayRef<Type>{size_type},
ArrayRef<Value>{scalar, GetR1Const(size_type.getShape(), builder, loc)});
}
LogicalResult CreateInitBufferValue(ArrayRef<int64_t> element_shape,
Value max_size, Operation* op,
Type element_dtype, OpBuilder builder,
Value* buffer) {
auto max_count_op = max_size.getDefiningOp();
if (!max_count_op) return op->emitOpError("unknown max element count");
auto max_count_const_op = llvm::dyn_cast<TF::ConstOp>(max_count_op);
if (!max_count_const_op) return op->emitOpError("unknown max element count");
int64_t max_size_const =
(*max_count_const_op.value().getValues<APInt>().begin()).getSExtValue();
return CreateInitBufferValue(element_shape, max_size_const, op, element_dtype,
builder, buffer);
}
LogicalResult CreateInitBufferValue(ArrayRef<int64_t> element_shape,
int64_t max_size, Operation* op,
Type element_dtype, OpBuilder builder,
Value* buffer) {
llvm::SmallVector<int64_t, 8> buffer_shape;
buffer_shape.push_back(max_size);
for (int64_t dim : element_shape) {
buffer_shape.push_back(dim);
}
auto zero = CreateScalarConst(0, builder, op->getLoc());
if (getElementTypeOrSelf(zero.getType()) != element_dtype) {
zero = builder.create<TF::CastOp>(
op->getLoc(), ArrayRef<Type>{RankedTensorType::get({}, element_dtype)},
ArrayRef<Value>{zero});
}
auto buffer_type = RankedTensorType::get(buffer_shape, element_dtype);
auto broadcast = builder.create<TF::BroadcastToOp>(
op->getLoc(), ArrayRef<Type>{buffer_type},
ArrayRef<Value>{zero, GetR1Const(buffer_shape, builder, op->getLoc())});
*buffer = broadcast.output();
return success();
}
llvm::Optional<RankedTensorType> GetElementTypeFromAccess(
Value collection, ModuleOp module,
llvm::function_ref<llvm::Optional<Type>(Operation*)> infer_from_op) {
for (auto& use : collection.getUses()) {
if (auto while_op = llvm::dyn_cast<TF::WhileOp>(use.getOwner())) {
auto body = while_op.body_function();
assert(body);
auto type_from_body = GetElementTypeFromAccess(
body.getArgument(use.getOperandNumber()), module, infer_from_op);
if (type_from_body.hasValue()) return type_from_body;
} else if (auto if_op = llvm::dyn_cast<TF::IfOp>(use.getOwner())) {
auto then_branch = if_op.then_function();
auto else_branch = if_op.else_function();
assert(then_branch && else_branch);
auto type_from_then = GetElementTypeFromAccess(
then_branch.getArgument(use.getOperandNumber() - 1), module,
infer_from_op);
if (type_from_then.hasValue()) return type_from_then;
auto type_from_else = GetElementTypeFromAccess(
else_branch.getArgument(use.getOperandNumber() - 1), module,
infer_from_op);
if (type_from_else.hasValue()) return type_from_else;
} else if (auto call = llvm::dyn_cast<CallOpInterface>(use.getOwner())) {
auto callee = dyn_cast<func::FuncOp>(call.resolveCallable());
auto type_from_callee = GetElementTypeFromAccess(
callee.getArgument(use.getOperandNumber()), module, infer_from_op);
if (type_from_callee.hasValue()) return type_from_callee;
} else if (llvm::isa<TF::IdentityOp, TF::IdentityNOp>(use.getOwner())) {
auto type_from_alias = GetElementTypeFromAccess(
use.getOwner()->getResult(use.getOperandNumber()), module,
infer_from_op);
if (type_from_alias.hasValue()) return type_from_alias;
} else if (auto type = infer_from_op(use.getOwner())) {
if (!type) continue;
auto elem_type = type->dyn_cast<RankedTensorType>();
if (elem_type && elem_type.hasStaticShape()) return elem_type;
}
}
return llvm::None;
}
// Creates a ReadVariableOp on a local variable.
Value ReadLocalVariable(Value local_var, OpBuilder builder, Location loc) {
return builder
.create<TF::ReadVariableOp>(
loc,
ArrayRef<Type>{getElementTypeOrSelf(local_var.getType())
.cast<TF::ResourceType>()
.getSubtypes()[0]},
ArrayRef<Value>{local_var})
.value();
}
// Creates an AssignVariableOp on a local variable.
TF::AssignVariableOp WriteLocalVariable(Value local_var, Value value,
OpBuilder builder, Location loc) {
return builder.create<TF::AssignVariableOp>(
loc, ArrayRef<Type>{}, ArrayRef<Value>{local_var, value});
}
Value AccumulateBuffers(Value a, Value b, OpBuilder builder, Location loc) {
if (getElementTypeOrSelf(a.getType()) == builder.getI1Type()) {
return builder.create<TF::LogicalOrOp>(loc, ArrayRef<Type>{a.getType()},
ArrayRef<Value>{a, b});
}
return builder.create<TF::AddV2Op>(loc, ArrayRef<Type>{a.getType()},
ArrayRef<Value>{a, b});
}
namespace {
int64_t GetFirstIfIndicesAreContiguous(Value indices) {
auto type = indices.getType().dyn_cast<RankedTensorType>();
if (!type) return -1;
auto indices_op = indices.getDefiningOp();
if (!indices_op) return -1;
auto const_op = llvm::dyn_cast<TF::ConstOp>(indices_op);
if (!const_op) return -1;
int64_t last_index = -1;
int64_t first_index = -1;
for (const auto& ind : const_op.value().getValues<APInt>()) {
if (last_index == -1) {
last_index = ind.getSExtValue();
first_index = last_index;
continue;
}
if (last_index + 1 != ind.getSExtValue()) return -1;
last_index++;
}
return first_index;
}
} // namespace
Value GatherElements(Value indices, Value buffer, OpBuilder builder,
Location loc) {
auto buffer_type = buffer.getType().cast<RankedTensorType>();
auto result_shape = llvm::to_vector<8>(buffer_type.getShape());
result_shape[0] = indices.getType().cast<RankedTensorType>().getDimSize(0);
int64_t maybe_contiguous_start = GetFirstIfIndicesAreContiguous(indices);
if (maybe_contiguous_start >= 0) {
llvm::SmallVector<int64_t, 8> slice_starts(result_shape.size(), 0);
slice_starts[0] = maybe_contiguous_start;
auto slice_type =
RankedTensorType::get(result_shape, buffer_type.getElementType());
return builder.create<TF::SliceOp>(
loc, ArrayRef<Type>{slice_type},
ArrayRef<Value>{buffer, GetR1Const(slice_starts, builder, loc),
GetR1Const(result_shape, builder, loc)});
}
auto result_type =
RankedTensorType::get(result_shape, buffer_type.getElementType());
return builder.create<TF::GatherV2Op>(
loc, ArrayRef<Type>{result_type},
ArrayRef<Value>{buffer, indices, CreateScalarConst(0, builder, loc)});
}
Value ScatterAccumulateElements(Value indices, Value updates, Value buffer,
OpBuilder builder, Location loc) {
auto buffer_type = buffer.getType().cast<RankedTensorType>();
auto updates_type = updates.getType().cast<RankedTensorType>();
int64_t maybe_contiguous_start = GetFirstIfIndicesAreContiguous(indices);
if (maybe_contiguous_start == 0 && buffer_type == updates_type) {
return AccumulateBuffers(buffer, updates, builder, loc);
}
// We cannot simply use a TensorScatterUpdate, as it does not accumulate with
// the old data; it is tricky to manually add the old data either, since there
// could be duplicates in the index. We follow the old bridge's approach by
// iterating through the indices.
auto per_slice_shape = llvm::to_vector<8>(buffer_type.getShape());
per_slice_shape[0] = 1;
auto slice_sizes = GetR1Const(per_slice_shape, builder, loc);
llvm::SmallVector<int64_t, 8> starts_in_update(buffer_type.getRank(), 0);
for (int64_t i = 0; i < updates_type.getDimSize(0); ++i) {
auto index = builder.create<TF::SliceOp>(
loc, ArrayRef<Type>{GetSizeType(builder)},
ArrayRef<Value>{indices, GetR1Const({i}, builder, loc),
GetR1Const({1}, builder, loc)});
auto old_slice =
GetElement(index, buffer, builder, loc, /*keep_slice_shape=*/true);
starts_in_update[0] = i;
auto update_slice_starts = GetR1Const(starts_in_update, builder, loc);
auto slice =
builder
.create<TF::SliceOp>(
loc, ArrayRef<Type>{old_slice.getType()},
ArrayRef<Value>{updates, update_slice_starts, slice_sizes})
.output();
slice = AccumulateBuffers(old_slice, slice, builder, loc);
buffer = SetElement(index, buffer, slice, builder, loc);
}
return buffer;
}
} // namespace collection_ops_util
} // namespace TF
} // namespace mlir