| /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. |
| |
| Licensed under the Apache License, Version 2.0 (the "License"); |
| you may not use this file except in compliance with the License. |
| You may obtain a copy of the License at |
| |
| http://www.apache.org/licenses/LICENSE-2.0 |
| |
| Unless required by applicable law or agreed to in writing, software |
| distributed under the License is distributed on an "AS IS" BASIS, |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| See the License for the specific language governing permissions and |
| limitations under the License. |
| ==============================================================================*/ |
| |
| #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" |
| |
| #include <algorithm> |
| #include <iterator> |
| |
| #include "llvm/ADT/ArrayRef.h" |
| #include "llvm/ADT/STLExtras.h" |
| #include "llvm/ADT/Sequence.h" |
| #include "llvm/ADT/SmallVector.h" |
| #include "llvm/ADT/StringExtras.h" |
| #include "llvm/ADT/StringSwitch.h" |
| #include "llvm/Support/Casting.h" |
| #include "llvm/Support/FormatVariadic.h" |
| #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project |
| #include "mlir/Dialect/Traits.h" // from @llvm-project |
| #include "mlir/IR/Attributes.h" // from @llvm-project |
| #include "mlir/IR/Builders.h" // from @llvm-project |
| #include "mlir/IR/DialectImplementation.h" // from @llvm-project |
| #include "mlir/IR/Function.h" // from @llvm-project |
| #include "mlir/IR/MLIRContext.h" // from @llvm-project |
| #include "mlir/IR/Matchers.h" // from @llvm-project |
| #include "mlir/IR/OpDefinition.h" // from @llvm-project |
| #include "mlir/IR/OpImplementation.h" // from @llvm-project |
| #include "mlir/IR/PatternMatch.h" // from @llvm-project |
| #include "mlir/IR/StandardTypes.h" // from @llvm-project |
| #include "mlir/IR/Types.h" // from @llvm-project |
| #include "mlir/IR/Value.h" // from @llvm-project |
| #include "mlir/Support/LogicalResult.h" // from @llvm-project |
| #include "mlir/Transforms/FoldUtils.h" // from @llvm-project |
| #include "mlir/Transforms/InliningUtils.h" // from @llvm-project |
| #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" |
| |
| namespace mlir { |
| namespace tf_executor { |
| |
| //===----------------------------------------------------------------------===// |
| // TF Executor Dialect |
| //===----------------------------------------------------------------------===// |
| |
| namespace { |
| |
| using TF::DropRefType; |
| using TF::DropTypeSubTypes; |
| |
| struct TensorFlowExecutorInlinerInterface : public DialectInlinerInterface { |
| using DialectInlinerInterface::DialectInlinerInterface; |
| |
| //===--------------------------------------------------------------------===// |
| // Analysis Hooks |
| //===--------------------------------------------------------------------===// |
| |
| // Override the inlining hook to determine if 'src' can be inlined into |
| // 'dest'. |
| bool isLegalToInline(Region *dest, Region *src, |
| BlockAndValueMapping &value_mapping) const final { |
| // Allow inlining into tf.island regions if the incoming region has a single |
| // block. |
| return llvm::isa<tf_executor::IslandOp>(dest->getParentOp()) && |
| std::next(src->begin()) == src->end(); |
| } |
| }; |
| |
| struct TensorFlowExecutorOpFolderDialectInterface |
| : public OpFolderDialectInterface { |
| using OpFolderDialectInterface::OpFolderDialectInterface; |
| |
| // Registered hook to check if the given region, which is attached to an |
| // operation that is *not* isolated from above (i.e. no internal regions |
| // reference values defined in an enclosing region), should be used when |
| // materializing constants. |
| // In the executor dialect we materialize inside an island. |
| bool shouldMaterializeInto(Region *region) const final { |
| return isa<tf_executor::IslandOp>(region->getParentOp()); |
| } |
| }; |
| |
| } // namespace |
| |
| TensorFlowExecutorDialect::TensorFlowExecutorDialect(MLIRContext *context) |
| : Dialect(/*name=*/"tf_executor", context) { |
| addOperations< |
| #define GET_OP_LIST |
| #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.cc.inc" |
| >(); |
| |
| addInterfaces<TensorFlowExecutorInlinerInterface, |
| TensorFlowExecutorOpFolderDialectInterface>(); |
| |
| addTypes<ControlType, TokenType>(); |
| } |
| |
| Type TensorFlowExecutorDialect::parseType(DialectAsmParser &parser) const { |
| StringRef data_type; |
| if (parser.parseKeyword(&data_type)) return Type(); |
| |
| if (data_type == "control") return ControlType::get(getContext()); |
| if (data_type == "token") return TokenType::get(getContext()); |
| parser.emitError(parser.getNameLoc()) |
| << "unknown tf_executor type: " << data_type; |
| return nullptr; |
| } |
| |
| void TensorFlowExecutorDialect::printType(Type type, |
| DialectAsmPrinter &os) const { |
| if (type.isa<ControlType>()) { |
| os << "control"; |
| return; |
| } |
| if (type.isa<TokenType>()) { |
| os << "token"; |
| return; |
| } |
| os << "<unknown tf_executor type>"; |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Implementation for all the operations defined in ODS (op definition spec). |
| //===----------------------------------------------------------------------===// |
| |
| namespace { |
| |
| // Verifies that every control operands are at the end of the list. |
| // Used by the constraint `ControlOperandsAfterAllData` in ODS. |
| LogicalResult VerifyControlOperandsAfterAllData(Operation *op) { |
| bool found_control = false; |
| for (int operand_idx : llvm::seq<int>(0, op->getNumOperands())) { |
| if (op->getOperand(operand_idx).getType().isa<ControlType>()) { |
| found_control = true; |
| continue; |
| } |
| if (found_control) |
| return op->emitOpError() << "found non-control operand #" << operand_idx |
| << " after control operand"; |
| } |
| return success(); |
| } |
| |
| } // anonymous namespace |
| |
| //===----------------------------------------------------------------------===// |
| // tf_executor.graph |
| //===----------------------------------------------------------------------===// |
| |
| FetchOp GraphOp::GetFetch() { return llvm::cast<FetchOp>(GetBody().back()); } |
| |
| namespace { |
| |
| LogicalResult Verify(GraphOp graph) { |
| auto *executorDialect = graph.getDialect(); |
| |
| if (graph.GetBody().empty()) |
| return graph.emitOpError() << "expects a non-empty body"; |
| |
| // Only tf_executor dialect operations are allowed to be immediately nested |
| // in a tf_executor.graph region. |
| for (Operation &op : graph.GetBody()) { |
| if (op.getDialect() != executorDialect) |
| return op.emitOpError() << "unallowed inside a tf_executor.graph region"; |
| if (isa<GraphOp>(op)) |
| return op.emitOpError() |
| << "unallowed directly inside another tf_executor.graph"; |
| } |
| |
| Operation &fetch = graph.GetBody().back(); |
| if (!isa<FetchOp>(fetch)) |
| return fetch.emitOpError() |
| << "invalid tf_executor.graph terminator, fetch expected"; |
| |
| // Ensure that the fetch terminator operands matches the graph result type. |
| // All the non-control operands of the fetch operation must match the graph |
| // returned value. |
| if (fetch.getNumOperands() < graph.getNumResults()) |
| return fetch.emitOpError() << "does not have enough operands to cover the " |
| "graph returned values"; |
| for (int i : llvm::seq<int>(0, fetch.getNumOperands())) { |
| Value operand = fetch.getOperand(i); |
| // Break out of the loop at the first control operand encountered. |
| if (operand.getType().isa<ControlType>()) { |
| if (i != graph.getNumResults()) |
| return fetch.emitOpError() |
| << "operand #" << i |
| << " is a control type, can't be bound to a graph result"; |
| break; |
| } |
| if (i >= graph.getNumResults()) |
| return fetch.emitOpError() |
| << "operand #" << i << " does not have a graph results to bind"; |
| if (graph.getResult(i).getType() != operand.getType()) |
| return fetch.emitOpError() |
| << "operand #" << i << " type mismatch graph results"; |
| } |
| return success(); |
| } |
| |
| void Print(GraphOp graph, OpAsmPrinter &p) { |
| p << graph.getOperationName(); |
| p.printRegion(graph.getOperation()->getRegion(0)); |
| p.printOptionalAttrDict(graph.getAttrs()); |
| } |
| |
| ParseResult ParseGraphOp(OpAsmParser &parser, OperationState &result) { |
| llvm::SMLoc loc = parser.getCurrentLocation(); |
| |
| // Parse the body region. |
| Region &body = *result.addRegion(); |
| if (parser.parseRegion(body, llvm::None, llvm::None)) return failure(); |
| |
| if (body.getBlocks().size() > 1) |
| return parser.emitError(loc) << "expects a single block region"; |
| |
| // Ensure that the region is well formed: it contains at least a block with |
| // a FetchOp terminator. |
| GraphOp::ensureTerminator(body, parser.getBuilder(), result.location); |
| |
| // Get the results type from the terminator type inside the graph. |
| Operation &fetch = body.back().back(); |
| if (!isa<FetchOp>(fetch)) |
| return parser.emitError(loc) << "expects a tf_executor.fetch terminator"; |
| |
| // The return value of the graph operation are the non-control operands of |
| // the fetch operation. |
| result.types.reserve(fetch.getNumOperands()); |
| for (Type type : fetch.getOperandTypes()) { |
| if (type.isa<ControlType>()) break; |
| result.types.push_back(type); |
| } |
| |
| // Parse the optional attribute list. |
| if (parser.parseOptionalAttrDict(result.attributes)) return failure(); |
| |
| return success(); |
| } |
| |
| } // anonymous namespace |
| |
| //===----------------------------------------------------------------------===// |
| // tf_executor.fetch |
| //===----------------------------------------------------------------------===// |
| |
| namespace { |
| |
| void Print(FetchOp fetch, OpAsmPrinter &p) { |
| p << fetch.getOperationName(); |
| if (fetch.getNumOperands() > 0) { |
| p << ' '; |
| p.printOperands(fetch.operand_begin(), fetch.operand_end()); |
| p << " : "; |
| interleaveComma(fetch.getOperandTypes(), p); |
| } |
| p.printOptionalAttrDict(fetch.getAttrs()); |
| } |
| |
| ParseResult ParseFetchOp(OpAsmParser &parser, OperationState &result) { |
| SmallVector<OpAsmParser::OperandType, 2> opInfo; |
| SmallVector<Type, 2> types; |
| llvm::SMLoc loc = parser.getCurrentLocation(); |
| return failure(parser.parseOperandList(opInfo) || |
| (!opInfo.empty() && parser.parseColonTypeList(types)) || |
| parser.resolveOperands(opInfo, types, loc, result.operands) || |
| parser.parseOptionalAttrDict(result.attributes) |
| |
| ); |
| } |
| |
| } // anonymous namespace |
| |
| //===----------------------------------------------------------------------===// |
| // tf_executor.island |
| //===----------------------------------------------------------------------===// |
| |
| YieldOp IslandOp::GetYield() { return llvm::cast<YieldOp>(GetBody().back()); } |
| |
| // Checks if a tf_executor.island wraps a single operation and the single |
| // operation results are perfectly forwarded to the islands yield. |
| bool IslandOp::WrapsSingleOp() { |
| auto body = GetBody().without_terminator(); |
| if (!hasSingleElement(body)) return false; |
| |
| Operation &wrapped_op = *body.begin(); |
| YieldOp yield = GetYield(); |
| return wrapped_op.getNumResults() == yield.getNumOperands() && |
| std::equal(wrapped_op.getResults().begin(), |
| wrapped_op.getResults().end(), yield.getOperands().begin()); |
| } |
| |
| namespace { |
| |
| LogicalResult Verify(IslandOp island) { |
| if (island.GetBody().empty()) |
| return island.emitOpError() << "expects a non-empty body"; |
| |
| Operation &yield = island.GetBody().back(); |
| if (!isa<YieldOp>(yield)) |
| return yield.emitOpError() |
| << "invalid tf_executor.island terminator, yield expected"; |
| |
| // Ensure that the yield terminator operands matches the island results type. |
| int result_count = island.getNumResults() - 1; // -1 for the control token |
| if (yield.getNumOperands() != result_count) |
| return yield.emitOpError() |
| << "has " << yield.getNumOperands() |
| << " operand, but island returns " << result_count; |
| for (int operand_idx : llvm::seq<int>(0, yield.getNumOperands())) { |
| if (island.getResult(operand_idx).getType() != |
| yield.getOperand(operand_idx).getType()) |
| return yield.emitOpError() |
| << "operand #" << operand_idx << " type mismatch island results"; |
| } |
| |
| // Check that there aren't any control results other than the last one. |
| Type control_type = ControlType::get(island.getContext()); |
| for (int operand_idx : llvm::seq<int>(0, island.getNumResults() - 1)) { |
| if (island.getResult(operand_idx).getType() == control_type) |
| return yield.emitOpError() |
| << "unexpected control type for operand #" << operand_idx; |
| } |
| return success(); |
| } |
| |
| void Print(IslandOp op, OpAsmPrinter &p) { |
| p << op.getOperationName(); |
| if (op.getNumOperands()) { |
| // These are always control operand, no explicit type needed. |
| p << '('; |
| p.printOperands(op.getOperands()); |
| p << ')'; |
| } |
| |
| // Check if we can print the short "wraps" form: that is if the island |
| // contains a single operation and the result of this operation are perfectly |
| // forwarded to the yield. |
| if (op.getAttrs().empty() && op.WrapsSingleOp()) { |
| Operation &wrapped_op = op.GetBody().front(); |
| YieldOp yield_op = op.GetYield(); |
| // The "wraps" syntax only encodes a single location. |
| // In order to correctly round-trip, we can only use this syntax when all |
| // the locations are identical. |
| if (wrapped_op.getLoc() == op.getLoc() && |
| yield_op.getLoc() == op.getLoc()) { |
| p << " wraps "; |
| p.printGenericOp(&wrapped_op); |
| return; |
| } |
| } |
| p.printRegion(op.getOperation()->getRegion(0)); |
| p.printOptionalAttrDict(op.getAttrs()); |
| } |
| |
| ParseResult ParseIslandOp(OpAsmParser &parser, OperationState &result) { |
| llvm::SMLoc loc = parser.getCurrentLocation(); |
| Type control_type = ControlType::get(parser.getBuilder().getContext()); |
| |
| // Parse optional argument list (control dependencies only). |
| SmallVector<OpAsmParser::OperandType, 4> op_infos; |
| if (parser.parseOperandList(op_infos, OpAsmParser::Delimiter::OptionalParen)) |
| return failure(); |
| if (!op_infos.empty()) { |
| SmallVector<Type, 2> types(op_infos.size(), control_type); |
| parser.resolveOperands(op_infos, types, loc, result.operands); |
| } |
| |
| // Parse the body region. |
| Region &body = *result.addRegion(); |
| |
| if (succeeded(parser.parseOptionalKeyword("wraps"))) { |
| // If we parse the short version of the island, we have an operation in the |
| // generic form that follows the "wraps" keyword. Parse it inside the region |
| // and forward all of its results as-is to the yield operation. |
| body.push_back(new Block); |
| Block &block = body.back(); |
| Operation *wrapped_op = parser.parseGenericOperation(&block, block.begin()); |
| if (!wrapped_op) return failure(); |
| OpBuilder builder(parser.getBuilder().getContext()); |
| builder.setInsertionPointToEnd(&block); |
| builder.create<YieldOp>(wrapped_op->getLoc(), wrapped_op->getResults()); |
| result.location = wrapped_op->getLoc(); |
| } else if (parser.parseRegion(body, llvm::None, llvm::None)) { |
| return failure(); |
| } |
| |
| IslandOp::ensureTerminator(body, parser.getBuilder(), result.location); |
| |
| // Get the results type for the island from the terminator operands. |
| Operation &yield = body.back().back(); |
| result.types.reserve(yield.getNumOperands() + 1); |
| result.types.append(yield.operand_type_begin(), yield.operand_type_end()); |
| result.types.push_back(control_type); |
| |
| // Parse the optional attribute list. |
| if (parser.parseOptionalAttrDict(result.attributes)) return failure(); |
| return success(); |
| } |
| |
| } // anonymous namespace |
| |
| //===----------------------------------------------------------------------===// |
| // tf_executor.yield |
| //===----------------------------------------------------------------------===// |
| |
| namespace { |
| |
| void Print(YieldOp yield, OpAsmPrinter &p) { |
| p << yield.getOperationName(); |
| if (yield.getNumOperands() > 0) { |
| p << ' '; |
| p.printOperands(yield.operand_begin(), yield.operand_end()); |
| p << " : "; |
| interleaveComma(yield.getOperandTypes(), p); |
| } |
| p.printOptionalAttrDict(yield.getAttrs()); |
| } |
| |
| ParseResult ParseYieldOp(OpAsmParser &parser, OperationState &result) { |
| SmallVector<OpAsmParser::OperandType, 2> op_info; |
| SmallVector<Type, 2> types; |
| llvm::SMLoc loc = parser.getCurrentLocation(); |
| return failure(parser.parseOperandList(op_info) || |
| (!op_info.empty() && parser.parseColonTypeList(types)) || |
| parser.resolveOperands(op_info, types, loc, result.operands) || |
| parser.parseOptionalAttrDict(result.attributes)); |
| } |
| |
| } // anonymous namespace |
| |
| //===----------------------------------------------------------------------===// |
| // tf_executor.Switch |
| //===----------------------------------------------------------------------===// |
| |
| namespace { |
| |
| ParseResult ParseSwitchOp(OpAsmParser &parser, OperationState &result) { |
| SmallVector<OpAsmParser::OperandType, 2> op_infos; |
| SmallVector<Type, 1> types; |
| if (parser.parseOperandList(op_infos) || parser.parseColonTypeList(types)) |
| return failure(); |
| if (types.size() != 1) |
| return parser.emitError(parser.getNameLoc()) |
| << " expects only a single data type"; |
| |
| // Support parsing either a functional type (in which case all the types are |
| // fully qualified) or a short form with a single type (in which case the data |
| // input and the outputs are all using this type and predicate is tensor<i1> |
| // type). |
| if (types.front().isa<FunctionType>()) { |
| FunctionType type = types.front().cast<FunctionType>(); |
| if (type.getNumInputs() < 2) |
| return parser.emitError(parser.getNameLoc()) |
| << " expects a single data type and a predicate"; |
| result.types.assign(type.getResults().begin(), type.getResults().end()); |
| types.assign(type.getInputs().begin(), type.getInputs().end()); |
| } else { |
| if (op_infos.size() < 2) |
| return parser.emitError(parser.getNameLoc()) |
| << " expects a single data type and a predicate"; |
| Type control_type = ControlType::get(parser.getBuilder().getContext()); |
| result.types.append(2, types[0]); |
| result.types.push_back(control_type); |
| Type i1_type = parser.getBuilder().getI1Type(); |
| RankedTensorType predicate_type = RankedTensorType::get({}, i1_type); |
| types.push_back(predicate_type); |
| types.append(op_infos.size() - 2, control_type); |
| } |
| |
| llvm::SMLoc loc = parser.getCurrentLocation(); |
| if (parser.resolveOperands(op_infos, types, loc, result.operands)) |
| return failure(); |
| |
| return parser.parseOptionalAttrDict(result.attributes); |
| } |
| |
| void Print(SwitchOp switch_op, OpAsmPrinter &p) { |
| p << switch_op.getOperationName() << ' '; |
| p.printOperands(switch_op.getOperands()); |
| Type data_operand_ty = switch_op.data().getType(); |
| // If the types aren't perfectly matching, print the functional type syntax |
| // else print the shorter single type. |
| p << " : "; |
| if (switch_op.trueOutput().getType() != data_operand_ty || |
| switch_op.falseOutput().getType() != data_operand_ty || |
| switch_op.predicate().getType().isa<UnrankedTensorType>()) { |
| p.printFunctionalType(switch_op.getOperation()); |
| } else { |
| p << switch_op.getType(0); |
| } |
| p.printOptionalAttrDict(switch_op.getAttrs()); |
| } |
| |
| } // anonymous namespace |
| |
| //===----------------------------------------------------------------------===// |
| // tf_executor.SwitchN |
| //===----------------------------------------------------------------------===// |
| |
| namespace { |
| |
| LogicalResult Verify(SwitchNOp switchn) { |
| IntegerAttr num_outs = switchn.getAttrOfType<IntegerAttr>("num_outs"); |
| if (!num_outs) |
| return switchn.emitOpError() << "expects a `num_outs` integer attribute"; |
| |
| // Expect num_outs results + 1 control output. |
| if (switchn.getNumResults() != num_outs.getInt() + 1) |
| return switchn.emitOpError() |
| << "expect `num_outs` (" << num_outs.getInt() << ") results but got " |
| << (switchn.getNumResults() - 1); |
| |
| // Check that operand can be broadcasted to each output type. |
| auto operand0_type = switchn.getOperand(0).getType(); |
| TensorType operand0_tensor_type = operand0_type.dyn_cast<TensorType>(); |
| if (!operand0_tensor_type) { |
| return switchn.emitOpError() |
| << "expects data operand to have tensor type but got " |
| << operand0_type; |
| } |
| for (Type output_type : switchn.getResultTypes()) { |
| if (output_type.isa<ControlType>()) break; |
| |
| TensorType output_tensor_type = output_type.dyn_cast<TensorType>(); |
| if (!output_tensor_type) { |
| return switchn.emitOpError() |
| << "expects outputs to have tensor type but got " << output_type; |
| } |
| |
| // If the output type is a ref type, then the operand type should also be of |
| // the same ref type. However, if the output type is a non-ref type T, then |
| // the operand can be tensor of type T or T_REF. |
| bool is_output_ref = |
| output_tensor_type.getElementType().isa<TF::TensorFlowRefType>(); |
| if (is_output_ref && |
| !operand0_tensor_type.getElementType().isa<TF::TensorFlowRefType>()) { |
| return switchn.emitOpError() |
| << "expects same operand and output element type but got " |
| << operand0_tensor_type << " vs " << output_tensor_type; |
| } |
| Type broadcasted_type = OpTrait::util::getBroadcastedType( |
| DropRefType(DropTypeSubTypes(operand0_tensor_type)), |
| DropRefType(DropTypeSubTypes(output_tensor_type))); |
| if (!broadcasted_type) { |
| return switchn.emitOpError() |
| << "expects data operand to be broadcastable with all output types" |
| << " but got " << operand0_tensor_type << " vs " |
| << output_tensor_type; |
| } |
| } |
| return success(); |
| } |
| |
| void Print(SwitchNOp switchn, OpAsmPrinter &p) { |
| p << switchn.getOperationName() << ' '; |
| auto operands = switchn.getOperands(); |
| // Print the 2 data operands. |
| p.printOperands(operands.begin(), std::next(operands.begin(), 2)); |
| p << " of " << (switchn.getNumResults() - 1); |
| // print control dependencies if any |
| if (!llvm::empty(switchn.controlInputs())) { |
| p << " ("; |
| p.printOperands(switchn.controlInputs()); |
| p << ")"; |
| } |
| p << " : " << switchn.getType(0); |
| p.printOptionalAttrDict(switchn.getAttrs(), {"num_outs"}); |
| } |
| |
| ParseResult ParseSwitchNOp(OpAsmParser &parser, OperationState &result) { |
| // Parsing: |
| // %2:6 = tf_executor.SwitchN %0, %1 of 5 : tensor<??xf32> |
| // Where the first operand is the data to replicate, the second is an i32 |
| // indicating which output to populate, followed by the keyword `of` and the |
| // number of outputs (+1 for the control token). |
| SmallVector<OpAsmParser::OperandType, 2> op_infos; |
| SmallVector<Type, 1> types; |
| llvm::SMLoc loc = parser.getCurrentLocation(); |
| IntegerAttr num_outs; |
| Type i64_type = parser.getBuilder().getIntegerType(64); |
| if (parser.parseOperandList(op_infos, 2) || parser.parseKeyword("of") || |
| parser.parseAttribute(num_outs, i64_type, "num_outs", |
| result.attributes) || |
| parser.parseOperandList(op_infos, |
| OpAsmParser::Delimiter::OptionalParen) || |
| parser.parseColonTypeList(types)) |
| return failure(); |
| if (types.size() != 1) |
| return parser.emitError(parser.getNameLoc()) |
| << " expects only a single data type"; |
| |
| if (num_outs.getInt() <= 0) |
| return parser.emitError(parser.getNameLoc()) |
| << " expects a positive number of outputs"; |
| |
| // `types` already contains the type for the data, add an i32 for the |
| // output_index, and then the optional control inputs. |
| auto builder = parser.getBuilder(); |
| types.push_back(RankedTensorType::get({}, builder.getIntegerType(32))); |
| Type control_type = ControlType::get(builder.getContext()); |
| types.append(op_infos.size() - 2, control_type); |
| |
| if (parser.resolveOperands(op_infos, types, loc, result.operands)) |
| return failure(); |
| |
| // Output result types is a replication `num_outs` times the data input type. |
| result.types.append(num_outs.getInt(), types[0]); |
| result.types.push_back(control_type); |
| |
| return parser.parseOptionalAttrDict(result.attributes); |
| } |
| |
| } // anonymous namespace |
| |
| //===----------------------------------------------------------------------===// |
| // tf_executor.Merge |
| //===----------------------------------------------------------------------===// |
| |
| namespace { |
| |
| LogicalResult Verify(MergeOp merge) { |
| if (!merge.getNumOperands()) |
| return merge.emitOpError() << "expects at least one operand"; |
| |
| Type data_type = merge.getOperand(0).getType(); |
| if (data_type.isa<ControlType>()) |
| return merge.emitOpError() << "expects a non-control input"; |
| |
| // Check that each operand can be individually broadcasted to the output type. |
| Type output_type = merge.output().getType(); |
| TensorType output_tensor_ty = output_type.dyn_cast<TensorType>(); |
| if (!output_tensor_ty) { |
| return merge.emitOpError() |
| << "expects output to have tensor type but got " << output_type; |
| } |
| bool is_output_ref = |
| output_tensor_ty.getElementType().isa<TF::TensorFlowRefType>(); |
| for (Type operand_type : merge.getOperandTypes()) { |
| if (operand_type.isa<ControlType>()) break; |
| |
| // TODO(hinsu): Update ControlOperandsAfterAllData trait to verify this |
| // constraint. |
| TensorType operand_tensor_ty = operand_type.dyn_cast<TensorType>(); |
| if (!operand_tensor_ty) |
| return merge.emitOpError() |
| << "expects data operands to have tensor type but got " |
| << operand_type; |
| |
| // If output type is a ref type then all operand types should also be of the |
| // same ref type. However, if the output type is a non-ref type T, operands |
| // can be tensor of type T or T_REF. |
| if (is_output_ref && |
| !operand_tensor_ty.getElementType().isa<TF::TensorFlowRefType>()) { |
| return merge.emitOpError() |
| << "expects same operand and output element type but got " |
| << operand_tensor_ty << " vs " << output_tensor_ty; |
| } |
| Type broadcasted_type = OpTrait::util::getBroadcastedType( |
| DropRefType(DropTypeSubTypes(output_tensor_ty)), |
| DropRefType(DropTypeSubTypes(operand_tensor_ty))); |
| if (!broadcasted_type) |
| return merge.emitOpError() |
| << "expects all operands to be broadcastable with output type" |
| << " but got " << operand_tensor_ty << " vs " << output_tensor_ty; |
| } |
| return success(); |
| } |
| |
| void Print(MergeOp merge, OpAsmPrinter &p) { |
| // Use short form only when there are exactly two data operands and their |
| // type matches the output type. Otherwise, use the generic printer. |
| bool use_short_form = true; |
| int num_data_operands = 0; |
| |
| Type output_type = merge.output().getType(); |
| for (Type operand_type : merge.getOperandTypes()) { |
| if (operand_type.isa<ControlType>()) break; |
| num_data_operands++; |
| |
| if (operand_type != output_type) { |
| use_short_form = false; |
| break; |
| } |
| } |
| |
| p << merge.getOperationName() << ' '; |
| p.printOperands(merge.getOperands()); |
| |
| // Print the type signature of the operation. |
| p << " : "; |
| if (!use_short_form || num_data_operands != 2) { |
| p.printFunctionalType(merge.getOperation()); |
| } else { |
| p << output_type; |
| } |
| |
| p.printOptionalAttrDict(merge.getAttrs()); |
| } |
| |
| ParseResult ParseMergeOp(OpAsmParser &parser, OperationState &result) { |
| SmallVector<OpAsmParser::OperandType, 2> op_infos; |
| SmallVector<Type, 1> types; |
| llvm::SMLoc loc = parser.getCurrentLocation(); |
| if (parser.parseOperandList(op_infos) || parser.parseColonTypeList(types)) |
| return failure(); |
| if (types.size() != 1) |
| return parser.emitError(parser.getNameLoc()) |
| << " expects only a single data type"; |
| |
| // Support parsing either a functional type (in which case all the types are |
| // fully qualified) or a short form with a single type (in which case the data |
| // inputs and the output are all using this type). |
| if (FunctionType type = types.front().dyn_cast<FunctionType>()) { |
| result.types.assign(type.getResults().begin(), type.getResults().end()); |
| types.assign(type.getInputs().begin(), type.getInputs().end()); |
| } else { |
| // In case of the short form, use the parsed type for both the operands and |
| // the remaining operands are expected to be control inputs. |
| types.push_back(types.front()); |
| Type control_type = ControlType::get(parser.getBuilder().getContext()); |
| types.append(op_infos.size() - 2, control_type); |
| |
| RankedTensorType i32_tensor = |
| RankedTensorType::get({}, parser.getBuilder().getIntegerType(32)); |
| result.types = {types.front(), i32_tensor, control_type}; |
| } |
| |
| if (parser.resolveOperands(op_infos, types, loc, result.operands)) |
| return failure(); |
| |
| return parser.parseOptionalAttrDict(result.attributes); |
| } |
| |
| } // anonymous namespace |
| |
| //===----------------------------------------------------------------------===// |
| // tf_executor.Enter |
| //===----------------------------------------------------------------------===// |
| |
| namespace { |
| |
| // Default number for the parallel_iterations attributes on Enter nodes. |
| constexpr int kDefaultParallelIterations = 10; |
| |
| void Print(EnterOp enter, OpAsmPrinter &p) { |
| p << enter.getOperationName() << ' '; |
| p.printOperands(enter.getOperands()); |
| |
| p << " frame \""; |
| printEscapedString(enter.frame_name(), p.getStream()); |
| p << "\""; |
| if (enter.parallel_iterations() != kDefaultParallelIterations) |
| p << " parallel_iterations " << enter.parallel_iterations(); |
| if (enter.is_constant()) p << " constant "; |
| |
| // If the types aren't perfectly matching, print the functional type syntax |
| // else print the shorter single type. |
| p << " : "; |
| if (enter.data().getType() != enter.output().getType()) { |
| p.printFunctionalType(enter.getOperation()); |
| } else { |
| p << enter.getType(0); |
| } |
| |
| p.printOptionalAttrDict(enter.getAttrs(), |
| {"frame_name", "parallel_iterations", "is_constant"}); |
| } |
| |
| ParseResult ParseEnterOp(OpAsmParser &parser, OperationState &result) { |
| SmallVector<OpAsmParser::OperandType, 2> op_infos; |
| llvm::SMLoc loc = parser.getCurrentLocation(); |
| MLIRContext *context = parser.getBuilder().getContext(); |
| if (parser.parseOperandList(op_infos)) return failure(); |
| if (op_infos.empty()) |
| return parser.emitError(loc) << " expects at least one data operand"; |
| |
| Attribute frame; |
| if (parser.parseKeyword("frame") || |
| parser.parseAttribute(frame, NoneType::get(context), "frame_name", |
| result.attributes)) |
| return failure(); |
| |
| Type i64 = parser.getBuilder().getIntegerType(64); |
| if (parser.parseOptionalKeyword("parallel_iterations")) { |
| result.addAttribute("parallel_iterations", |
| IntegerAttr::get(i64, kDefaultParallelIterations)); |
| } else { |
| IntegerAttr parallel_iterations; |
| if (parser.parseAttribute(parallel_iterations, i64, "parallel_iterations", |
| result.attributes)) |
| return failure(); |
| } |
| bool has_constant = succeeded(parser.parseOptionalKeyword("constant")); |
| result.addAttribute("is_constant", BoolAttr::get(has_constant, context)); |
| |
| SmallVector<Type, 1> types; |
| if (parser.parseColonTypeList(types)) return failure(); |
| if (types.size() != 1) |
| return parser.emitError(loc) << " expects only a single data type"; |
| |
| // Support parsing either a functional type (in which case all the types are |
| // fully qualified) or a short form with a single type (in which case the data |
| // input and the outputs are all using this type). |
| if (FunctionType type = types.front().dyn_cast<FunctionType>()) { |
| // One data input, and any number of control inputs. |
| if (type.getNumInputs() >= 1) { |
| result.types.assign(type.getResults().begin(), type.getResults().end()); |
| types.assign(type.getInputs().begin(), type.getInputs().end()); |
| } else { |
| return parser.emitError(parser.getNameLoc()) << " expects a data input"; |
| } |
| } else { |
| Type control_type = ControlType::get(context); |
| types.append(op_infos.size() - 1, control_type); |
| result.addTypes({types.front(), control_type}); |
| } |
| |
| // Extra operands are expected to be control inputs. |
| |
| if (parser.resolveOperands(op_infos, types, loc, result.operands)) |
| return failure(); |
| |
| return parser.parseOptionalAttrDict(result.attributes); |
| } |
| |
| } // anonymous namespace |
| |
| //===----------------------------------------------------------------------===// |
| // tf_executor.NextIteration.Source |
| //===----------------------------------------------------------------------===// |
| |
| namespace { |
| |
| LogicalResult Verify(NextIterationSourceOp source) { |
| Value token = source.token(); |
| if (!token.hasOneUse()) |
| return source.emitOpError() << "expects a single user for produced token"; |
| if (!isa<NextIterationSinkOp>(*token.user_begin())) |
| return source.emitOpError() << "token should be consumed by a sink op"; |
| return success(); |
| } |
| |
| void Print(NextIterationSourceOp next_iteration, OpAsmPrinter &p) { |
| p << next_iteration.getOperationName() << " : " << next_iteration.getType(0); |
| p.printOptionalAttrDict(next_iteration.getAttrs()); |
| } |
| |
| ParseResult ParseNextIterationSourceOp(OpAsmParser &parser, |
| OperationState &result) { |
| SmallVector<Type, 1> types; |
| if (parser.parseColonTypeList(types)) return failure(); |
| |
| MLIRContext *context = parser.getBuilder().getContext(); |
| Type token_type = TokenType::get(context); |
| Type control_type = ControlType::get(context); |
| result.addTypes({types.front(), token_type, control_type}); |
| return parser.parseOptionalAttrDict(result.attributes); |
| } |
| |
| } // anonymous namespace |
| |
| //===----------------------------------------------------------------------===// |
| // tf_executor.NextIteration.Sink |
| //===----------------------------------------------------------------------===// |
| |
| namespace { |
| |
| LogicalResult Verify(NextIterationSinkOp sink) { |
| Value token = sink.token(); |
| Operation *definingOp = token.getDefiningOp(); |
| if (!definingOp) |
| return sink.emitOpError() << "expects a token directly produced by a " |
| "tf_executor.NextIteration.Source op: "; |
| auto source = dyn_cast<NextIterationSourceOp>(definingOp); |
| if (!source) |
| return sink.emitOpError() << "expects a token produced by a " |
| "tf_executor.NextIteration.Source op: "; |
| if (source.output().getType() != sink.input().getType()) |
| return sink.emitOpError() |
| << "input type " << sink.input().getType() |
| << " mismatch the tf_executor.NextIteration.Source output type: " |
| << source.output().getType(); |
| return success(); |
| } |
| |
| void Print(NextIterationSinkOp next_iteration, OpAsmPrinter &p) { |
| p << next_iteration.getOperationName() << " ["; |
| p.printOperand(next_iteration.getOperand(0)); |
| p << "] "; |
| p.printOperands(llvm::drop_begin(next_iteration.getOperands(), 1)); |
| p << " : " << next_iteration.getOperand(1).getType(); |
| p.printOptionalAttrDict(next_iteration.getAttrs()); |
| } |
| |
| ParseResult ParseNextIterationSinkOp(OpAsmParser &parser, |
| OperationState &result) { |
| SmallVector<OpAsmParser::OperandType, 2> op_infos; |
| llvm::SMLoc loc = parser.getCurrentLocation(); |
| |
| // First type is always the token consumed from the NextIteration.source |
| Type token_type = TokenType::get(parser.getBuilder().getContext()); |
| SmallVector<Type, 1> types = {token_type}; |
| |
| if (parser.parseOperandList(op_infos, 1, OpAsmParser::Delimiter::Square) || |
| parser.parseOperandList(op_infos) || parser.parseColonTypeList(types)) |
| return failure(); |
| |
| Type control_type = ControlType::get(parser.getBuilder().getContext()); |
| types.append(op_infos.size() - 2, control_type); |
| if (parser.resolveOperands(op_infos, types, loc, result.operands)) |
| return failure(); |
| |
| return parser.parseOptionalAttrDict(result.attributes); |
| } |
| |
| } // anonymous namespace |
| |
| //===----------------------------------------------------------------------===// |
| // tf_executor.Exit |
| //===----------------------------------------------------------------------===// |
| |
| namespace { |
| |
| void Print(ExitOp exit, OpAsmPrinter &p) { |
| p << exit.getOperationName() << ' '; |
| p.printOperands(exit.getOperands()); |
| p << " : " << exit.getType(0); |
| p.printOptionalAttrDict(exit.getAttrs()); |
| } |
| |
| ParseResult ParseExitOp(OpAsmParser &parser, OperationState &result) { |
| SmallVector<OpAsmParser::OperandType, 2> op_infos; |
| SmallVector<Type, 1> types; |
| |
| if (parser.parseOperandList(op_infos) || parser.parseColonTypeList(types)) |
| return failure(); |
| |
| llvm::SMLoc loc = parser.getCurrentLocation(); |
| Type control_type = ControlType::get(parser.getBuilder().getContext()); |
| types.append(op_infos.size() - 1, control_type); |
| if (parser.resolveOperands(op_infos, types, loc, result.operands)) |
| return failure(); |
| |
| result.addTypes({types.front(), control_type}); |
| return parser.parseOptionalAttrDict(result.attributes); |
| } |
| |
| } // anonymous namespace |
| |
| //===----------------------------------------------------------------------===// |
| // tf_executor.ControlTrigger |
| //===----------------------------------------------------------------------===// |
| |
| namespace { |
| |
| void Print(ControlTriggerOp trigger, OpAsmPrinter &p) { |
| p << trigger.getOperationName() << ' '; |
| p.printOperands(trigger.getOperands()); |
| p.printOptionalAttrDict(trigger.getAttrs()); |
| } |
| |
| ParseResult ParseControlTriggerOp(OpAsmParser &parser, OperationState &result) { |
| SmallVector<OpAsmParser::OperandType, 2> op_infos; |
| SmallVector<Type, 1> types; |
| llvm::SMLoc loc = parser.getCurrentLocation(); |
| |
| if (parser.parseOperandList(op_infos)) return failure(); |
| Type control_type = ControlType::get(parser.getBuilder().getContext()); |
| types.append(op_infos.size(), control_type); |
| if (parser.resolveOperands(op_infos, types, loc, result.operands)) |
| return failure(); |
| |
| // Single control as the only output |
| result.types.push_back(control_type); |
| return parser.parseOptionalAttrDict(result.attributes); |
| } |
| |
| } // anonymous namespace |
| |
| //===----------------------------------------------------------------------===// |
| // tf_executor.LoopCond |
| //===----------------------------------------------------------------------===// |
| |
| namespace { |
| |
| void Print(LoopCondOp loop_cond, OpAsmPrinter &p) { |
| p << loop_cond.getOperationName() << ' '; |
| p.printOperands(loop_cond.getOperands()); |
| |
| // If the types aren't matching (broadcast), print the functional type syntax. |
| if (loop_cond.input().getType() != loop_cond.output().getType()) { |
| p << " : "; |
| p.printFunctionalType(loop_cond.getOperation()); |
| } else { |
| p << " : " << loop_cond.input().getType(); |
| } |
| |
| p.printOptionalAttrDict(loop_cond.getAttrs()); |
| } |
| |
| ParseResult ParseLoopCondOp(OpAsmParser &parser, OperationState &result) { |
| SmallVector<OpAsmParser::OperandType, 2> op_infos; |
| |
| if (parser.parseOperandList(op_infos)) return failure(); |
| if (op_infos.empty()) |
| return parser.emitError(parser.getNameLoc()) |
| << "expects at least one operand"; |
| |
| SmallVector<Type, 1> types; |
| if (parser.parseColonTypeList(types)) return failure(); |
| |
| // Support parsing either a functional type (in which case all the types are |
| // fully qualified) or a short form with a single type (in which case the data |
| // input and the outputs are all using this type). |
| Type control_type = ControlType::get(parser.getBuilder().getContext()); |
| if (FunctionType type = types.front().dyn_cast<FunctionType>()) { |
| if (llvm::count_if(type.getInputs(), |
| [=](Type type) { return type != control_type; }) != 1) |
| return parser.emitError(parser.getNameLoc()) |
| << " expects a single data type"; |
| result.types.assign(type.getResults().begin(), type.getResults().end()); |
| types.assign(type.getInputs().begin(), type.getInputs().end()); |
| } else { |
| if (types.size() != 1) |
| return parser.emitError(parser.getNameLoc()) |
| << " expects a single data type"; |
| types.append(op_infos.size() - 1, control_type); |
| result.addTypes({types.front(), control_type}); |
| } |
| |
| llvm::SMLoc loc = parser.getCurrentLocation(); |
| if (parser.resolveOperands(op_infos, types, loc, result.operands)) |
| return failure(); |
| |
| return parser.parseOptionalAttrDict(result.attributes); |
| } |
| |
| } // namespace |
| |
| //===----------------------------------------------------------------------===// |
| // Canonicalization patterns |
| //===----------------------------------------------------------------------===// |
| |
| // TODO(lyandy): Add canonicalization for dedupping control inputs. |
| |
| //===----------------------------------------------------------------------===// |
| // tf_executor.graph |
| //===----------------------------------------------------------------------===// |
| |
| namespace { |
| // Finds in a block if the op of type `InnerOpT` is the first operation and |
| // optionally followed by a terminator. |
| template <typename InnerOpT> |
| bool HasSingleOpInBlock(Block *block) { |
| if (block->empty()) return false; |
| if (!llvm::isa<InnerOpT>(block->front())) return false; |
| // Either InnerOpT is the only instruction in the block, or there is a |
| // possible terminator. |
| return std::next(block->begin()) == block->end() || |
| std::next(block->begin(), 2) == block->end(); |
| } |
| |
| // This pattern matches GraphOps with only one FetchOp (empty) and remaps the |
| // results of the GraphOp to the operands of the FetchOp. |
| struct DropEmptyGraph : public OpRewritePattern<GraphOp> { |
| using OpRewritePattern<GraphOp>::OpRewritePattern; |
| |
| LogicalResult matchAndRewrite(GraphOp op, |
| PatternRewriter &rewriter) const override { |
| Block &block = op.GetBody(); |
| // Check if graph only has one fetch. |
| if (&block.front() != &block.back()) return failure(); |
| |
| // Map graph results to fetch operands. |
| rewriter.replaceOp(op, op.GetFetch().fetches()); |
| |
| return success(); |
| } |
| }; |
| |
| // This pattern matches GraphOps with only one island, pulls out all inner ops |
| // of the island to the block containing the GraphOp, and then removes the |
| // GraphOp. |
| struct HoistInnerOpsSingleIslandGraph : public OpRewritePattern<GraphOp> { |
| using OpRewritePattern<GraphOp>::OpRewritePattern; |
| |
| LogicalResult matchAndRewrite(GraphOp op, |
| PatternRewriter &rewriter) const override { |
| Block &block = op.GetBody(); |
| // Check if graph only has one island. |
| if (!HasSingleOpInBlock<IslandOp>(&block)) return failure(); |
| |
| FetchOp fetch_op = op.GetFetch(); |
| auto island_op = llvm::cast<IslandOp>(block.front()); |
| YieldOp yield_op = island_op.GetYield(); |
| |
| // Map graph results to inner ops results of single island. |
| llvm::SmallVector<Value, 8> new_rets; |
| for (Value operand : fetch_op.fetches()) { |
| // Control results should not be propagated out. |
| if (operand.getType().isa<ControlType>()) break; |
| |
| if (operand.getDefiningOp() != island_op) { |
| // Operand is not from island, simply propagate it out. |
| new_rets.push_back(operand); |
| } else { |
| // Lookup yield operand in island for inner op result. |
| auto result = operand.cast<OpResult>(); |
| new_rets.push_back(yield_op.getOperand(result.getResultNumber())); |
| } |
| } |
| |
| // Move inner ops from island to block containing graph. |
| auto &island_body = island_op.GetBody().getOperations(); |
| Operation *operation = op.getOperation(); |
| operation->getBlock()->getOperations().splice( |
| operation->getIterator(), island_body, island_body.begin(), |
| std::prev(island_body.end())); |
| rewriter.replaceOp(op, new_rets); |
| |
| return success(); |
| } |
| }; |
| } // anonymous namespace |
| |
| void GraphOp::getCanonicalizationPatterns(OwningRewritePatternList &results, |
| MLIRContext *context) { |
| results.insert<DropEmptyGraph, HoistInnerOpsSingleIslandGraph>(context); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // tf_executor.island |
| //===----------------------------------------------------------------------===// |
| |
| namespace { |
| // This pattern matches and removes IslandOps with no inner ops, no control |
| // operands and no data results. Control result users will have their relevant |
| // operands removed. |
| struct DropEmptyIslandNoOperandNoDataResult |
| : public OpRewritePattern<IslandOp> { |
| using OpRewritePattern<IslandOp>::OpRewritePattern; |
| |
| LogicalResult matchAndRewrite(IslandOp op, |
| PatternRewriter &rewriter) const override { |
| if (op.getNumOperands() != 0 || op.getNumResults() != 1 || |
| !HasSingleOpInBlock<YieldOp>(&op.GetBody())) |
| return failure(); |
| |
| for (auto &use : llvm::make_early_inc_range(op.control().getUses())) |
| use.getOwner()->eraseOperand(use.getOperandNumber()); |
| |
| rewriter.eraseOp(op); |
| |
| return success(); |
| } |
| }; |
| |
| // This pattern matches and removes IslandOps with no inner ops, no control |
| // operands, one data result and no control result user. The single data result |
| // (from YieldOps first operand) is forwarded to the IslandOp single data result |
| // users. |
| struct DropEmptyIslandNoOperandOneDataResult |
| : public OpRewritePattern<IslandOp> { |
| using OpRewritePattern<IslandOp>::OpRewritePattern; |
| |
| LogicalResult matchAndRewrite(IslandOp op, |
| PatternRewriter &rewriter) const override { |
| if (op.getNumOperands() != 0 || op.getNumResults() != 2 || |
| !op.control().use_empty() || |
| !HasSingleOpInBlock<YieldOp>(&op.GetBody())) |
| return failure(); |
| |
| rewriter.replaceOp(op, {op.GetYield().getOperand(0), nullptr}); |
| |
| return success(); |
| } |
| }; |
| |
| // TODO(lyandy): Add canonicalization for empty IslandOps with more than one |
| // control operand and no data results. |
| |
| } // anonymous namespace |
| |
| void IslandOp::getCanonicalizationPatterns(OwningRewritePatternList &results, |
| MLIRContext *context) { |
| results.insert<DropEmptyIslandNoOperandNoDataResult, |
| DropEmptyIslandNoOperandOneDataResult>(context); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // tf_executor.ControlTrigger |
| //===----------------------------------------------------------------------===// |
| |
| namespace { |
| // This pattern matches and removes ControlTriggerOps with no control operands. |
| // Control result users will have their relevant operands removed. |
| struct DropEmptyControlTrigger : public OpRewritePattern<ControlTriggerOp> { |
| using OpRewritePattern<ControlTriggerOp>::OpRewritePattern; |
| |
| LogicalResult matchAndRewrite(ControlTriggerOp op, |
| PatternRewriter &rewriter) const override { |
| if (op.getNumOperands() != 0) return failure(); |
| |
| for (auto &use : llvm::make_early_inc_range(op.control().getUses())) |
| use.getOwner()->eraseOperand(use.getOperandNumber()); |
| |
| rewriter.eraseOp(op); |
| |
| return success(); |
| } |
| }; |
| } // anonymous namespace |
| |
| void ControlTriggerOp::getCanonicalizationPatterns( |
| OwningRewritePatternList &results, MLIRContext *context) { |
| results.insert<DropEmptyControlTrigger>(context); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Folders |
| //===----------------------------------------------------------------------===// |
| |
| //===----------------------------------------------------------------------===// |
| // tf_executor.island |
| //===----------------------------------------------------------------------===// |
| |
| LogicalResult IslandOp::fold(llvm::ArrayRef<Attribute> operands, |
| llvm::SmallVectorImpl<OpFoldResult> &results) { |
| // This folds IslandOps with no inner ops, one control operand and no data |
| // results. The single control operand is forwarded to the IslandOp control |
| // result users. |
| if (getNumOperands() != 1 || getNumResults() != 1 || |
| !HasSingleOpInBlock<YieldOp>(&GetBody())) |
| return failure(); |
| |
| results.emplace_back(getOperand(0)); |
| |
| return success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // TableGen'd op method definitions |
| //===----------------------------------------------------------------------===// |
| |
| #define GET_OP_CLASSES |
| #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.cc.inc" |
| |
| } // namespace tf_executor |
| } // namespace mlir |