| /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. |
| |
| Licensed under the Apache License, Version 2.0 (the "License"); |
| you may not use this file except in compliance with the License. |
| You may obtain a copy of the License at |
| |
| http://www.apache.org/licenses/LICENSE-2.0 |
| |
| Unless required by applicable law or agreed to in writing, software |
| distributed under the License is distributed on an "AS IS" BASIS, |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| See the License for the specific language governing permissions and |
| limitations under the License. |
| ==============================================================================*/ |
| |
| #include "tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.h" |
| |
| #include <algorithm> |
| #include <cstdint> |
| #include <initializer_list> |
| #include <iterator> |
| #include <queue> |
| |
| #include "llvm/ADT/Hashing.h" |
| #include "llvm/ADT/None.h" |
| #include "llvm/ADT/PointerUnion.h" |
| #include "llvm/ADT/STLExtras.h" |
| #include "llvm/ADT/SmallVector.h" |
| #include "llvm/ADT/iterator_range.h" |
| #include "llvm/Support/Casting.h" |
| #include "llvm/Support/Debug.h" |
| #include "llvm/Support/FormatVariadic.h" |
| #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project |
| #include "mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project |
| #include "mlir/IR/Attributes.h" // from @llvm-project |
| #include "mlir/IR/Block.h" // from @llvm-project |
| #include "mlir/IR/Builders.h" // from @llvm-project |
| #include "mlir/IR/BuiltinOps.h" // from @llvm-project |
| #include "mlir/IR/BuiltinTypes.h" // from @llvm-project |
| #include "mlir/IR/Diagnostics.h" // from @llvm-project |
| #include "mlir/IR/Location.h" // from @llvm-project |
| #include "mlir/IR/Operation.h" // from @llvm-project |
| #include "mlir/IR/OperationSupport.h" // from @llvm-project |
| #include "mlir/IR/SymbolTable.h" // from @llvm-project |
| #include "mlir/IR/TypeUtilities.h" // from @llvm-project |
| #include "mlir/IR/Value.h" // from @llvm-project |
| #include "mlir/Interfaces/CallInterfaces.h" // from @llvm-project |
| #include "mlir/Interfaces/FoldInterfaces.h" // from @llvm-project |
| #include "mlir/Interfaces/InferTypeOpInterface.h" // from @llvm-project |
| #include "mlir/Pass/Pass.h" // from @llvm-project |
| #include "mlir/Pass/PassRegistry.h" // from @llvm-project |
| #include "mlir/Support/LLVM.h" // from @llvm-project |
| #include "mlir/Support/LogicalResult.h" // from @llvm-project |
| #include "mlir/Transforms/FoldUtils.h" // from @llvm-project |
| #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" |
| #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" |
| #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" |
| #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" |
| #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" |
| #include "tensorflow/compiler/mlir/tensorflow/utils/shape_inference_utils.h" |
| #include "tensorflow/compiler/mlir/tensorflow/utils/translate_utils.h" |
| #include "tensorflow/core/framework/shape_inference.h" |
| #include "tensorflow/core/framework/types.pb.h" |
| |
| #define DEBUG_TYPE "tf-shape-inference" |
| |
| #define DCOMMENT(MSG) LLVM_DEBUG(llvm::dbgs() << MSG << "\n") |
| #define DCOMMENT_OP(OP, MSG) \ |
| LLVM_DEBUG(OP->print(llvm::dbgs() << MSG << " "); llvm::dbgs() << "\n") |
| |
| using ::tensorflow::int64; |
| using tensorflow::shape_inference::DimensionHandle; |
| using tensorflow::shape_inference::InferenceContext; |
| using tensorflow::shape_inference::ShapeHandle; |
| |
| namespace mlir { |
| namespace TF { |
| namespace { |
| |
| // Returns whether type can be further refined. |
| bool CanBeRefined(Type type) { |
| auto shape_type = type.dyn_cast<ShapedType>(); |
| if (!shape_type) return false; |
| |
| // Returns whether type with subtypes can be further refined. |
| auto can_refine_subtypes = [](TF::TensorFlowTypeWithSubtype tws) { |
| return tws.GetSubtypes().empty() || |
| llvm::any_of(tws.GetSubtypes(), CanBeRefined); |
| }; |
| auto type_with_subtype = |
| shape_type.getElementType().dyn_cast<TF::TensorFlowTypeWithSubtype>(); |
| if (type_with_subtype && can_refine_subtypes(type_with_subtype)) return true; |
| |
| return !shape_type.hasStaticShape(); |
| } |
| |
| // Returns whether `original_type` type can be refined with |
| // `potential_refined_type` type. |
| bool CanRefineTypeWith(Type original_type, Type potential_refined_type) { |
| if (original_type == potential_refined_type || !CanBeRefined(original_type)) |
| return false; |
| |
| auto shape_type = potential_refined_type.dyn_cast<ShapedType>(); |
| if (!shape_type) return false; |
| if (shape_type.hasRank()) return true; |
| |
| auto element_type_with_subtype = |
| shape_type.getElementType().dyn_cast<TF::TensorFlowTypeWithSubtype>(); |
| return element_type_with_subtype && |
| !element_type_with_subtype.GetSubtypes().empty(); |
| } |
| |
| // Returns if the shape inference pass supports an op outside the TF dialect. |
| bool IsSupportedNonTFOp(Operation* op) { |
| return isa<tf_device::ReturnOp, tf_device::ClusterOp, tf_device::LaunchOp, |
| tf_executor::EnterOp, tf_executor::ExitOp, tf_executor::FetchOp, |
| tf_executor::GraphOp, tf_executor::IslandOp, |
| tf_executor::LoopCondOp, tf_executor::MergeOp, |
| tf_executor::NextIterationSinkOp, tf_executor::SwitchNOp, |
| tf_executor::SwitchOp, tf_executor::YieldOp>(op); |
| } |
| |
| // Returns whether a cast back would need to be inserted, e.g., whether the |
| // operation of which use is an operand allows for shape refinement without |
| // a cast. |
| bool NeedsCastBack(OpOperand& use, Dialect* tf_dialect) { |
| return use.getOwner()->getDialect() != tf_dialect && |
| !IsSupportedNonTFOp(use.getOwner()); |
| } |
| |
| TensorType CreateTensorType(llvm::Optional<llvm::ArrayRef<int64_t>> shape, |
| Type element_type) { |
| if (shape.hasValue()) |
| return RankedTensorType::get(shape.getValue(), element_type); |
| return UnrankedTensorType::get(element_type); |
| } |
| |
| // Returns true if the op creates a TensorList. |
| bool IsTensorListInitOp(Operation* op) { |
| return isa<TensorListReserveOp>(op) || isa<EmptyTensorListOp>(op) || |
| isa<TensorListFromTensorOp>(op); |
| } |
| |
| // Returns the `element_shape` operand of the ops that create a TensorList. |
| Value GetElementShapeOperand(Operation* op) { |
| if (auto empty_tl = dyn_cast<EmptyTensorListOp>(op)) |
| return empty_tl.element_shape(); |
| if (auto tl_reserve = dyn_cast<TensorListReserveOp>(op)) |
| return tl_reserve.element_shape(); |
| if (auto tl_from_tensor = dyn_cast<TensorListFromTensorOp>(op)) |
| return tl_from_tensor.element_shape(); |
| llvm_unreachable("unsupported TensorList op"); |
| } |
| |
| // Utility function to create a ranked tensor type after dropping the first |
| // dimension from the input type. |
| RankedTensorType DropFirstDimension(Type type) { |
| RankedTensorType ranked_type = type.dyn_cast<RankedTensorType>(); |
| if (!ranked_type) return {}; |
| llvm::ArrayRef<int64_t> dims_except_first = |
| ranked_type.getShape().drop_front(); |
| return RankedTensorType::get(dims_except_first, ranked_type.getElementType()); |
| } |
| |
| // Follow the use chain of TensorList and return true iff all elements written |
| // to TensorList have same static shape. If all elements have same shape, assign |
| // it to `potential_element_type`. |
| // |
| // This can handle multiple mutations of a TensorList object and would return |
| // true if across all mutations the elements written have the same shape. |
| bool CanInferTensorListElementType(Value tensorlist, |
| Value initial_element_shape, |
| RankedTensorType* potential_element_type) { |
| // Verifies if the new element type has static shape and matches the potential |
| // type passed from caller. Updates the potential_element_type, if not defined |
| // yet. |
| auto verify_and_update_potential_element_type = |
| [&](RankedTensorType new_element_type) -> bool { |
| if (!new_element_type || !new_element_type.hasStaticShape()) return false; |
| if (!*potential_element_type) { |
| *potential_element_type = new_element_type; |
| return true; |
| } |
| return *potential_element_type == new_element_type; |
| }; |
| |
| // TensorLists are semantically immutable. For example, TensorListSetItem |
| // takes a TensorList as input and produces a TensorList as output. So to |
| // traverse modifications to TensorList and verify that all elements written |
| // to it have the same shape, we need to follow use-def chain of ops that |
| // (conceptually) modify it i.e., ops that take an input TensorList and |
| // produce an output TensorList. |
| for (auto& use : tensorlist.getUses()) { |
| if (auto push = llvm::dyn_cast<TensorListPushBackOp>(use.getOwner())) { |
| auto element_type = push.tensor().getType().dyn_cast<RankedTensorType>(); |
| if (!verify_and_update_potential_element_type(element_type)) return false; |
| if (!CanInferTensorListElementType(push.output_handle(), |
| initial_element_shape, |
| potential_element_type)) |
| return false; |
| continue; |
| } |
| if (auto scatter = llvm::dyn_cast<TensorListScatterIntoExistingListOp>( |
| use.getOwner())) { |
| // For scatter op we can get the element shape by dropping the first |
| // dimension of the input tensor. |
| RankedTensorType element_type = |
| DropFirstDimension(scatter.tensor().getType()); |
| if (!verify_and_update_potential_element_type(element_type)) return false; |
| if (!CanInferTensorListElementType(scatter.output_handle(), |
| initial_element_shape, |
| potential_element_type)) |
| return false; |
| continue; |
| } |
| if (auto set_item = llvm::dyn_cast<TensorListSetItemOp>(use.getOwner())) { |
| auto element_type = |
| set_item.item().getType().dyn_cast<RankedTensorType>(); |
| if (!verify_and_update_potential_element_type(element_type)) return false; |
| if (!CanInferTensorListElementType(set_item.output_handle(), |
| initial_element_shape, |
| potential_element_type)) |
| return false; |
| continue; |
| } |
| if (auto pop = llvm::dyn_cast<TensorListPopBackOp>(use.getOwner())) { |
| if (!CanInferTensorListElementType(pop.output_handle(), |
| initial_element_shape, |
| potential_element_type)) |
| return false; |
| continue; |
| } |
| if (auto resize = llvm::dyn_cast<TensorListResizeOp>(use.getOwner())) { |
| if (!CanInferTensorListElementType(resize.output_handle(), |
| initial_element_shape, |
| potential_element_type)) |
| return false; |
| continue; |
| } |
| // WhileRegionOp can explicitly capture TensorList value to be used inside |
| // its regions. So we check the uses of corresponding block argument in each |
| // region and the use of TensorList returned using YieldOp. |
| if (auto while_region = llvm::dyn_cast<WhileRegionOp>(use.getOwner())) { |
| for (auto branch : while_region.getRegions()) { |
| if (!CanInferTensorListElementType( |
| branch->getArgument(use.getOperandNumber()), |
| initial_element_shape, potential_element_type)) |
| return false; |
| } |
| continue; |
| } |
| if (auto yield = llvm::dyn_cast<YieldOp>(use.getOwner())) { |
| Operation* parent = yield->getParentOp(); |
| if (!CanInferTensorListElementType( |
| parent->getResult(use.getOperandNumber()), initial_element_shape, |
| potential_element_type)) |
| return false; |
| continue; |
| } |
| // Refining the tensor list element type might change the output of |
| // TensorListElementShape which is expected to be the originally assigned |
| // shape to TensorList init ops. So replace it with the original element |
| // shape value. |
| if (auto tl_element_shape = |
| dyn_cast<TensorListElementShapeOp>(use.getOwner())) { |
| // If element types match, we can do a direct replacement. |
| if (getElementTypeOrSelf(tl_element_shape.getResult()) == |
| getElementTypeOrSelf(initial_element_shape.getType())) { |
| tl_element_shape.replaceAllUsesWith(initial_element_shape); |
| } else { |
| OpBuilder b(use.getOwner()); |
| auto cast_op = b.create<TF::CastOp>( |
| use.getOwner()->getLoc(), tl_element_shape.getResult().getType(), |
| initial_element_shape, |
| /*truncate=*/b.getBoolAttr(false)); |
| tl_element_shape.replaceAllUsesWith(cast_op.getResult()); |
| } |
| continue; |
| } |
| // Ignore ops that just consume a TensorList and do not output another |
| // TensorList. |
| if (isa<TensorListStackOp, TensorListGatherOp, TensorListConcatV2Op, |
| TensorListLengthOp, TensorListGetItemOp>(use.getOwner())) |
| continue; |
| |
| // For any other unknown users of the TensorList, we are conservative and |
| // stop element shape inference. |
| return false; |
| } |
| return true; |
| } |
| } // namespace |
| |
| // Combination of value producer and port of value produced (e.g., |
| // <value result output>:<value in output tensor>, |
| // so for tf.Const -> tensor<10x20xf32>, [0,2,18] would point to a unique output |
| // scalar value). |
| struct ValuePort { |
| PointerUnion<Operation*, BlockArgument> producer; |
| SmallVector<unsigned int, 2> port; |
| |
| bool operator==(const ValuePort& other) const { |
| return producer == other.producer && port == other.port; |
| } |
| |
| // Convert output value to ValuePort. |
| explicit ValuePort(Value v) { |
| OpResult opr = v.dyn_cast<OpResult>(); |
| if (opr) { |
| producer = opr.getOwner(); |
| port = {opr.getResultNumber()}; |
| } else { |
| producer = v.cast<BlockArgument>(); |
| port = {0}; |
| } |
| } |
| ValuePort(PointerUnion<Operation*, BlockArgument> producer, |
| SmallVector<unsigned int, 2> port) |
| : producer(producer), port(port) {} |
| |
| raw_ostream& print(raw_ostream& os) const { |
| if (auto* op = producer.dyn_cast<Operation*>()) |
| os << "op " << op->getName(); |
| if (auto ba = producer.dyn_cast<BlockArgument>()) |
| os << "block_arg " << ba.getArgNumber(); |
| os << formatv(" [{0}]", llvm::make_range(port.begin(), port.end())); |
| return os; |
| } |
| }; |
| |
| struct ValuePortHasher { |
| std::size_t operator()(const ValuePort& other) const { |
| return hash_combine(llvm::hash_value(other.producer.getOpaqueValue()), |
| hash_value(ArrayRef<unsigned int>(other.port))); |
| } |
| }; |
| |
| using ValuePortResultMap = |
| std::unordered_map<ValuePort, Attribute, ValuePortHasher>; |
| using ComputedQueryFn = function_ref<bool(ValuePort)>; |
| using ValueQueryFn = function_ref<Attribute(const ValuePort&)>; |
| using ValuePortInputs = SmallVectorImpl<ValuePort>; |
| |
| // TODO(jpienaar): ComputeInputsRequiredForOutput and ComputeOutputComponent are |
| // intended to be switched to op interfaces once more refined. |
| LogicalResult ComputeInputsRequiredForOutput(ValuePort value_port, |
| ComputedQueryFn has_been_computed, |
| ValuePortInputs* inputs) { |
| auto op = value_port.producer.dyn_cast<Operation*>(); |
| auto& port = value_port.port; |
| if (!op) return failure(); |
| |
| // No inputs required for constants. |
| if (matchPattern(op, m_Constant())) return success(); |
| |
| // Note: this focusses only on the trivial pack op case and this could be |
| // generalized. |
| if (auto pack_op = dyn_cast<TF::PackOp>(op)) { |
| auto type = pack_op.getType().cast<TensorType>(); |
| if (!type.hasRank() || type.getRank() != 1) return failure(); |
| if (port.size() != 2) return failure(); |
| assert(port[0] == 0); |
| ValuePort req(pack_op.getOperand(port[1])); |
| if (!has_been_computed(req)) inputs->push_back(req); |
| return success(); |
| } |
| |
| return failure(); |
| } |
| |
| // Computes the output produced by ValuePort using the query function of |
| // existing computed values. |
| Attribute ComputeOutputComponent(const ValuePort& value_port, |
| ValueQueryFn values) { |
| LLVM_DEBUG(value_port.print(llvm::dbgs() << "Computing output for ") << "\n"); |
| if (auto known = values(value_port)) return known; |
| |
| auto op = value_port.producer.dyn_cast<Operation*>(); |
| if (!op) return nullptr; |
| auto& port = value_port.port; |
| |
| if (port.empty()) { |
| LLVM_DEBUG(llvm::dbgs() << "skipping, port outside spec of " << op << "\n"); |
| return nullptr; |
| } |
| |
| ElementsAttr attr; |
| if (matchPattern(op, m_Constant(&attr))) { |
| if (port.size() == 1 && port[0] == 0) return attr; |
| return nullptr; |
| } |
| |
| // Note: this focusses only on the trivial pack op case and this could be |
| // generalized. |
| if (auto pack_op = dyn_cast<TF::PackOp>(op)) { |
| TensorType type = pack_op.getType().cast<TensorType>(); |
| if (!type.hasRank() || type.getRank() != 1) return nullptr; |
| if (port.size() != 2 || port[0] != 0) return nullptr; |
| ValuePort op_port(op->getOperand(port[1])); |
| return values(op_port); |
| } |
| |
| if (auto graph = dyn_cast<tf_executor::GraphOp>(op)) { |
| if (port.size() == 1) |
| return ComputeOutputComponent( |
| ValuePort(graph.GetFetch().fetches()[port[0]]), values); |
| return nullptr; |
| } |
| |
| if (auto island = dyn_cast<tf_executor::IslandOp>(op)) { |
| if (port.size() == 1) |
| return ComputeOutputComponent( |
| ValuePort(island.GetYield().fetches()[port[0]]), values); |
| return nullptr; |
| } |
| |
| return nullptr; |
| } |
| |
| // Context used during ShapeInference. This class contains common information |
| // that is required by the individual shape inference helper functions (e.g., |
| // TF Graph version, constant values computed, etc.) |
| class ShapeInference { |
| public: |
| ShapeInference(int64_t graph_version, MLIRContext* context, |
| bool propagate_caller_callee_constants); |
| |
| LogicalResult ComputeInputsRequiredForOutput(ValuePort value_port, |
| ValuePortInputs* inputs) { |
| return ::mlir::TF::ComputeInputsRequiredForOutput( |
| value_port, |
| [this](const ValuePort& port) { |
| return results_.find(port) != results_.end(); |
| }, |
| inputs); |
| } |
| |
| Attribute ComputeOutputComponent(const ValuePort& value_port) { |
| if (auto known_attr = results_[value_port]) return known_attr; |
| auto attr = ::mlir::TF::ComputeOutputComponent( |
| value_port, [this](const ValuePort& port) { return results_[port]; }); |
| RecordValue(value_port, attr); |
| return attr; |
| } |
| |
| // Returns ShapeHandle if the op result could be computed as shape. |
| ShapeHandle ComputeOutputAsShape(OpResult result, InferenceContext* ic); |
| |
| void RecordValue(const ValuePort& value_port, Attribute value) { |
| LLVM_DEBUG(value_port.print(llvm::dbgs() << "\trecording ") |
| << value << "\n"); |
| results_[value_port] = value; |
| } |
| |
| // Infers shape of tf.While/tf.WhileRegion. If `shape_invariant` attribute is |
| // set, operand types are set as result types if associated body result types |
| // match the operand type (does not change per loop iteration). If operand and |
| // body result types are not the same, only handle types are propagated to |
| // result types. This is necessary to not incorrectly change result shapes |
| // when the While op will have a different result shape. Otherwise operand |
| // shapes are propagated to result shapes. |
| template <typename WhileOpTy> |
| bool InferShapeForWhile(WhileOpTy op, TypeRange body_result_types); |
| |
| // Performs shape inference on the provided op and return true if the type of |
| // at least one result has been changed. |
| // A tf.Cast() is inserted for any uses that isn't in the TensorFlow dialect. |
| // `graph_version` indicates the current GraphDef compatibility versions |
| // (the versions field in graph.proto). |
| bool InferShapeForSingleOperation(Operation* op); |
| |
| // Infers shape on the provided region, including nested ones, iterate until |
| // fix point with a limit of max_iteration. Returns success if fix point is |
| // reached before max_iteration. |
| LogicalResult InferShapeUntilFixPoint(Region* region, int64_t max_iterations); |
| |
| // Updates input types and refine shapes inside body of functions that are |
| // attached to ControlFlow ops (If/While) or Calls. These functions include |
| // Then/Else branches of IfOp and Cond/Body functions of WhileOp. Functions |
| // attached to control flow share following common properties: |
| // 1) They are never reused, ie. having a single use in module. |
| // 2) Their input types match those of their parent ops (excluding inputs |
| // like predicate). |
| // For calls, functions can be reused across multiple call sites. In this case |
| // we propagate the types when all call sites have the same operand types. |
| LogicalResult PropagateShapeToFunctions(ModuleOp module, |
| TypeRange input_types, |
| ArrayRef<FuncOp> functions, |
| int64_t max_iteration); |
| |
| // Propagates shapes to regions given the shapes of the inputs of the regions. |
| // All regions provided in `regions` are assumed to have inputs of type |
| // `input_types`. |
| LogicalResult PropagateShapeToRegions(TypeRange input_types, |
| ArrayRef<Region*> regions, |
| int64_t max_iteration); |
| |
| // Shape propagation for call/control flow ops. |
| LogicalResult PropagateShapeIntoAttachedFunctions(Operation* op, |
| int64_t max_iteration); |
| |
| // Shape propagation for region based control flow. |
| LogicalResult PropagateShapeIntoAttachedRegions(Operation* op, |
| int64_t max_iterations); |
| |
| // Propagates any constant operand of call_op to the called function body's |
| // corresponding argument if the callee has only one use. |
| // |
| // TODO(b/154065712): Move this to a more general inter-procedural constant |
| // folding pass. |
| void PropagateConstantToCallee(CallOpInterface call_op, FuncOp func, |
| ModuleOp module); |
| |
| // Propagates any constant return value of the callee function to the call |
| // op's corresponding result. |
| void PropagateConstantFromCallee(CallOpInterface call_op, FuncOp func, |
| ModuleOp module); |
| |
| // Tries to compute the result of folding the op. This doesn't actually |
| // perform constant folding, it is just computes the equivalent constants. |
| // Returns whether it was able to compute constant values. |
| LogicalResult TryToFold(Operation* op); |
| |
| // Makes result types match the operand types (the i-th result type will |
| // match the i-th operand type). Returns true if anything is changed. |
| bool RefineTypeForPassThroughOperands(Operation* op, OperandRange operands, |
| ResultRange results); |
| |
| // Makes result type's shape match the corresponding operand's shape. |
| // Returns whether any change was made. |
| bool RefineShapeForPassThroughOps(Operation* op); |
| |
| // Infers shape for necessary ops that are not in the TF dialect. Returns |
| // whether any result type changed. |
| bool InferShapeForNonTFDialectOperation(Operation* op); |
| |
| // Infers shape for function return type and returns whether changed. |
| void InferShapeForFunctionReturnType(FuncOp func); |
| |
| // Enqueues function for processing. |
| void enqueue(FuncOp fn) { |
| LLVM_DEBUG(llvm::dbgs() |
| << "enqueue " << fn.getName() << " (" |
| << (queue_set_.count(fn) ? "already inserted" : "newly inserted") |
| << ")\n"); |
| if (queue_set_.insert(fn).second) queue_.push(fn); |
| } |
| |
| // Enqueues callers on functions. |
| void EnqueueCallers(FuncOp fn); |
| |
| // Returns the function at the front of the queue. |
| FuncOp front() { return queue_.front(); } |
| |
| // Returns whether work queue is empty. |
| bool EmptyQueue() const { return queue_.empty(); } |
| |
| // Returns function from the front of the work queue. |
| FuncOp pop_front() { |
| FuncOp ret = queue_.front(); |
| queue_.pop(); |
| queue_set_.erase(ret); |
| return ret; |
| } |
| |
| // Returns the current size of the queue. |
| std::queue<FuncOp>::size_type QueueSize() const { return queue_.size(); } |
| |
| Dialect* const tf_dialect_; |
| |
| private: |
| // Updates the result of an operation to a new inferred type. Also inserts |
| // tf.Cast operation for uses that are incompatible with the new type. |
| void UpdateTypeAndInsertIncompatibleUseCasts(Type new_type, Value result); |
| |
| // Refines the type of `result` of `op` using the type |
| // `potential_refined_type`. Return true if the type was changed. |
| bool RefineResultType(Operation* op, Value result, |
| Type potential_refined_type); |
| |
| // Infers the shape from a (Stateful)PartionedCall operation by looking up the |
| // called function and propagating the return type. |
| bool InferShapeForCall(CallOpInterface call_op); |
| |
| bool InferShapeForCast(CastOp op); |
| |
| // Infers the shape IfOp outputs based on the shapes of the then and else |
| // function result types. |
| bool InferShapeForIf(IfOp op); |
| |
| // Infers the shape IfRegion outputs based on the shapes of the then and else |
| // yields. |
| bool InferShapeForIfRegion(IfRegionOp op); |
| |
| // Infers the shape of ops that create TensorList. Specifically, |
| // TensorListReserveOp, EmptyTensorListOp and TensorListFromTensor ops. It |
| // refines the element shape if all tensors written to the list across all |
| // mutations have identical static shape. |
| bool InferShapeForTensorListInitOps(Operation* op); |
| |
| bool RefineWithInferTypeOpInterface(InferTypeOpInterface infer_ti); |
| |
| // Returns all the callers of a function. |
| // Note: Usage of the return value of this function may not be interleaved |
| // with insertions to the callers map. This could occur if GetCallers is |
| // called with two separate functions, the 2nd one incurs a resize and then |
| // both first and 2nd stored callers are used. |
| ArrayRef<Operation*> GetCallers(FuncOp fn); |
| |
| // Mapping between ValuePort (which corresponds to an OpResult or smaller, |
| // e.g., first element of OpResult produced) to an Attribute if the ValuePort |
| // corresponds to a constant value. |
| ValuePortResultMap results_; |
| |
| // Map from a function to the callers of that function. |
| llvm::DenseMap<FuncOp, SmallVector<Operation*, 4>> callers_of_func_; |
| |
| // Queue of functions being processed. |
| llvm::DenseSet<FuncOp> queue_set_; |
| std::queue<FuncOp> queue_; |
| |
| int64_t graph_version_; |
| |
| // TODO(b/154065712): Remove propagate_caller_callee_constants once using |
| // SCCP pass instead. |
| bool propagate_caller_callee_constants_; |
| }; |
| |
| ShapeInference::ShapeInference(int64_t graph_version, MLIRContext* context, |
| bool propagate_caller_callee_constants) |
| : tf_dialect_(context->getLoadedDialect<TensorFlowDialect>()), |
| graph_version_(graph_version), |
| propagate_caller_callee_constants_(propagate_caller_callee_constants) {} |
| |
| ArrayRef<Operation*> ShapeInference::GetCallers(FuncOp fn) { |
| auto pair = callers_of_func_.try_emplace(fn); |
| if (pair.second) { |
| ModuleOp module = fn->getParentOfType<ModuleOp>(); |
| auto uses = mlir::SymbolTable::getSymbolUses(fn.getOperation(), module); |
| if (uses) { |
| pair.first->second.reserve(pair.first->second.size()); |
| for (auto use : *uses) { |
| pair.first->second.push_back(use.getUser()); |
| } |
| } |
| } |
| return pair.first->second; |
| } |
| |
| void ShapeInference::EnqueueCallers(FuncOp fn) { |
| for (auto user : GetCallers(fn)) enqueue(user->getParentOfType<FuncOp>()); |
| } |
| |
| void ShapeInference::UpdateTypeAndInsertIncompatibleUseCasts(Type new_type, |
| Value result) { |
| // A tf.Cast operation is lazily created on the first use requires a cast. |
| TF::CastOp cast_op; |
| auto get_cast_op = [&]() { |
| if (!cast_op) { |
| Operation* op = result.getDefiningOp(); |
| OpBuilder b(op); |
| b.setInsertionPointAfter(op); |
| cast_op = b.create<TF::CastOp>(op->getLoc(), result.getType(), result, |
| /*truncate=*/b.getBoolAttr(false)); |
| } |
| return Value(cast_op); |
| }; |
| // First insert cast back for uses that need a cast and then |
| // update the type. |
| bool enqueue_callers = false; |
| for (OpOperand& use : make_early_inc_range(result.getUses())) { |
| if (isa<ReturnOp>(use.getOwner())) |
| enqueue_callers = true; |
| else if (NeedsCastBack(use, tf_dialect_)) |
| use.set(get_cast_op()); |
| } |
| |
| result.setType(new_type); |
| if (enqueue_callers) |
| EnqueueCallers(result.getDefiningOp()->getParentOfType<FuncOp>()); |
| } |
| |
| bool ShapeInference::RefineResultType(Operation* op, Value result, |
| Type potential_refined_type) { |
| if (!CanRefineTypeWith(result.getType(), potential_refined_type)) |
| return false; |
| |
| UpdateTypeAndInsertIncompatibleUseCasts(potential_refined_type, result); |
| return true; |
| } |
| |
| // Infers the shape from a (Stateful)PartionedCall operation by looking up the |
| // called function and propagating the return type. |
| bool ShapeInference::InferShapeForCall(CallOpInterface call_op) { |
| FuncOp func = dyn_cast<FuncOp>(call_op.resolveCallable()); |
| if (!func) return false; |
| |
| LLVM_DEBUG(llvm::dbgs() << "Infer shape for call " << func.getName()); |
| Operation* op = call_op.getOperation(); |
| bool changed = false; |
| // Map each of the results of the call to the returned type of the |
| // function. |
| for (auto result : zip(op->getResults(), func.getType().getResults())) { |
| changed = RefineResultType(op, std::get<0>(result), std::get<1>(result)) || |
| changed; |
| } |
| LLVM_DEBUG(llvm::dbgs() << " changed ? " << changed << "\n"); |
| |
| return changed; |
| } |
| |
| bool ShapeInference::InferShapeForCast(CastOp op) { |
| DCOMMENT_OP(op.getOperation(), "Inferring shape for "); |
| Value result = op.getResult(); |
| if (!CanBeRefined(result.getType())) return false; |
| |
| Type operand_type = op.getOperand().getType(); |
| auto ranked_op_type = operand_type.dyn_cast<RankedTensorType>(); |
| if (!ranked_op_type) return false; |
| auto ranked_res_type = result.getType().dyn_cast<RankedTensorType>(); |
| if (ranked_res_type && |
| ranked_op_type.getShape() == ranked_res_type.getShape()) |
| return false; |
| |
| // Avoid inserting a cast where no users types could be refined (e.g., where |
| // there would need to be a cast inserted for every user again). |
| if (llvm::all_of(result.getUses(), [this](OpOperand& use) { |
| return NeedsCastBack(use, tf_dialect_); |
| })) |
| return false; |
| |
| auto new_type = RankedTensorType::get( |
| ranked_op_type.getShape(), |
| result.getType().cast<ShapedType>().getElementType()); |
| |
| UpdateTypeAndInsertIncompatibleUseCasts(new_type, op.getResult()); |
| return true; |
| } |
| |
| bool ShapeInference::InferShapeForIf(IfOp op) { |
| DCOMMENT_OP(op.getOperation(), "Infer shape for if "); |
| bool changed = false; |
| auto then_results = op.then_function().getType().getResults(); |
| auto else_results = op.else_function().getType().getResults(); |
| for (auto it : llvm::zip(op.getResults(), then_results, else_results)) { |
| // If then and else types do not match, skip refinement for that result. |
| if (std::get<1>(it) != std::get<2>(it)) continue; |
| changed = RefineResultType(op, std::get<0>(it), std::get<1>(it)) || changed; |
| } |
| return changed; |
| } |
| |
| bool ShapeInference::InferShapeForIfRegion(IfRegionOp op) { |
| bool changed = false; |
| |
| Operation* then_yield = op.then_branch().front().getTerminator(); |
| Operation* else_yield = op.else_branch().front().getTerminator(); |
| for (auto result : zip(op.getResults(), then_yield->getOperandTypes(), |
| else_yield->getOperandTypes())) { |
| // If then and else types do not match, skip refinement for that result. |
| if (std::get<1>(result) != std::get<2>(result)) continue; |
| changed = RefineResultType(op, std::get<0>(result), std::get<1>(result)) || |
| changed; |
| } |
| return changed; |
| } |
| |
| bool ShapeInference::InferShapeForTensorListInitOps(Operation* op) { |
| DCOMMENT_OP(op, "Inferring shape for TensorList "); |
| Value handle = op->getResult(0); |
| Value initial_element_shape = GetElementShapeOperand(op); |
| RankedTensorType element_type; |
| if (auto tl_from_tensor = dyn_cast<TensorListFromTensorOp>(op)) { |
| // For TensorListFromTensor op we can infer element shape by dropping the |
| // first dimension of input tensor. |
| element_type = DropFirstDimension(tl_from_tensor.tensor().getType()); |
| if (!element_type || !element_type.hasStaticShape()) return false; |
| } |
| if (!CanInferTensorListElementType(handle, initial_element_shape, |
| &element_type)) |
| return false; |
| if (!element_type || !element_type.hasStaticShape()) return false; |
| auto variant_type = VariantType::get(element_type, op->getContext()); |
| auto tensor_type = RankedTensorType::get({}, variant_type); |
| bool changed = RefineResultType(op, handle, tensor_type); |
| if (changed) DCOMMENT_OP(op, "Modified after shape inference:"); |
| return changed; |
| } |
| |
| bool ShapeInference::RefineWithInferTypeOpInterface( |
| InferTypeOpInterface infer_ti) { |
| Operation* op = infer_ti.getOperation(); |
| SmallVector<Type, 4> inferred; |
| LogicalResult res = infer_ti.inferReturnTypes( |
| op->getContext(), op->getLoc(), op->getOperands(), |
| op->getAttrDictionary(), op->getRegions(), inferred); |
| if (failed(res)) { |
| op->emitOpError("failed to refine type as inference failed"); |
| return false; |
| } |
| |
| if (inferred == op->getResultTypes()) return false; |
| |
| // Map each of the results of the call to the returned type of the |
| // function. |
| bool changed = false; |
| for (auto result : zip(op->getResults(), inferred)) { |
| if (std::get<0>(result).getType() == std::get<1>(result)) continue; |
| |
| UpdateTypeAndInsertIncompatibleUseCasts(std::get<1>(result), |
| std::get<0>(result)); |
| changed = true; |
| } |
| return changed; |
| } |
| |
| ShapeHandle ShapeInference::ComputeOutputAsShape(OpResult result, |
| InferenceContext* ic) { |
| LLVM_DEBUG(result.print(llvm::dbgs() << "\nEvaluate partially ")); |
| auto rt = result.getType().dyn_cast<RankedTensorType>(); |
| if (!rt || !rt.hasStaticShape() || rt.getRank() != 1) return {}; |
| int dim_size = rt.getDimSize(0); |
| |
| // Worklist to direct partial evaluation. |
| SmallVector<ValuePort, 4> worklist; |
| |
| // Simple evaluator that attempts to partially evaluate the input value even |
| // if unable to evaluate the complete output. Below follows a simple stack |
| // based evaluation where it queries what operands/part of operands need to |
| // be evaluated and attempting to partially evaluate those operands. It does |
| // so by pushing the operands that need to be required on to the worklist |
| // before enqueuing the operation requiering those values. |
| std::vector<DimensionHandle> dims(dim_size, ic->UnknownDim()); |
| for (unsigned int i = 0, e = dims.size(); i != e; ++i) { |
| LLVM_DEBUG(llvm::dbgs() << "\nConsidering output dim " << i << "\n"); |
| |
| worklist.push_back( |
| ValuePort{result.getOwner(), {result.getResultNumber(), i}}); |
| while (!worklist.empty()) { |
| auto front = worklist.pop_back_val(); |
| LLVM_DEBUG(front.print(llvm::dbgs() << "\nWorklist front ")); |
| |
| SmallVector<ValuePort, 4> inputs; |
| auto res = ComputeInputsRequiredForOutput(front, &inputs); |
| if (failed(res)) { |
| // Abort if unable to find which required inputs need to be computed. |
| worklist.clear(); |
| break; |
| } |
| |
| if (!inputs.empty()) { |
| // Enqueue required computation followed by its required operands in |
| // stack. |
| worklist.push_back(std::move(front)); |
| for (auto& it : inputs) worklist.push_back(std::move(it)); |
| continue; |
| } |
| |
| auto ret = ComputeOutputComponent(front); |
| if (!ret) continue; |
| |
| LLVM_DEBUG(ret.print(llvm::dbgs() << "\ncomputed result = ")); |
| |
| // If worklist is empty, then this is the root query op. |
| if (worklist.empty()) { |
| LLVM_DEBUG(llvm::dbgs() << "[root node]\n"); |
| if (auto dea = ret.dyn_cast<DenseIntElementsAttr>()) { |
| if (dea.getNumElements() != 1) { |
| LLVM_DEBUG(llvm::dbgs() << "Unexpected number of elements\n"); |
| return {}; |
| } |
| int64_t val = (*dea.getIntValues().begin()).getSExtValue(); |
| dims[i] = ic->MakeDim(val); |
| } |
| } |
| } |
| } |
| return ic->MakeShape(dims); |
| } |
| |
| bool ShapeInference::RefineTypeForPassThroughOperands(Operation* op, |
| OperandRange operands, |
| ResultRange results) { |
| bool changed = false; |
| for (auto entry : zip(operands, results)) { |
| Type operand_type = std::get<0>(entry).getType(); |
| Value result = std::get<1>(entry); |
| TensorType result_type = result.getType().cast<TensorType>(); |
| if (operand_type == result_type) continue; |
| // Pass through nodes may remove ref types, don't consider that as |
| // refinement. |
| // TODO(jpienaar): There could be refinement in addition to this, so |
| // refine this. |
| if (operand_type.cast<TensorType>() |
| .getElementType() |
| .isa<TF::TensorFlowRefType>() && |
| !result_type.cast<TensorType>() |
| .getElementType() |
| .isa<TF::TensorFlowRefType>()) |
| continue; |
| |
| UpdateTypeAndInsertIncompatibleUseCasts(operand_type, result); |
| changed = true; |
| } |
| return changed; |
| } |
| |
| bool ShapeInference::RefineShapeForPassThroughOps(Operation* op) { |
| DCOMMENT_OP(op, "Pass through op"); |
| auto is_allowed_dtype = [](Type t) { |
| // Skip if element type is not in standard or TF dialect. |
| // TODO(jpienaar): The tf.Cast op, which is uniformly inserted at the |
| // moment, cannot handle arbirary types (e.g., it can't handle quantized |
| // types). This restriction can be relaxed if not only tf.Cast is used. |
| return t.getDialect().getNamespace().empty() || |
| isa<TensorFlowDialect>(t.getDialect()); |
| }; |
| |
| bool changed = false; |
| for (auto entry : zip(op->getOperands(), op->getResults())) { |
| TensorType operand_type = std::get<0>(entry).getType().cast<TensorType>(); |
| Value result = std::get<1>(entry); |
| TensorType result_type = result.getType().cast<TensorType>(); |
| if (operand_type == result_type) continue; |
| if (!operand_type.hasRank()) continue; |
| if (result_type.hasRank() && |
| result_type.getShape() == operand_type.getShape()) |
| continue; |
| if (!is_allowed_dtype(operand_type.getElementType()) || |
| !is_allowed_dtype(result_type.getElementType())) |
| continue; |
| |
| auto new_type = RankedTensorType::get(operand_type.getShape(), |
| result_type.getElementType()); |
| UpdateTypeAndInsertIncompatibleUseCasts(new_type, result); |
| changed = true; |
| } |
| return changed; |
| } |
| |
| bool ShapeInference::InferShapeForNonTFDialectOperation(Operation* op) { |
| if (auto graph_op = dyn_cast<tf_executor::GraphOp>(op)) { |
| return RefineTypeForPassThroughOperands( |
| graph_op.GetFetch(), graph_op.GetFetch().fetches(), op->getResults()); |
| } |
| if (auto island_op = dyn_cast<tf_executor::IslandOp>(op)) { |
| return RefineTypeForPassThroughOperands( |
| island_op.GetYield(), island_op.GetYield().fetches(), op->getResults()); |
| } |
| if (auto iter_sink = dyn_cast<tf_executor::NextIterationSinkOp>(op)) { |
| auto iter_source = cast<tf_executor::NextIterationSourceOp>( |
| iter_sink.token().getDefiningOp()); |
| return RefineTypeForPassThroughOperands( |
| op, iter_sink.getOperands().drop_front().take_front(), |
| iter_source.getResults()); |
| } |
| if (auto launch_op = dyn_cast<tf_device::LaunchOp>(op)) { |
| auto terminator = launch_op.GetBody().getTerminator(); |
| return RefineTypeForPassThroughOperands(op, terminator->getOperands(), |
| op->getResults()); |
| } |
| if (auto cluster_op = dyn_cast<tf_device::ClusterOp>(op)) { |
| auto terminator = cluster_op.GetBody().getTerminator(); |
| return RefineTypeForPassThroughOperands(op, terminator->getOperands(), |
| op->getResults()); |
| } |
| if (op->hasTrait<OpTrait::SameOperandsAndResultShape>() || |
| isa<tensor::CastOp>(op)) { |
| return RefineShapeForPassThroughOps(op); |
| } |
| if (auto call = dyn_cast<CallOpInterface>(op)) return InferShapeForCall(call); |
| return false; |
| } |
| |
| // Finds element type to be used for result from operand, with special handling |
| // for handle types. |
| Type GetElementTypeFromOperand(TensorType operand_type, |
| TensorType result_type) { |
| auto operand_handle_type = |
| operand_type.getElementType().dyn_cast<TensorFlowTypeWithSubtype>(); |
| if (!operand_handle_type) return result_type.getElementType(); |
| auto result_handle_type = |
| result_type.getElementType().cast<TensorFlowTypeWithSubtype>(); |
| if (operand_handle_type.GetSubtypes().empty() || |
| !result_handle_type.GetSubtypes().empty()) |
| return result_type.getElementType(); |
| return operand_handle_type; |
| } |
| |
| // Checks if one tensor type can refine another type for tf.While/ |
| // tf.WhileRegion. If rank differs or static dimensions can be lost, the other |
| // type cannot be used for refinement. |
| bool CanWhileTypeBeRefinedWith(TensorType current_type, |
| TensorType potential_refined_type) { |
| if (!current_type.hasRank()) return true; |
| if (!potential_refined_type.hasRank()) return false; |
| if (current_type.getRank() != potential_refined_type.getRank()) return false; |
| for (auto dim : |
| llvm::zip(current_type.getShape(), potential_refined_type.getShape())) { |
| int64_t current_dim = std::get<0>(dim); |
| int64_t potential_refined_dim = std::get<1>(dim); |
| if (current_dim != potential_refined_dim && |
| current_dim != ShapedType::kDynamicSize) |
| return false; |
| } |
| return true; |
| } |
| |
| template <typename WhileOpTy> |
| bool ShapeInference::InferShapeForWhile(WhileOpTy op, |
| TypeRange body_result_types) { |
| if (!op.shape_invariant()) |
| return RefineTypeForPassThroughOperands(op, op.input(), op.output()); |
| |
| bool changed = false; |
| for (auto entry : |
| zip(op.input().getTypes(), op.output(), body_result_types)) { |
| Value result = std::get<1>(entry); |
| TensorType body_result_type = |
| std::get<2>(entry).template cast<TensorType>(); |
| auto result_type = result.getType().cast<TensorType>(); |
| |
| Type potential_refined_type; |
| if (CanWhileTypeBeRefinedWith(result_type, body_result_type)) { |
| Type element_type = |
| GetElementTypeFromOperand(body_result_type, result_type); |
| potential_refined_type = CreateTensorType( |
| body_result_type.hasRank() ? body_result_type.getShape() |
| : llvm::Optional<ArrayRef<int64_t>>(), |
| element_type); |
| } else { |
| TensorType operand_type = std::get<0>(entry).template cast<TensorType>(); |
| Type element_type = GetElementTypeFromOperand(operand_type, result_type); |
| potential_refined_type = CreateTensorType( |
| result_type.hasRank() ? result_type.getShape() |
| : llvm::Optional<ArrayRef<int64_t>>(), |
| element_type); |
| } |
| changed |= RefineResultType(op, result, potential_refined_type); |
| } |
| return changed; |
| } |
| |
| bool ShapeInference::InferShapeForSingleOperation(Operation* op) { |
| LLVM_DEBUG(op->print(llvm::dbgs() << "InferShapeForSingleOperation for "); |
| llvm::dbgs() << "\n"); |
| assert(tf_dialect_ == op->getDialect()); |
| // The shape function of these ops sometimes does not propagate subtypes |
| // (handle shapes) for resource and variant types. We use a simple passthrough |
| // to make sure they are preserved in the output. |
| if (isa<TF::IdentityOp, TF::IdentityNOp, TF::ZerosLikeOp>(op)) { |
| return RefineTypeForPassThroughOperands(op, op->getOperands(), |
| op->getResults()); |
| } |
| |
| // If no result for this op needs shape inference, we have a fast-path return. |
| // But if the type is a resource/variant, we do not skip it because we might |
| // not have the handle shapes. |
| if (none_of(op->getResultTypes(), CanBeRefined)) { |
| LLVM_DEBUG(llvm::dbgs() << "Skipping inference for statically shaped op '" |
| << op->getName() << "'.\n"); |
| return false; |
| } |
| |
| // Handle call operations by looking up callee and inferring return shape as |
| // needed. |
| if (auto call = dyn_cast<CallOpInterface>(op)) return InferShapeForCall(call); |
| |
| // tf.Cast are only inferred if they have at least one user in the TF dialect |
| // or feeding into the function return. This is necessary to avoid inserting |
| // casts which cannot be refined. |
| if (auto cast_op = dyn_cast<CastOp>(op)) return InferShapeForCast(cast_op); |
| |
| // Handle IfOp here by inferring the shape from the else/then function |
| // results. Since `output_shapes` is a derived attribute, avoid going down the |
| // TF InferenceContext path as IfOp shape inference is implemented as just |
| // a lookup of the output_shapes attribute. |
| if (auto if_op = dyn_cast<IfOp>(op)) return InferShapeForIf(if_op); |
| |
| // Handle IfRegion operations by inferring return shape from the then and else |
| // branches. |
| if (auto if_region = dyn_cast<IfRegionOp>(op)) |
| return InferShapeForIfRegion(if_region); |
| |
| if (auto while_op = dyn_cast<WhileOp>(op)) |
| return InferShapeForWhile(while_op, |
| while_op.body_function().getType().getResults()); |
| |
| if (auto while_region = dyn_cast<WhileRegionOp>(op)) |
| return InferShapeForWhile( |
| while_region, |
| while_region.body().front().getTerminator()->getOperandTypes()); |
| |
| // Handle TensorList init operations by inferring shape from TensorList write |
| // operations. If we are unable to refine element shape here, proceed to use |
| // the InferenceContext below to get more precise shapes. |
| if (IsTensorListInitOp(op) && InferShapeForTensorListInitOps(op)) return true; |
| |
| // Return operand as a constant attribute. |
| auto operand_as_constant_fn = [&](Value operand) { |
| ValuePort vp(operand); |
| Attribute attr = ComputeOutputComponent(vp); |
| if (!attr && matchPattern(operand, m_Constant(&attr))) |
| RecordValue(vp, attr); |
| return attr; |
| }; |
| |
| // Return op result as a shape. |
| auto op_result_as_shape_fn = [&](InferenceContext& context, |
| OpResult op_result) { |
| return ComputeOutputAsShape(op_result, &context); |
| }; |
| |
| // Return result element type at `index`. |
| auto result_element_type_fn = [&](int index) { |
| return op->getResult(index).getType().cast<TensorType>().getElementType(); |
| }; |
| |
| llvm::SmallVector<ShapedTypeComponents, 4> inferred_return_shapes; |
| if (failed(InferReturnTypeComponentsForTFOp( |
| /*location=*/None, op, graph_version_, operand_as_constant_fn, |
| op_result_as_shape_fn, result_element_type_fn, |
| inferred_return_shapes))) |
| return false; |
| |
| // Update the shape for each of the operation result if the InferenceContext |
| // has more precise shapes recorded. |
| bool changed = false; |
| for (auto result : llvm::zip(op->getResults(), inferred_return_shapes)) { |
| Value op_result = std::get<0>(result); |
| if (!CanBeRefined(op_result.getType())) continue; |
| |
| ShapedTypeComponents inferred = std::get<1>(result); |
| TensorType inferred_type; |
| if (inferred.hasRank()) |
| inferred_type = |
| RankedTensorType::get(inferred.getDims(), inferred.getElementType()); |
| else |
| inferred_type = UnrankedTensorType::get(inferred.getElementType()); |
| |
| if (op_result.getType() == inferred_type) continue; |
| UpdateTypeAndInsertIncompatibleUseCasts(inferred_type, op_result); |
| changed = true; |
| } |
| |
| if (changed) DCOMMENT_OP(op, "Modified after shape inference:"); |
| return changed; |
| } |
| |
| LogicalResult ShapeInference::PropagateShapeToFunctions( |
| ModuleOp module, TypeRange input_types, ArrayRef<FuncOp> functions, |
| int64_t max_iteration) { |
| bool all_succeeded = true; |
| // If shape propagation fails for one function, return failure, but do not |
| // early exit and attempt to propagate shapes for all provided functions to |
| // have a best-effort propagation. |
| for (FuncOp func : functions) { |
| DCOMMENT("Propating shape to" << func.getName()); |
| ArrayRef<Operation*> callers = GetCallers(func); |
| if (!llvm::hasSingleElement(callers) && |
| !llvm::all_of(callers.drop_front(), [&](Operation* caller) { |
| /// TODO(aminim): this is overly conservative as some operations |
| /// (like TPUPartitionedCallOp) may have extra operands that aren't |
| /// propagated to the callee. |
| return isa<CallOpInterface>(caller) && |
| std::equal(caller->getOperandTypes().begin(), |
| caller->getOperandTypes().end(), |
| callers.front()->getOperandTypes().begin()); |
| })) { |
| all_succeeded = false; |
| if (llvm::any_of(callers, [](Operation* op) { |
| return isa<IfOp, WhileOp, CaseOp>(op); |
| })) |
| func.emitWarning(formatv( |
| "expected control flow function @{0} to have exactly 1 use, " |
| "found {1}.", |
| func.getName(), callers.size())); |
| |
| continue; |
| } |
| FunctionType func_type = func.getType(); |
| func.setType(FunctionType::get(func.getContext(), input_types, |
| func_type.getResults())); |
| |
| auto res = |
| PropagateShapeToRegions(input_types, {&func.getBody()}, max_iteration); |
| if (failed(res)) { |
| all_succeeded = false; |
| continue; |
| } |
| |
| InferShapeForFunctionReturnType(func); |
| } |
| return success(all_succeeded); |
| } |
| |
| LogicalResult ShapeInference::PropagateShapeToRegions(TypeRange input_types, |
| ArrayRef<Region*> regions, |
| int64_t max_iteration) { |
| DCOMMENT("\tPropagating shapes to regions"); |
| bool all_succeeded = true; |
| // If shape propagation fails for one region, return failure, but do not |
| // early exit and attempt to propagate shapes for all provided regions to |
| // have a best-effort propagation. |
| for (auto region : regions) { |
| // Refine region arguments. |
| Block& entry = region->front(); |
| assert(llvm::size(input_types) == entry.getNumArguments()); |
| for (auto it : llvm::zip(entry.getArguments(), input_types)) { |
| BlockArgument arg = std::get<0>(it); |
| Type type = std::get<1>(it); |
| arg.setType(type); |
| } |
| |
| // Propagate shapes into the region. |
| all_succeeded = succeeded(InferShapeUntilFixPoint(region, max_iteration)) && |
| all_succeeded; |
| } |
| return success(all_succeeded); |
| } |
| |
| void ShapeInference::PropagateConstantToCallee(CallOpInterface call_op, |
| FuncOp func, ModuleOp module) { |
| auto callers = GetCallers(func); |
| if (!llvm::hasSingleElement(callers)) return; |
| |
| OpBuilder builder(&func.front().front()); |
| Operation* op = call_op.getOperation(); |
| // If this is the only caller, and an operand is a constant, propagate |
| // the constant value inside the function. |
| for (auto arg : func.getArguments()) { |
| auto operand = op->getOperand(arg.getArgNumber()); |
| if (propagate_caller_callee_constants_) { |
| if (isa_and_nonnull<TF::ConstOp>(operand.getDefiningOp())) { |
| arg.replaceAllUsesWith( |
| builder.clone(*operand.getDefiningOp())->getResult(0)); |
| } |
| continue; |
| } |
| |
| auto known_constant = ComputeOutputComponent(ValuePort(operand)); |
| if (!known_constant) continue; |
| LLVM_DEBUG(call_op.print(llvm::dbgs() << "Propagate to calee: "); |
| known_constant.print(llvm::dbgs() << " constant "); |
| llvm::dbgs() << "\n"); |
| RecordValue(ValuePort(arg), known_constant); |
| } |
| } |
| |
| void ShapeInference::PropagateConstantFromCallee(CallOpInterface call_op, |
| FuncOp func, ModuleOp module) { |
| // If the return value is a constant, use the constant as the value of |
| // the call return. |
| Operation* op = call_op.getOperation(); |
| OpBuilder builder(op); |
| builder.setInsertionPointAfter(op); |
| for (auto retval : |
| llvm::enumerate(func.front().getTerminator()->getOperands())) { |
| if (propagate_caller_callee_constants_) { |
| auto retval_op = retval.value().getDefiningOp(); |
| if (isa_and_nonnull<TF::ConstOp>(retval_op)) { |
| op->getResult(retval.index()) |
| .replaceAllUsesWith(builder.clone(*retval_op)->getResult(0)); |
| } |
| continue; |
| } |
| |
| ValuePort vp(retval.value()); |
| if (auto known_constant = ComputeOutputComponent(vp)) { |
| LLVM_DEBUG(known_constant.print(llvm::dbgs() << "Propagate constant "); |
| call_op.print(llvm::dbgs() << "from "); llvm::dbgs() << "\n"); |
| RecordValue(ValuePort(op->getResult(retval.index())), known_constant); |
| } |
| } |
| } |
| |
| bool RankedAndSameRank(TensorType lhs, TensorType rhs) { |
| return lhs.hasRank() && rhs.hasRank() && lhs.getRank() == rhs.getRank(); |
| } |
| |
| // Creates a compatible RankedTensorType where mismatched dimensions are |
| // replaced with dynamic sizes. |
| RankedTensorType GetCompatibleRankedTensorType(RankedTensorType lhs, |
| RankedTensorType rhs) { |
| assert(lhs.getRank() == rhs.getRank()); |
| llvm::SmallVector<int64_t, 4> dims; |
| dims.reserve(lhs.getRank()); |
| for (auto dim : llvm::zip(lhs.getShape(), rhs.getShape())) { |
| int64_t lhs_dim = std::get<0>(dim); |
| if (lhs_dim == std::get<1>(dim)) { |
| dims.push_back(lhs_dim); |
| } else { |
| dims.push_back(ShapedType::kDynamicSize); |
| } |
| } |
| return RankedTensorType::get(dims, GetElementTypeFromOperand(lhs, rhs)); |
| } |
| |
| // Finds compatible types to propagate into functions/regions of a shape |
| // invariant tf.While/tf.WhileRegion. If operand and result types are the same, |
| // that type is returned. If operand and result types are of the same rank, a |
| // compatible type with matching dimensions is used. Otherwise functions/regions |
| // arguments are returned but with the handle type from the operand type. |
| llvm::SmallVector<Type, 4> GetWhileCompatibleTypes( |
| TypeRange operand_types, TypeRange result_types, |
| TypeRange region_argument_types) { |
| llvm::SmallVector<Type, 4> types; |
| types.reserve(operand_types.size()); |
| for (auto entry : |
| llvm::zip(operand_types, result_types, region_argument_types)) { |
| auto operand_type = std::get<0>(entry).cast<TensorType>(); |
| auto result_type = std::get<1>(entry).cast<TensorType>(); |
| if (operand_type == result_type) { |
| types.push_back(operand_type); |
| } else if (RankedAndSameRank(operand_type, result_type)) { |
| auto potential_refined_type = |
| GetCompatibleRankedTensorType(operand_type.cast<RankedTensorType>(), |
| result_type.cast<RankedTensorType>()); |
| types.push_back(potential_refined_type); |
| } else { |
| auto region_argument_type = std::get<2>(entry).cast<TensorType>(); |
| Type element_type = GetElementTypeFromOperand( |
| operand_type.cast<TensorType>(), region_argument_type); |
| Type potential_refined_type = CreateTensorType( |
| region_argument_type.hasRank() ? region_argument_type.getShape() |
| : llvm::Optional<ArrayRef<int64_t>>(), |
| element_type); |
| types.push_back(potential_refined_type); |
| } |
| } |
| return types; |
| } |
| |
| LogicalResult ShapeInference::PropagateShapeIntoAttachedFunctions( |
| Operation* op, int64_t max_iteration) { |
| ModuleOp module = op->getParentOfType<ModuleOp>(); |
| if (auto if_op = dyn_cast<TF::IfOp>(op)) { |
| DCOMMENT("Propagating shapes into If"); |
| return PropagateShapeToFunctions( |
| module, if_op.input().getTypes(), |
| {if_op.then_function(), if_op.else_function()}, max_iteration); |
| } else if (auto case_op = dyn_cast<TF::CaseOp>(op)) { |
| SmallVector<FuncOp, 4> branches; |
| case_op.get_branch_functions(branches); |
| return PropagateShapeToFunctions(module, case_op.input().getTypes(), |
| branches, max_iteration); |
| } else if (auto while_op = dyn_cast<TF::WhileOp>(op)) { |
| // If `shape_invariant` is set, operand shapes cannot be simply propagated |
| // to result shapes as the op may have different intermediate shapes (such |
| // While ops can have different result shapes from operand shapes). |
| // Compatible shapes must be determined before propagating them. |
| if (while_op.shape_invariant()) { |
| auto compatible_types = GetWhileCompatibleTypes( |
| while_op.input().getTypes(), while_op.output().getTypes(), |
| while_op.body_function().getType().getInputs()); |
| return PropagateShapeToFunctions( |
| module, compatible_types, |
| {while_op.cond_function(), while_op.body_function()}, max_iteration); |
| } |
| return PropagateShapeToFunctions( |
| module, while_op.input().getTypes(), |
| {while_op.cond_function(), while_op.body_function()}, max_iteration); |
| } else if (auto call_op = dyn_cast<CallOpInterface>(op)) { |
| if (auto func = dyn_cast<FuncOp>(call_op.resolveCallable())) { |
| PropagateConstantToCallee(call_op, func, module); |
| if (failed(PropagateShapeToFunctions(module, |
| call_op.getArgOperands().getTypes(), |
| {func}, max_iteration))) { |
| return failure(); |
| } |
| PropagateConstantFromCallee(call_op, func, module); |
| return success(); |
| } |
| } |
| |
| // TODO(ycao): Implement support for Call op, including function reuse. |
| |
| return success(); |
| } |
| |
| LogicalResult ShapeInference::PropagateShapeIntoAttachedRegions( |
| Operation* op, int64_t max_iteration) { |
| if (auto while_op = dyn_cast<TF::WhileRegionOp>(op)) { |
| // If `shape_invariant` is set, operand shapes cannot be simply propagated |
| // to result shapes as the op may have different intermediate shapes (such |
| // While ops can have different result shapes from operand shapes). |
| // Compatible shapes must be determined before propagating them. |
| if (while_op.shape_invariant()) { |
| auto compatible_types = GetWhileCompatibleTypes( |
| while_op.input().getTypes(), while_op.output().getTypes(), |
| while_op.body().getArgumentTypes()); |
| return PropagateShapeToRegions(compatible_types, |
| {&while_op.cond(), &while_op.body()}, |
| max_iteration); |
| } |
| return PropagateShapeToRegions(while_op.input().getTypes(), |
| {&while_op.cond(), &while_op.body()}, |
| max_iteration); |
| } |
| return success(); |
| } |
| |
| LogicalResult ShapeInference::TryToFold(Operation* op) { |
| LLVM_DEBUG(op->print(llvm::dbgs() << "TryToFold "); llvm::dbgs() << "\n"); |
| // If any output result is known, then the op probably has been computed |
| // before. |
| if (op->getNumResults() > 0 && results_[ValuePort(op->getResult(0))]) |
| return success(); |
| |
| SmallVector<Attribute, 8> constant_operands(op->getNumOperands()); |
| SmallVector<OpFoldResult, 8> fold_results; |
| |
| // Check to see if any operands to the operation is constant and whether |
| // the operation knows how to constant fold itself. |
| bool some_unknown = false; |
| for (int i = 0, e = op->getNumOperands(); i != e; ++i) { |
| if (!(constant_operands[i] = |
| ComputeOutputComponent(ValuePort(op->getOperand(i))))) |
| some_unknown = true; |
| } |
| |
| // Attempt to constant fold the operation. |
| auto* abstract_op = op->getAbstractOperation(); |
| LogicalResult folded = failure(); |
| if (abstract_op) { |
| folded = abstract_op->foldHook(op, constant_operands, fold_results); |
| } |
| // Attempt dialect fallback if op's fold hook failed. |
| if (failed(folded)) { |
| Dialect* dialect = op->getDialect(); |
| if (!dialect) return failure(); |
| // Only attempt TF dialect fallback if there are no unknown operands. |
| if (some_unknown && dialect == tf_dialect_) return failure(); |
| auto* interface = dialect->getRegisteredInterface<DialectFoldInterface>(); |
| if (!interface) return failure(); |
| |
| if (failed(interface->fold(op, constant_operands, fold_results))) |
| return failure(); |
| } |
| |
| for (auto result : zip(op->getResults(), fold_results)) { |
| auto fold_result = std::get<1>(result); |
| Attribute attr = nullptr; |
| if ((attr = fold_result.dyn_cast<Attribute>())) { |
| RecordValue(ValuePort(std::get<0>(result)), attr); |
| } else { |
| auto value = fold_result.get<Value>(); |
| if ((attr = ComputeOutputComponent(ValuePort(value)))) { |
| DCOMMENT("\t\tValue Result mapped to " << attr); |
| RecordValue(ValuePort(std::get<0>(result)), attr); |
| } else { |
| DCOMMENT("\t\tValue result unmapped, consider value type:" << value); |
| RefineResultType(op, std::get<0>(result), value.getType()); |
| } |
| } |
| |
| if (ElementsAttr eattr = attr.dyn_cast_or_null<ElementsAttr>()) { |
| if (std::get<0>(result).getType() == eattr.getType()) continue; |
| |
| UpdateTypeAndInsertIncompatibleUseCasts(eattr.getType(), |
| std::get<0>(result)); |
| } |
| } |
| |
| return success(); |
| } |
| |
| void ShapeInference::InferShapeForFunctionReturnType(FuncOp func) { |
| LLVM_DEBUG(llvm::dbgs() << "Inferring return type for: " << func.getName() |
| << "\n"); |
| |
| // Find any return ops. |
| SmallVector<ReturnOp, 4> return_ops; |
| for (Block& block : func) { |
| if (auto return_op = dyn_cast<ReturnOp>(block.getTerminator())) { |
| return_ops.push_back(return_op); |
| } |
| } |
| |
| // Right now we only handle the case of a single return op. |
| // To handle multiple return ops, we would need to look at all their shapes |
| // and come up with a common shape and insert appropriate casts. |
| if (return_ops.size() != 1) return; |
| |
| // Find the return type. |
| auto return_op = return_ops.front(); |
| |
| // Manually fold tf.Cast that precedes the return instruction and only differs |
| // in shape refinement level. |
| bool changed = false; |
| for (OpOperand& arg_op : return_op.getOperation()->getOpOperands()) { |
| Operation* arg_defining_op = arg_op.get().getDefiningOp(); |
| if (auto cast_op = dyn_cast_or_null<CastOp>(arg_defining_op)) { |
| Value input = cast_op.x(); |
| Value result = cast_op.y(); |
| if (!CanRefineTypeWith(result.getType(), input.getType())) continue; |
| |
| LLVM_DEBUG({ |
| llvm::errs() << "\tfolding & updating return type "; |
| cast_op.getResult().getType().print(llvm::errs()); |
| cast_op.getOperand().getType().print(llvm::errs() << " to "); |
| llvm::errs() << "\n"; |
| }); |
| |
| // Shape inference should not change the element type. |
| if (HasCompatibleElementTypes(input.getType(), result.getType())) { |
| arg_op.set(input); |
| } else { |
| OpBuilder b(return_op.getOperation()); |
| TensorType type; |
| if (input.getType().cast<TensorType>().hasRank()) { |
| type = RankedTensorType::get( |
| input.getType().cast<TensorType>().getShape(), |
| result.getType().cast<TensorType>().getElementType()); |
| } else { |
| type = UnrankedTensorType::get( |
| result.getType().cast<TensorType>().getElementType()); |
| } |
| auto new_cast_op = |
| b.create<TF::CastOp>(return_op.getLoc(), type, input, |
| /*truncate=*/b.getBoolAttr(false)); |
| arg_op.set(new_cast_op); |
| } |
| if (cast_op.y().use_empty()) cast_op.erase(); |
| changed = true; |
| } |
| } |
| |
| DCOMMENT("Updating function type"); |
| func.setType(FunctionType::get(func.getContext(), func.getArgumentTypes(), |
| return_op.getOperandTypes())); |
| |
| if (changed) EnqueueCallers(func); |
| } |
| |
| LogicalResult ShapeInference::InferShapeUntilFixPoint(Region* region, |
| int64_t max_iteration) { |
| bool changed = true; |
| // TODO(b/180630087): This is due to creating needless intermediate types |
| // which can be really expensive given the current approach here. |
| bool failed_due_inefficiency = false; |
| |
| // TODO(aminim): we could have a more efficient traversal by guiding the |
| // traversal with a worklist and reconsider only the nodes for which an |
| // operand type was inferred. This would need to be careful if working on a |
| // region that would not be isolated. |
| for (int iteration = 0; iteration < max_iteration && changed; ++iteration) { |
| changed = false; |
| LLVM_DEBUG(llvm::dbgs() |
| << "Shape inference, iteration " << iteration << "\n"); |
| region->walk([&](Operation* op) { |
| // TODO(b/180630087): Remove post change. |
| if (op->getNumResults() > 50e3) { |
| failed_due_inefficiency = true; |
| return; |
| } |
| |
| DCOMMENT_OP(op, "Inferring for"); |
| if (auto infer_ti = dyn_cast<InferTypeOpInterface>(op)) { |
| DCOMMENT("\tRefinining with type op interface"); |
| changed |= RefineWithInferTypeOpInterface(infer_ti); |
| return; |
| } |
| |
| if (op->getDialect() != tf_dialect_) { |
| DCOMMENT("\tInfer non-TF dialect"); |
| changed |= InferShapeForNonTFDialectOperation(op); |
| return; |
| } |
| |
| // Before attempting inference, just try to compute the folded |
| // value/shape. |
| if (succeeded(TryToFold(op)) && |
| // Folding can "succeed" and yet not all types be refined. In such |
| // cases we still want to give a try at `InferShapeForSingleOperation` |
| none_of(op->getResultTypes(), CanBeRefined)) |
| return; |
| |
| // Best-effort shape inference in attached functions. Do not return |
| // failure even if it doesn't get to fixed point. |
| if (failed(PropagateShapeIntoAttachedFunctions(op, max_iteration))) { |
| op->emitWarning() << "unable to refine shape of attached function " |
| "arguments and bodies"; |
| } |
| |
| if (failed(PropagateShapeIntoAttachedRegions(op, max_iteration))) { |
| op->emitWarning() << "unable to refine shape of attached region " |
| "arguments and bodies"; |
| } |
| |
| changed |= InferShapeForSingleOperation(op); |
| }); |
| |
| if (failed_due_inefficiency) { |
| LOG(ERROR) << "skipped inference due to b/180029566"; |
| return failure(); |
| } |
| } |
| |
| if (changed) { |
| return region->getParentOp()->emitWarning() |
| << "shape inference did not reach stable state after " |
| << max_iteration << " iterations"; |
| } |
| return success(); |
| } |
| |
| static LogicalResult InferShapeForFunction(ShapeInference& context, FuncOp func, |
| int64_t max_iterations) { |
| if (failed(context.InferShapeUntilFixPoint(&func.getBody(), max_iterations))) |
| return failure(); |
| // TODO(b/156276510): Verify that it is always fine to refine a function's |
| // return type, as long as we do not change the argument shapes. |
| context.InferShapeForFunctionReturnType(func); |
| |
| return success(); |
| } |
| |
| LogicalResult InferShapeForFunction(FuncOp func, |
| ArrayRef<ArrayRef<int64_t>> arg_shapes, |
| int64_t graph_version, |
| int64_t max_iterations) { |
| ShapeInference context(graph_version, func.getContext(), |
| /*propagate_caller_callee_constants=*/true); |
| if (arg_shapes.empty()) { |
| return InferShapeForFunction(context, func, max_iterations); |
| } |
| |
| FunctionType func_type = func.getType(); |
| bool needs_refinement = false; |
| SmallVector<Type, 4> new_arg_types; |
| new_arg_types.reserve(func_type.getNumInputs()); |
| |
| // Update argument types in-place using the provided arg_shapes. |
| for (size_t i = 0; i < func_type.getNumInputs(); ++i) { |
| ArrayRef<int64_t> shape = arg_shapes[i]; |
| Type element_type; |
| if (auto input_ty = func_type.getInput(i).dyn_cast<RankedTensorType>()) { |
| if (input_ty.getRank() != shape.size()) { |
| return failure(); |
| } |
| element_type = input_ty.getElementType(); |
| } else { |
| auto unranked_input_ty = func_type.getInput(i).dyn_cast<TensorType>(); |
| if (!unranked_input_ty) { |
| return failure(); |
| } |
| element_type = unranked_input_ty.getElementType(); |
| } |
| |
| auto new_arg_type = RankedTensorType::get(shape, element_type); |
| if (new_arg_type != func_type.getInput(i)) { |
| // If the new type is more detailed, trigger shape inference. |
| func.getArgument(i).setType(new_arg_type); |
| needs_refinement = true; |
| } |
| new_arg_types.push_back(new_arg_type); |
| } |
| |
| if (!needs_refinement) return success(); |
| |
| if (failed(context.InferShapeUntilFixPoint(&func.getBody(), max_iterations))) |
| return failure(); |
| |
| context.InferShapeForFunctionReturnType(func); |
| func.setType(FunctionType::get(func.getContext(), new_arg_types, |
| func.getType().getResults())); |
| |
| return success(); |
| } |
| |
| LogicalResult InferModuleShape(ModuleOp module, int64_t max_iterations) { |
| auto producer_or = tensorflow::GetTfGraphProducerVersion(module); |
| if (!producer_or.ok()) { |
| // TODO(jpienaar): Keeping the existing behavior for now but this could |
| // be relaxed. |
| LLVM_DEBUG(llvm::dbgs() |
| << "Skipping inference; " << producer_or.status().ToString()); |
| return success(); |
| } |
| int64_t producer = producer_or.ValueOrDie(); |
| // TODO(jpienaar): Clean up propagate_NextIterationSinkOp_callee_constants if |
| // it is no longer needed. |
| ShapeInference context(producer, module.getContext(), |
| /*propagate_caller_callee_constants=*/false); |
| if (auto main = module.lookupSymbol<mlir::FuncOp>("main")) |
| context.enqueue(main); |
| for (auto func : module.getOps<FuncOp>()) context.enqueue(func); |
| // Arbitrarily upper bound the maximum number of functions that get processed |
| // just to avoid pathological cases. |
| auto max_iteration = context.QueueSize() * 4; |
| while (!context.EmptyQueue()) { |
| FuncOp func = context.front(); |
| if (failed(InferShapeForFunction(context, func, max_iterations))) |
| return failure(); |
| context.pop_front(); |
| |
| if ((--max_iteration) == 0) { |
| return emitWarning(UnknownLoc::get(module.getContext())) |
| << "shape inference did not reach stable state after " |
| << max_iteration << " iterations"; |
| } |
| } |
| return success(); |
| } |
| |
| } // namespace TF |
| } // namespace mlir |