| /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. |
| |
| Licensed under the Apache License, Version 2.0 (the "License"); |
| you may not use this file except in compliance with the License. |
| You may obtain a copy of the License at |
| |
| http://www.apache.org/licenses/LICENSE-2.0 |
| |
| Unless required by applicable law or agreed to in writing, software |
| distributed under the License is distributed on an "AS IS" BASIS, |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| See the License for the specific language governing permissions and |
| limitations under the License. |
| ==============================================================================*/ |
| |
| #include "llvm/ADT/DenseMap.h" |
| #include "llvm/ADT/None.h" |
| #include "llvm/ADT/Optional.h" |
| #include "llvm/ADT/STLExtras.h" |
| #include "llvm/ADT/SmallVector.h" |
| #include "llvm/ADT/StringExtras.h" |
| #include "llvm/ADT/StringMap.h" |
| #include "llvm/ADT/StringSet.h" |
| #include "llvm/Support/Casting.h" |
| #include "llvm/Support/FormatVariadic.h" |
| #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project |
| #include "mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project |
| #include "mlir/IR/Attributes.h" // from @llvm-project |
| #include "mlir/IR/Builders.h" // from @llvm-project |
| #include "mlir/IR/BuiltinOps.h" // from @llvm-project |
| #include "mlir/IR/BuiltinTypes.h" // from @llvm-project |
| #include "mlir/IR/Location.h" // from @llvm-project |
| #include "mlir/IR/MLIRContext.h" // from @llvm-project |
| #include "mlir/IR/SymbolTable.h" // from @llvm-project |
| #include "mlir/IR/TypeUtilities.h" // from @llvm-project |
| #include "mlir/IR/Types.h" // from @llvm-project |
| #include "mlir/IR/Value.h" // from @llvm-project |
| #include "mlir/Pass/Pass.h" // from @llvm-project |
| #include "mlir/Support/LLVM.h" // from @llvm-project |
| #include "mlir/Support/LogicalResult.h" // from @llvm-project |
| #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" |
| #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" |
| #include "tensorflow/compiler/mlir/tensorflow/transforms/collection_ops_util.h" |
| #include "tensorflow/compiler/mlir/tensorflow/transforms/passes_detail.h" |
| #include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h" |
| #include "tensorflow/compiler/mlir/tensorflow/utils/mangling_util.h" |
| #include "tensorflow/core/framework/tensor.h" |
| #include "tensorflow/core/framework/tensor_shape.h" |
| #include "tensorflow/core/framework/tensor_shape.pb.h" |
| #include "tensorflow/core/framework/types.pb.h" |
| #include "tensorflow/core/platform/types.h" |
| |
| namespace mlir { |
| namespace { |
| |
| namespace cutil = TF::collection_ops_util; |
| |
| using std::string; |
| |
| // A pass that converts tensor array operations to tensor operations and |
| // read/assign ops on local variables. A later resource lifting pass can further |
| // remove the local variables. |
| // |
| // This pass requires that the full shape of the tensor array can be inferred: |
| // 1) the size needs to be a constant, 2) it specifies the full element shape, |
| // or that can be inferred from a later write, and 3) all elements have the same |
| // shape. |
| // |
| struct TensorArrayOpsDecompositionPass |
| : public TF::TensorArrayOpsDecompositionPassBase< |
| TensorArrayOpsDecompositionPass> { |
| void runOnOperation() override; |
| }; |
| |
| // Infers the element type and count for a TensorArraySplitV3Op. Requires |
| // constant lengths and static shape on the input value. |
| LogicalResult GetSplitElementTypeAndCount(TF::TensorArraySplitV3Op split, |
| RankedTensorType* elem_type, |
| int64_t* count) { |
| auto lengths_const = |
| llvm::dyn_cast_or_null<TF::ConstOp>(split.lengths().getDefiningOp()); |
| if (!lengths_const) return split.emitOpError("non-constant split lengths"); |
| *count = lengths_const.value().getNumElements(); |
| if (*count <= 0) return split.emitOpError("non-positive split count"); |
| auto buffer_type = split.value().getType().dyn_cast<RankedTensorType>(); |
| if (!buffer_type || !buffer_type.hasStaticShape() || |
| buffer_type.getRank() < 1) { |
| return split.emitOpError("unknown or invalid split tensor shape"); |
| } |
| int64_t length = buffer_type.getDimSize(0) / *count; |
| for (const auto& len : lengths_const.value().getValues<APInt>()) { |
| if (length == len.getSExtValue()) continue; |
| return split.emitOpError("different split lengths are not supported"); |
| } |
| llvm::SmallVector<int64_t, 8> elem_shape; |
| elem_shape.push_back(length); |
| for (int64_t dim : buffer_type.getShape().drop_front()) { |
| elem_shape.push_back(dim); |
| } |
| *elem_type = RankedTensorType::get(elem_shape, buffer_type.getElementType()); |
| return success(); |
| } |
| |
| // Tries to infer the tensor array element shape. |
| llvm::Optional<llvm::SmallVector<int64_t, 8>> GetTensorArrayElementShape( |
| TF::TensorArrayV3Op ta, ModuleOp module) { |
| auto element_shape = ta.element_shapeAttr().cast<mlir::TF::ShapeAttr>(); |
| if (element_shape.hasStaticShape()) { |
| auto shape = element_shape.getShape(); |
| // Convert int64 to int64_t. |
| llvm::SmallVector<int64_t, 8> dims(shape.begin(), shape.end()); |
| return dims; |
| } |
| |
| bool has_failure = false; |
| auto elem_type = cutil::GetElementTypeFromAccess( |
| ta.handle(), module, [&](Operation* user) -> llvm::Optional<Type> { |
| if (has_failure) return llvm::None; |
| if (auto write = llvm::dyn_cast<TF::TensorArrayWriteV3Op>(user)) { |
| return write.value().getType(); |
| } else if (auto split = |
| llvm::dyn_cast<TF::TensorArraySplitV3Op>(user)) { |
| if (!split.lengths().getDefiningOp() || |
| !llvm::isa<TF::ConstOp>(split.lengths().getDefiningOp())) { |
| return llvm::None; |
| } |
| RankedTensorType t; |
| int64_t count; |
| if (failed(GetSplitElementTypeAndCount(split, &t, &count))) { |
| has_failure = true; |
| return llvm::None; |
| } |
| return t; |
| } else if (auto scatter = |
| llvm::dyn_cast<TF::TensorArrayScatterV3Op>(user)) { |
| // TensorArrayScatter writes vector of tensors to TensorArray. We can |
| // deduce the shape of TensorArray by dropping the 0th dim of |
| // TensorArrayScatter `value`. |
| auto t = scatter.value().getType().dyn_cast<RankedTensorType>(); |
| if (!t || t.getShape().empty()) return llvm::None; |
| return RankedTensorType::get(t.getShape().drop_front(), |
| t.getElementType()); |
| } else if (auto gather = |
| llvm::dyn_cast<TF::TensorArrayGatherV3Op>(user)) { |
| // Try to infer from result type of gather. |
| auto t = gather.value().getType().dyn_cast<RankedTensorType>(); |
| if (t && !t.getShape().empty()) |
| return RankedTensorType::get(t.getShape().drop_front(), |
| t.getElementType()); |
| // Try to infer from `element_shape` attribute of gather. |
| auto element_shape = gather.element_shapeAttr() |
| .dyn_cast_or_null<mlir::TF::ShapeAttr>(); |
| if (element_shape && element_shape.hasStaticShape()) { |
| return RankedTensorType::get(element_shape.getShape(), |
| gather.dtype()); |
| } |
| } |
| return llvm::None; |
| }); |
| if (!elem_type) return llvm::None; |
| return llvm::to_vector<8>(elem_type->getShape()); |
| } |
| |
| void ReplaceAllUsesWithCast(Value old_val, Value new_val) { |
| if (old_val.use_empty()) return; |
| auto cast_op = |
| OpBuilder(old_val.getDefiningOp()) |
| .create<tensor::CastOp>(old_val.getLoc(), old_val.getType(), new_val); |
| old_val.replaceAllUsesWith(cast_op); |
| } |
| |
| void ReplaceAllUsesExceptTerminator(Value old_val, Value new_val) { |
| if (old_val.getType() == new_val.getType()) { |
| old_val.replaceAllUsesWith(new_val); |
| return; |
| } |
| Operation* old_op = old_val.getDefiningOp(); |
| Operation* terminator_op = |
| old_op->getParentOfType<func::FuncOp>().front().getTerminator(); |
| llvm::SmallPtrSet<mlir::Operation*, 1> exceptions = {terminator_op}; |
| old_val.replaceAllUsesExcept(new_val, exceptions); |
| } |
| |
| struct TensorArrayStats { |
| // Whether a write op should accumulate with the old value. Set to true if |
| // this is a gradient. |
| bool accumulate_on_write; |
| // Maps from a gradient source string to the local variable to the gradient. |
| llvm::StringMap<Value> grads; |
| }; |
| |
| LogicalResult HandleTensorArrayV3Op( |
| TF::TensorArrayV3Op ta, ModuleOp module, |
| llvm::SmallDenseMap<Value, TensorArrayStats>* stats) { |
| auto elem_shape = GetTensorArrayElementShape(ta, module); |
| if (!elem_shape) return ta.emitOpError("unknown element shape"); |
| if (ta.dynamic_size()) { |
| return ta.emitOpError("dynamic tensor array size is unsupported"); |
| } |
| Value buffer; |
| OpBuilder builder(ta); |
| if (failed(cutil::CreateInitBufferValue(*elem_shape, ta.size(), ta, |
| ta.dtype(), builder, &buffer))) { |
| return failure(); |
| } |
| auto var_type = RankedTensorType::get( |
| {}, TF::ResourceType::get( |
| ArrayRef<TensorType>{buffer.getType().cast<TensorType>()}, |
| ta.getContext())); |
| auto local_var = builder.create<TF::MlirLocalVarOp>( |
| ta.getLoc(), ArrayRef<Type>{var_type}, ArrayRef<Value>{}); |
| cutil::WriteLocalVariable(local_var, buffer, builder, ta.getLoc()); |
| ta.handle().replaceAllUsesWith(local_var); |
| // The flow output is just a way for the front end to enforce ordering among |
| // tensor array ops, but in the MLIR TF dialect they have sequential ordering. |
| // Just create a constant to replace its uses. |
| tensorflow::Tensor scalar_tensor(tensorflow::DT_FLOAT, {}); |
| scalar_tensor.scalar<float>()() = 0.0f; |
| auto flow = builder.create<TF::ConstOp>( |
| ta.getLoc(), |
| tensorflow::ConvertTensor(scalar_tensor, &builder).ValueOrDie()); |
| ta.flow().replaceAllUsesWith(flow); |
| ta.erase(); |
| (*stats)[local_var].accumulate_on_write = false; |
| return success(); |
| } |
| |
| LogicalResult HandleTensorArrayReadV3Op( |
| TF::TensorArrayReadV3Op read, |
| const llvm::SmallDenseMap<Value, TensorArrayStats>& stats) { |
| auto local_var = read.handle(); |
| if (stats.count(local_var) == 0) { |
| return read.emitOpError("unknown tensor array"); |
| } |
| OpBuilder builder(read); |
| auto buffer = cutil::ReadLocalVariable(local_var, builder, read.getLoc()); |
| auto index_reshape = |
| cutil::ReshapeScalarToSizeType(builder, read.index(), read.getLoc()); |
| auto elem = cutil::GetElement(index_reshape, buffer, builder, read.getLoc()); |
| ReplaceAllUsesExceptTerminator(read.value(), elem); |
| ReplaceAllUsesWithCast(read.value(), elem); |
| read.erase(); |
| // The clear_after_read attribute does not mean setting the tensor to 0 after |
| // read; instead it does not allow a second read before the next write. We |
| // follow the old bridge's implementation not to do anything here. |
| return success(); |
| } |
| |
| LogicalResult HandleTensorArrayWriteV3Op( |
| TF::TensorArrayWriteV3Op write, |
| const llvm::SmallDenseMap<Value, TensorArrayStats>& stats) { |
| auto local_var = write.handle(); |
| auto stat_it = stats.find(local_var); |
| if (stat_it == stats.end()) return write.emitOpError("unknown tensor array"); |
| OpBuilder builder(write); |
| auto buffer = cutil::ReadLocalVariable(local_var, builder, write.getLoc()); |
| auto index_reshape = |
| cutil::ReshapeScalarToSizeType(builder, write.index(), write.getLoc()); |
| auto elem = write.value(); |
| if (stat_it->getSecond().accumulate_on_write) { |
| // Get the old slice, and accumulate with it. We set keep_slice_shape |
| // (keeping the leading size-1 dimension) because it avoids reshape back and |
| // forth. |
| auto original_elem = |
| cutil::GetElement(index_reshape, buffer, builder, write.getLoc(), |
| /*keep_slice_shape=*/true); |
| // Add a size-1 leading dimension to elem. |
| auto slice_type = original_elem.getType().cast<RankedTensorType>(); |
| elem = builder.create<TF::ReshapeOp>( |
| write.getLoc(), ArrayRef<Type>{slice_type}, |
| ArrayRef<Value>{elem, cutil::GetR1Const(slice_type.getShape(), builder, |
| write.getLoc())}); |
| elem = |
| cutil::AccumulateBuffers(elem, original_elem, builder, write.getLoc()); |
| } |
| buffer = |
| cutil::SetElement(index_reshape, buffer, elem, builder, write.getLoc()); |
| cutil::WriteLocalVariable(local_var, buffer, builder, write.getLoc()); |
| write.flow_out().replaceAllUsesWith(write.flow_in()); |
| write.erase(); |
| return success(); |
| } |
| |
| LogicalResult HandleTensorArrayConcatV3Op( |
| TF::TensorArrayConcatV3Op concat, |
| const llvm::SmallDenseMap<Value, TensorArrayStats>& stats) { |
| auto local_var = concat.handle(); |
| if (stats.count(local_var) == 0) { |
| return concat.emitOpError("unknown tensor array"); |
| } |
| OpBuilder builder(concat); |
| auto buffer = cutil::ReadLocalVariable(local_var, builder, concat.getLoc()); |
| auto buffer_type = buffer.getType().cast<RankedTensorType>(); |
| if (buffer_type.getShape().size() <= 1) { |
| return concat.emitOpError("cannot concat on scalar-element tensor array"); |
| } |
| // Merget he first two dimensions. |
| auto shape = llvm::to_vector<8>(buffer_type.getShape().drop_front()); |
| shape[0] *= buffer_type.getDimSize(0); |
| buffer = builder.create<TF::ReshapeOp>( |
| concat.getLoc(), |
| ArrayRef<Type>{ |
| RankedTensorType::get(shape, buffer_type.getElementType())}, |
| ArrayRef<Value>{buffer, |
| cutil::GetR1Const(shape, builder, concat.getLoc())}); |
| ReplaceAllUsesExceptTerminator(concat.value(), buffer); |
| ReplaceAllUsesWithCast(concat.value(), buffer); |
| |
| // Create the lengths as a list of the same value (element size). |
| tensorflow::Tensor lengths_tensor(tensorflow::DT_INT64, |
| {buffer_type.getDimSize(0)}); |
| for (int64_t i = 0; i < buffer_type.getDimSize(0); ++i) { |
| lengths_tensor.vec<int64_t>()(i) = buffer_type.getDimSize(1); |
| } |
| concat.lengths().replaceAllUsesWith(builder.create<TF::ConstOp>( |
| concat.getLoc(), |
| tensorflow::ConvertTensor(lengths_tensor, &builder).ValueOrDie())); |
| concat.erase(); |
| return success(); |
| } |
| |
| LogicalResult HandleTensorArraySplitV3Op( |
| TF::TensorArraySplitV3Op split, |
| const llvm::SmallDenseMap<Value, TensorArrayStats>& stats) { |
| auto local_var = split.handle(); |
| if (stats.count(local_var) == 0) { |
| return split.emitOpError("unknown tensor array"); |
| } |
| OpBuilder builder(split); |
| int64_t count; |
| RankedTensorType elem_type; |
| if (failed(GetSplitElementTypeAndCount(split, &elem_type, &count))) { |
| return failure(); |
| } |
| llvm::SmallVector<int64_t, 8> buffer_shape; |
| buffer_shape.push_back(count); |
| for (int64_t dim : elem_type.getShape()) buffer_shape.push_back(dim); |
| // Reshape the input to match the buffer of the tensor array. |
| auto buffer = builder |
| .create<TF::ReshapeOp>( |
| split.getLoc(), |
| ArrayRef<Type>{RankedTensorType::get( |
| buffer_shape, elem_type.getElementType())}, |
| ArrayRef<Value>{split.value(), |
| cutil::GetR1Const(buffer_shape, builder, |
| split.getLoc())}) |
| .output(); |
| // Accumulate with the old buffer. |
| auto old_buffer = |
| cutil::ReadLocalVariable(local_var, builder, split.getLoc()); |
| buffer = |
| cutil::AccumulateBuffers(old_buffer, buffer, builder, split.getLoc()); |
| cutil::WriteLocalVariable(local_var, buffer, builder, split.getLoc()); |
| split.flow_out().replaceAllUsesWith(split.flow_in()); |
| split.erase(); |
| return success(); |
| } |
| |
| LogicalResult HandleTensorArraySizeV3Op( |
| TF::TensorArraySizeV3Op size, |
| const llvm::SmallDenseMap<Value, TensorArrayStats>& stats) { |
| auto local_var = size.handle(); |
| if (stats.count(local_var) == 0) { |
| return size.emitOpError("unknown tensor array"); |
| } |
| auto buffer_type = getElementTypeOrSelf(local_var.getType()) |
| .cast<TF::ResourceType>() |
| .getSubtypes()[0] |
| .cast<RankedTensorType>(); |
| OpBuilder builder(size); |
| auto result = cutil::CreateScalarConst(buffer_type.getDimSize(0), builder, |
| size.getLoc()); |
| size.size().replaceAllUsesWith(result); |
| size.erase(); |
| return success(); |
| } |
| |
| LogicalResult CreateAndInitializeGradVariable(Type local_var_type, |
| Operation* op, Value* var) { |
| OpBuilder builder(op); |
| *var = builder.create<TF::MlirLocalVarOp>( |
| op->getLoc(), ArrayRef<Type>{local_var_type}, ArrayRef<Value>{}); |
| Value buffer; |
| auto buffer_type = getElementTypeOrSelf(local_var_type) |
| .cast<TF::ResourceType>() |
| .getSubtypes()[0] |
| .cast<RankedTensorType>(); |
| if (failed(cutil::CreateInitBufferValue( |
| buffer_type.getShape().drop_front(), buffer_type.getDimSize(0), op, |
| buffer_type.getElementType(), builder, &buffer))) { |
| return failure(); |
| } |
| cutil::WriteLocalVariable(*var, buffer, builder, op->getLoc()); |
| return success(); |
| } |
| |
| LogicalResult HandleTensorArrayGradV3Op( |
| TF::TensorArrayGradV3Op grad, |
| llvm::SmallDenseMap<Value, TensorArrayStats>* stats) { |
| auto local_var = grad.handle(); |
| OpBuilder builder(grad); |
| Value grad_var; |
| auto sit = stats->find(local_var); |
| if (sit == stats->end()) return grad.emitOpError("unknown tensor array"); |
| auto emplace_res = |
| sit->getSecond().grads.try_emplace(grad.source().str(), Value()); |
| if (!emplace_res.second) { |
| // If the source has been assigned a grad, use it. |
| grad_var = emplace_res.first->second; |
| } else { |
| if (failed(CreateAndInitializeGradVariable(local_var.getType(), grad, |
| &grad_var))) { |
| return failure(); |
| } |
| emplace_res.first->second = grad_var; |
| // Write to a grad accumulates with previous writes. |
| (*stats)[grad_var].accumulate_on_write = true; |
| } |
| grad.flow_out().replaceAllUsesWith(grad.flow_in()); |
| grad.grad_handle().replaceAllUsesWith(grad_var); |
| grad.erase(); |
| return success(); |
| } |
| |
| LogicalResult HandleTensorArrayGatherV3Op( |
| TF::TensorArrayGatherV3Op gather, |
| const llvm::SmallDenseMap<Value, TensorArrayStats>& stats) { |
| auto local_var = gather.handle(); |
| if (stats.count(local_var) == 0) { |
| return gather.emitOpError("unknown tensor array"); |
| } |
| OpBuilder builder(gather); |
| auto buffer = cutil::ReadLocalVariable(local_var, builder, gather.getLoc()); |
| auto result = |
| cutil::GatherElements(gather.indices(), buffer, builder, gather.getLoc()); |
| ReplaceAllUsesExceptTerminator(gather.value(), result); |
| ReplaceAllUsesWithCast(gather.value(), result); |
| gather.erase(); |
| return success(); |
| } |
| |
| LogicalResult HandleTensorArrayScatterV3Op( |
| TF::TensorArrayScatterV3Op scatter, |
| const llvm::SmallDenseMap<Value, TensorArrayStats>& stats) { |
| auto local_var = scatter.handle(); |
| if (stats.count(local_var) == 0) { |
| return scatter.emitOpError("unknown tensor array"); |
| } |
| OpBuilder builder(scatter); |
| auto buffer = cutil::ReadLocalVariable(local_var, builder, scatter.getLoc()); |
| buffer = cutil::ScatterAccumulateElements(scatter.indices(), scatter.value(), |
| buffer, builder, scatter.getLoc()); |
| cutil::WriteLocalVariable(local_var, buffer, builder, scatter.getLoc()); |
| scatter.flow_out().replaceAllUsesWith(scatter.flow_in()); |
| scatter.erase(); |
| return success(); |
| } |
| |
| // Updates func's type according to its current arguments and return values. |
| void UpdateFuncType(func::FuncOp func) { |
| llvm::SmallVector<Type, 8> arg_types; |
| for (auto arg : func.getArguments()) arg_types.push_back(arg.getType()); |
| func.setType( |
| FunctionType::get(func.getContext(), arg_types, |
| func.front().getTerminator()->getOperandTypes())); |
| } |
| |
| // Finds the accessed gradient sources for each tensor array argument. |
| llvm::SmallDenseMap<int64_t, llvm::SmallVector<string, 4>> AccessedGradients( |
| ArrayRef<func::FuncOp> funcs, ModuleOp module) { |
| llvm::SmallDenseMap<int64_t, llvm::SmallVector<string, 4>> result; |
| llvm::SmallDenseMap<int64_t, llvm::StringSet<>> result_sets; |
| auto insert = [&](Value v, const string& source, const Block& func_block) { |
| auto arg = v.dyn_cast<BlockArgument>(); |
| if (!arg || arg.getOwner() != &func_block) return; |
| auto insert_res = result_sets[arg.getArgNumber()].insert(source); |
| if (!insert_res.second) return; |
| result[arg.getArgNumber()].push_back(source); |
| }; |
| for (func::FuncOp func : funcs) { |
| const Block& func_block = func.front(); |
| // Walk all operations and nested regions to find accessed gradient sources |
| // for function arguments. |
| func.walk([&](Operation* op) { |
| if (llvm::isa<TF::IdentityOp, TF::IdentityNOp>(op)) { |
| op->replaceAllUsesWith(op->getOperands()); |
| return; |
| } |
| if (auto grad = llvm::dyn_cast<TF::TensorArrayGradV3Op>(op)) { |
| insert(grad.handle(), grad.source().str(), func_block); |
| } else if (auto while_op = llvm::dyn_cast<TF::WhileOp>(op)) { |
| for (const auto& entry : AccessedGradients( |
| {while_op.body_function(), while_op.cond_function()}, module)) |
| for (const string& source : entry.getSecond()) |
| insert(while_op.getOperand(entry.getFirst()), source, func_block); |
| } else if (auto if_op = llvm::dyn_cast<TF::IfOp>(op)) { |
| for (const auto& entry : AccessedGradients( |
| {if_op.then_function(), if_op.else_function()}, module)) |
| for (const string& source : entry.getSecond()) |
| insert(if_op.getOperand(entry.getFirst() + 1), source, func_block); |
| } else if (auto call = llvm::dyn_cast<CallOpInterface>(op)) { |
| auto callee = dyn_cast<func::FuncOp>(call.resolveCallable()); |
| for (const auto& entry : AccessedGradients({callee}, module)) |
| for (const string& source : entry.getSecond()) |
| insert(call.getArgOperands()[entry.getFirst()], source, func_block); |
| } |
| }); |
| } |
| return result; |
| } |
| |
| // Contains cached information for decomposed callee functions for (stateful) |
| // partitioned call ops. |
| struct PartitionedCallTensorArrayOpsInfo { |
| bool signature_change; |
| func::FuncOp decomposed_callee; |
| llvm::SmallVector<std::pair<int64_t, llvm::SmallVector<string, 4>>, 4> |
| arg_grads; |
| llvm::SmallVector<std::pair<int64_t, int64_t>, 4> ret_forward_input; |
| }; |
| |
| // Updates a called function's input signature by adjusting resource types, and |
| // adding required gradient arguments. |
| void ChangeFunctionInputSignature( |
| func::FuncOp func, |
| const llvm::SmallDenseMap<int64_t, llvm::SmallVector<string, 4>>& grads, |
| llvm::function_ref<Type(int64_t)> ta_arg_buffer_type, |
| llvm::function_ref<bool(int64_t)> ta_accumulate_on_write, |
| llvm::SmallDenseMap<Value, TensorArrayStats>* stats) { |
| int64_t original_args = func.getNumArguments(); |
| for (int64_t argnum = 0; argnum < original_args; ++argnum) { |
| auto arg = func.getArgument(argnum); |
| Type t = ta_arg_buffer_type(argnum); |
| if (!t) continue; |
| arg.setType(t); |
| auto grad_it = grads.find(argnum); |
| if (grad_it == grads.end()) continue; |
| llvm::StringMap<Value> grads_map; |
| for (const string& source : grad_it->getSecond()) { |
| auto g = func.front().addArgument(t, func.getLoc()); |
| (*stats)[g].accumulate_on_write = true; |
| grads_map[source] = g; |
| } |
| auto& stat = (*stats)[arg]; |
| stat.accumulate_on_write = ta_accumulate_on_write(argnum); |
| stat.grads = std::move(grads_map); |
| } |
| UpdateFuncType(func); |
| } |
| |
| LogicalResult DecomposeTensorArrayOps( |
| Block*, ModuleOp, llvm::SmallDenseMap<Value, TensorArrayStats>*, |
| llvm::StringMap<PartitionedCallTensorArrayOpsInfo>*); |
| |
| LogicalResult HandleWhileOp(TF::WhileOp while_op, ModuleOp module, |
| llvm::SmallDenseMap<Value, TensorArrayStats>* stats, |
| llvm::StringMap<PartitionedCallTensorArrayOpsInfo>* |
| decomposed_partitioned_call_callees) { |
| auto body = while_op.body_function(); |
| auto cond = while_op.cond_function(); |
| auto grads = AccessedGradients({body, cond}, module); |
| auto ta_arg_buffer_type = [&](int64_t index) -> Type { |
| auto it = stats->find(while_op.getOperand(index)); |
| if (it == stats->end()) return nullptr; |
| return it->getFirst().getType(); |
| }; |
| auto ta_accumulate_on_write = [&](int64_t index) { |
| auto it = stats->find(while_op.getOperand(index)); |
| if (it == stats->end()) return false; |
| return it->getSecond().accumulate_on_write; |
| }; |
| llvm::SmallDenseMap<Value, TensorArrayStats> body_stats; |
| ChangeFunctionInputSignature(body, grads, ta_arg_buffer_type, |
| ta_accumulate_on_write, &body_stats); |
| llvm::SmallDenseMap<Value, TensorArrayStats> cond_stats; |
| ChangeFunctionInputSignature(cond, grads, ta_arg_buffer_type, |
| ta_accumulate_on_write, &cond_stats); |
| if (failed(DecomposeTensorArrayOps(&body.front(), module, &body_stats, |
| decomposed_partitioned_call_callees)) || |
| failed(DecomposeTensorArrayOps(&cond.front(), module, &cond_stats, |
| decomposed_partitioned_call_callees))) { |
| return failure(); |
| } |
| if (body_stats.empty() && cond_stats.empty()) return success(); |
| auto old_body_ret = body.front().getTerminator(); |
| auto new_retvals = llvm::to_vector<8>(old_body_ret->getOperands()); |
| for (int64_t i = 0; i < while_op.getNumResults(); ++i) { |
| if (!ta_arg_buffer_type(i)) continue; |
| auto retval = old_body_ret->getOperand(i); |
| auto arg = retval.dyn_cast<BlockArgument>(); |
| if (!arg) { |
| return while_op.emitOpError( |
| "output tensor array does not alias input in a while loop"); |
| } |
| for (const string& source : grads[i]) { |
| new_retvals.push_back(body_stats[arg].grads[source]); |
| } |
| } |
| OpBuilder(old_body_ret) |
| .create<func::ReturnOp>(old_body_ret->getLoc(), new_retvals); |
| old_body_ret->erase(); |
| UpdateFuncType(body); |
| // Recreate the while op. |
| auto operands = llvm::to_vector<8>(while_op.getOperands()); |
| for (int64_t i = 0; i < while_op.getNumOperands(); ++i) { |
| auto grad_it = grads.find(i); |
| auto& stat = (*stats)[operands[i]]; |
| if (grad_it == grads.end()) continue; |
| for (const string& source : grad_it->getSecond()) { |
| auto it = stat.grads.find(source); |
| if (it != stat.grads.end()) { |
| operands.push_back(it->second); |
| } else { |
| Value grad_var; |
| if (failed(CreateAndInitializeGradVariable(operands[i].getType(), |
| while_op, &grad_var))) { |
| return failure(); |
| } |
| stat.grads[source] = grad_var; |
| operands.push_back(grad_var); |
| (*stats)[grad_var].accumulate_on_write = true; |
| } |
| } |
| } |
| OpBuilder builder(while_op); |
| auto new_while = builder.create<TF::WhileOp>( |
| while_op.getLoc(), body.getFunctionType().getInputs(), operands, |
| while_op->getAttrs()); |
| for (int64_t i = 0; i < while_op.getNumOperands(); ++i) { |
| if (ta_arg_buffer_type(i)) { |
| while_op.getResult(i).replaceAllUsesWith(while_op.getOperand(i)); |
| } else { |
| while_op.getResult(i).replaceAllUsesWith(new_while.getResult(i)); |
| } |
| } |
| while_op.erase(); |
| return success(); |
| } |
| |
| LogicalResult HandleIfOp(TF::IfOp if_op, ModuleOp module, |
| llvm::SmallDenseMap<Value, TensorArrayStats>* stats, |
| llvm::StringMap<PartitionedCallTensorArrayOpsInfo>* |
| decomposed_partitioned_call_callees) { |
| auto then_branch = if_op.then_function(); |
| auto else_branch = if_op.else_function(); |
| auto grads = AccessedGradients({then_branch, else_branch}, module); |
| auto ta_arg_buffer_type = [&](int64_t index) -> Type { |
| auto it = stats->find(if_op.getOperand(index + 1)); |
| if (it == stats->end()) return nullptr; |
| return it->getFirst().getType(); |
| }; |
| auto ta_accumulate_on_write = [&](int64_t index) { |
| auto it = stats->find(if_op.getOperand(index + 1)); |
| if (it == stats->end()) return false; |
| return it->getSecond().accumulate_on_write; |
| }; |
| llvm::SmallDenseMap<Value, TensorArrayStats> then_stats; |
| ChangeFunctionInputSignature(then_branch, grads, ta_arg_buffer_type, |
| ta_accumulate_on_write, &then_stats); |
| llvm::SmallDenseMap<Value, TensorArrayStats> else_stats; |
| ChangeFunctionInputSignature(else_branch, grads, ta_arg_buffer_type, |
| ta_accumulate_on_write, &else_stats); |
| if (failed(DecomposeTensorArrayOps(&then_branch.front(), module, &then_stats, |
| decomposed_partitioned_call_callees)) || |
| failed(DecomposeTensorArrayOps(&else_branch.front(), module, &else_stats, |
| decomposed_partitioned_call_callees))) { |
| return failure(); |
| } |
| if (then_stats.empty() && else_stats.empty()) return success(); |
| // Recreate the if op. |
| auto operands = llvm::to_vector<8>(if_op.getOperands()); |
| for (int64_t i = 0; i < if_op.getNumOperands() - 1; ++i) { |
| auto grad_it = grads.find(i); |
| auto& stat = (*stats)[operands[i + 1]]; |
| if (grad_it == grads.end()) continue; |
| for (const string& source : grad_it->getSecond()) { |
| auto it = stat.grads.find(source); |
| if (it != stat.grads.end()) { |
| operands.push_back(it->second); |
| } else { |
| Value grad_var; |
| if (failed(CreateAndInitializeGradVariable(operands[i + 1].getType(), |
| if_op, &grad_var))) { |
| return failure(); |
| } |
| stat.grads[source] = grad_var; |
| operands.push_back(grad_var); |
| (*stats)[grad_var].accumulate_on_write = true; |
| } |
| } |
| } |
| OpBuilder builder(if_op); |
| auto new_if = builder.create<TF::IfOp>( |
| if_op.getLoc(), then_branch.getFunctionType().getResults(), operands, |
| if_op->getAttrs()); |
| auto ret_forwards_input = [](func::FuncOp f, int64_t ret_ind) -> int64_t { |
| auto retval = f.front().getTerminator()->getOperand(ret_ind); |
| auto arg = retval.dyn_cast<BlockArgument>(); |
| if (!arg) return -1; |
| return arg.getArgNumber(); |
| }; |
| for (int64_t i = 0; i < if_op.getNumResults(); ++i) { |
| if (!getElementTypeOrSelf(if_op.getResult(i).getType()) |
| .isa<TF::ResourceType>()) { |
| if_op.getResult(i).replaceAllUsesWith(new_if.getResult(i)); |
| continue; |
| } |
| int64_t then_forward_input = ret_forwards_input(then_branch, i); |
| int64_t else_foward_input = ret_forwards_input(else_branch, i); |
| if (then_forward_input != else_foward_input || then_forward_input < 0) { |
| return if_op.emitOpError( |
| "branches do not forward the same input resource"); |
| } |
| if_op.getResult(i).replaceAllUsesWith( |
| if_op.getOperand(then_forward_input + 1)); |
| } |
| if_op.erase(); |
| return success(); |
| } |
| |
| template <typename CallOp> |
| LogicalResult HandlePartitionedCallOp( |
| CallOp call, func::FuncOp callee, ModuleOp module, |
| llvm::SmallDenseMap<Value, TensorArrayStats>* stats, |
| llvm::StringMap<PartitionedCallTensorArrayOpsInfo>* |
| decomposed_partitioned_call_callees) { |
| auto emplace_res = decomposed_partitioned_call_callees->try_emplace( |
| callee.getName(), PartitionedCallTensorArrayOpsInfo()); |
| auto& info = emplace_res.first->second; |
| // Recreates the call op with info. |
| auto recreate_caller = [&]() -> LogicalResult { |
| auto new_operands = llvm::to_vector<8>(call.getOperands()); |
| for (const auto& entry : info.arg_grads) { |
| auto it = stats->find(call.getOperand(entry.first)); |
| if (it == stats->end()) return call.emitOpError("unknown tensor array"); |
| for (const string& source : entry.second) { |
| auto grad_it = it->getSecond().grads.find(source); |
| if (grad_it != it->getSecond().grads.end()) { |
| new_operands.push_back(grad_it->second); |
| } else { |
| Value grad_var; |
| if (failed(CreateAndInitializeGradVariable(it->getFirst().getType(), |
| call, &grad_var))) { |
| return failure(); |
| } |
| it->getSecond().grads[source] = grad_var; |
| new_operands.push_back(grad_var); |
| } |
| } |
| } |
| OpBuilder builder(call); |
| auto new_call = builder.create<CallOp>( |
| call.getLoc(), info.decomposed_callee.getFunctionType().getResults(), |
| new_operands, call->getAttrs()); |
| new_call->setAttr( |
| "f", SymbolRefAttr::get( |
| builder.getContext(), |
| const_cast<func::FuncOp&>(info.decomposed_callee).getName())); |
| for (const auto& entry : info.ret_forward_input) { |
| call.getResult(entry.first) |
| .replaceAllUsesWith(call.getOperand(entry.second)); |
| } |
| call.replaceAllUsesWith(new_call); |
| call.erase(); |
| return success(); |
| }; |
| if (!emplace_res.second) { |
| // This callee was handled before. |
| if (!info.signature_change) return success(); |
| return recreate_caller(); |
| } |
| // Rewrite the callee. |
| info.signature_change = false; |
| auto ta_arg_buffer_type = [&](int64_t index) -> Type { |
| auto it = stats->find(call.getOperand(index)); |
| if (it == stats->end()) return nullptr; |
| info.signature_change = true; |
| return it->getFirst().getType(); |
| }; |
| auto ta_accumulate_on_write = [&](int64_t index) { |
| auto it = stats->find(call.getOperand(index)); |
| if (it == stats->end()) return false; |
| return it->getSecond().accumulate_on_write; |
| }; |
| func::FuncOp lowered_callee = callee; |
| if (!callee.isPrivate()) { |
| // Clone non-private callee in case of signature change. |
| lowered_callee = callee.clone(); |
| lowered_callee.setPrivate(); |
| } |
| auto grads = AccessedGradients({lowered_callee}, module); |
| for (int64_t i = 0; i < lowered_callee.getNumArguments(); ++i) { |
| auto it = grads.find(i); |
| if (it == grads.end()) continue; |
| info.arg_grads.emplace_back(i, it->getSecond()); |
| } |
| llvm::SmallDenseMap<Value, TensorArrayStats> callee_stats; |
| ChangeFunctionInputSignature(lowered_callee, grads, ta_arg_buffer_type, |
| ta_accumulate_on_write, &callee_stats); |
| if (failed(DecomposeTensorArrayOps(&lowered_callee.front(), module, |
| &callee_stats, |
| decomposed_partitioned_call_callees))) { |
| return failure(); |
| } |
| for (int64_t i = 0; i < call.getNumResults(); ++i) { |
| auto ret = lowered_callee.front().getTerminator()->getOperand(i); |
| if (!getElementTypeOrSelf(ret.getType()).isa<TF::ResourceType>()) continue; |
| auto arg = ret.dyn_cast<BlockArgument>(); |
| if (!arg) continue; |
| info.ret_forward_input.emplace_back(i, arg.getArgNumber()); |
| } |
| |
| info.decomposed_callee = lowered_callee; |
| if (lowered_callee != callee) { |
| if (!info.signature_change) { |
| // Signature is not modified. We do not need to keep two copies. |
| lowered_callee.setName( |
| StringAttr::get(callee->getContext(), callee.getName())); |
| callee.erase(); |
| } else { |
| // Add the clone with a new name. |
| lowered_callee.setName(StringAttr::get( |
| callee->getContext(), |
| llvm::formatv("{0}_tensorarray_decomposed", callee.getName()).str())); |
| } |
| SymbolTable(module).insert(lowered_callee); |
| } |
| if (info.signature_change) return recreate_caller(); |
| return success(); |
| } |
| |
| LogicalResult HandleRegionControlFlowOps( |
| Operation& op, ModuleOp module, |
| llvm::SmallDenseMap<Value, TensorArrayStats>* stats, |
| llvm::StringMap<PartitionedCallTensorArrayOpsInfo>* |
| decomposed_partitioned_call_callees) { |
| for (OpOperand& operand : op.getOpOperands()) { |
| if (getElementTypeOrSelf(operand.get().getType()).isa<TF::ResourceType>()) { |
| return op.emitOpError() |
| << "found unexpected type " << operand.get().getType() |
| << " of operand #" << operand.getOperandNumber() |
| << ", resource type operands are expected to have been " |
| "canonicalized away for region based control flow ops"; |
| } |
| } |
| for (OpResult result : op.getResults()) { |
| if (getElementTypeOrSelf(result.getType()).isa<TF::ResourceType>()) { |
| return op.emitOpError() |
| << "found unexpected type " << result.getType() << " of result #" |
| << result.getResultNumber() |
| << ", resource type results are expected to have been " |
| "canonicalized away for region based control flow ops"; |
| } |
| } |
| |
| for (Region& region : op.getRegions()) { |
| if (failed(DecomposeTensorArrayOps(®ion.front(), module, stats, |
| decomposed_partitioned_call_callees))) |
| return failure(); |
| } |
| return success(); |
| } |
| |
| LogicalResult DecomposeTensorArrayOps( |
| Block* block, ModuleOp module, |
| llvm::SmallDenseMap<Value, TensorArrayStats>* stats, |
| llvm::StringMap<PartitionedCallTensorArrayOpsInfo>* |
| decomposed_partitioned_call_callees) { |
| for (auto& op : llvm::make_early_inc_range(block->getOperations())) { |
| if (llvm::isa<TF::IdentityOp, TF::IdentityNOp>(&op)) { |
| op.replaceAllUsesWith(op.getOperands()); |
| op.erase(); |
| } else if (auto ta = llvm::dyn_cast<TF::TensorArrayV3Op>(&op)) { |
| if (failed(HandleTensorArrayV3Op(ta, module, stats))) { |
| return failure(); |
| } |
| } else if (auto read = llvm::dyn_cast<TF::TensorArrayReadV3Op>(&op)) { |
| if (failed(HandleTensorArrayReadV3Op(read, *stats))) return failure(); |
| } else if (auto write = llvm::dyn_cast<TF::TensorArrayWriteV3Op>(&op)) { |
| if (failed(HandleTensorArrayWriteV3Op(write, *stats))) return failure(); |
| } else if (auto concat = llvm::dyn_cast<TF::TensorArrayConcatV3Op>(&op)) { |
| if (failed(HandleTensorArrayConcatV3Op(concat, *stats))) return failure(); |
| } else if (auto split = llvm::dyn_cast<TF::TensorArraySplitV3Op>(&op)) { |
| if (failed(HandleTensorArraySplitV3Op(split, *stats))) return failure(); |
| } else if (auto size = llvm::dyn_cast<TF::TensorArraySizeV3Op>(&op)) { |
| if (failed(HandleTensorArraySizeV3Op(size, *stats))) return failure(); |
| } else if (auto grad = llvm::dyn_cast<TF::TensorArrayGradV3Op>(&op)) { |
| if (failed(HandleTensorArrayGradV3Op(grad, stats))) return failure(); |
| } else if (auto gather = llvm::dyn_cast<TF::TensorArrayGatherV3Op>(&op)) { |
| if (failed(HandleTensorArrayGatherV3Op(gather, *stats))) return failure(); |
| } else if (auto scatter = llvm::dyn_cast<TF::TensorArrayScatterV3Op>(&op)) { |
| if (failed(HandleTensorArrayScatterV3Op(scatter, *stats))) { |
| return failure(); |
| } |
| } else if (auto close = llvm::dyn_cast<TF::TensorArrayCloseV3Op>(&op)) { |
| close.erase(); |
| } else if (auto while_op = llvm::dyn_cast<TF::WhileOp>(&op)) { |
| if (failed(HandleWhileOp(while_op, module, stats, |
| decomposed_partitioned_call_callees))) { |
| return failure(); |
| } |
| } else if (auto if_op = llvm::dyn_cast<TF::IfOp>(&op)) { |
| if (failed(HandleIfOp(if_op, module, stats, |
| decomposed_partitioned_call_callees))) { |
| return failure(); |
| } |
| } else if (llvm::isa<TF::CaseRegionOp>(op) || |
| llvm::isa<TF::IfRegionOp>(op) || |
| llvm::isa<TF::WhileRegionOp>(op)) { |
| if (failed(HandleRegionControlFlowOps( |
| op, module, stats, decomposed_partitioned_call_callees))) |
| return failure(); |
| } else if (auto pcall = llvm::dyn_cast<TF::PartitionedCallOp>(&op)) { |
| auto callee = pcall.func(); |
| if (!callee) |
| return pcall.emitOpError( |
| "TensorArray decomposition does not support call with nested " |
| "references."); |
| |
| if (failed( |
| HandlePartitionedCallOp(pcall, callee, module, stats, |
| decomposed_partitioned_call_callees))) { |
| return failure(); |
| } |
| } else if (auto spcall = |
| llvm::dyn_cast<TF::StatefulPartitionedCallOp>(&op)) { |
| if (failed( |
| HandlePartitionedCallOp(spcall, spcall.func(), module, stats, |
| decomposed_partitioned_call_callees))) { |
| return failure(); |
| } |
| } |
| } |
| return success(); |
| } |
| |
| void TensorArrayOpsDecompositionPass::runOnOperation() { |
| auto module = getOperation(); |
| auto main = module.lookupSymbol<func::FuncOp>("main"); |
| if (!main) return; |
| llvm::SmallDenseMap<Value, TensorArrayStats> stats; |
| llvm::StringMap<PartitionedCallTensorArrayOpsInfo> |
| decomposed_partitioned_call_callees; |
| if (failed(DecomposeTensorArrayOps(&main.front(), module, &stats, |
| &decomposed_partitioned_call_callees))) { |
| signalPassFailure(); |
| } |
| } |
| |
| } // namespace |
| |
| namespace TF { |
| |
| std::unique_ptr<OperationPass<ModuleOp>> |
| CreateTensorArrayOpsDecompositionPass() { |
| return std::make_unique<TensorArrayOpsDecompositionPass>(); |
| } |
| |
| } // namespace TF |
| } // namespace mlir |