| /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
| |
| Licensed under the Apache License, Version 2.0 (the "License"); |
| you may not use this file except in compliance with the License. |
| You may obtain a copy of the License at |
| |
| http://www.apache.org/licenses/LICENSE-2.0 |
| |
| Unless required by applicable law or agreed to in writing, software |
| distributed under the License is distributed on an "AS IS" BASIS, |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| See the License for the specific language governing permissions and |
| limitations under the License. |
| ==============================================================================*/ |
| |
| // Analysis for determining the possible set of values for all positions |
| // (instructions and ShapeIndexes) in the HLO module. Analysis is module-scoped |
| // tracking values across computation boundaries. |
| |
| #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_DATAFLOW_ANALYSIS_H_ |
| #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_DATAFLOW_ANALYSIS_H_ |
| |
| #include <iterator> |
| #include <memory> |
| #include <string> |
| #include <vector> |
| |
| #include "absl/container/flat_hash_map.h" |
| #include "absl/container/flat_hash_set.h" |
| #include "absl/types/span.h" |
| #include "tensorflow/compiler/xla/service/call_graph.h" |
| #include "tensorflow/compiler/xla/service/hlo_instruction.h" |
| #include "tensorflow/compiler/xla/service/hlo_module.h" |
| #include "tensorflow/compiler/xla/service/hlo_phi_graph.h" |
| #include "tensorflow/compiler/xla/service/hlo_value.h" |
| #include "tensorflow/compiler/xla/shape_util.h" |
| #include "tensorflow/compiler/xla/status.h" |
| #include "tensorflow/compiler/xla/statusor.h" |
| #include "tensorflow/compiler/xla/types.h" |
| #include "tensorflow/compiler/xla/xla_data.pb.h" |
| |
| namespace xla { |
| |
| // Analysis which identifies all HLO values and their uses in an HLO module. |
| class HloDataflowAnalysis { |
| public: |
| // Infrastructure for passing may-alias hints: HLO passes can populate the |
| // may-alias table. If an empty optional is returned, default rules are used. |
| // |
| // Must-alias rules (as defined by GetInPlaceInputOutputPairs) cannot be |
| // overriden using backend-specific overrides. |
| // |
| // The first parameter of the function should be the instruction, the |
| // second parameter should be an operand of the instruction. The third |
| // parameter should be the output index of the instruction. |
| using CanShareBuffer = std::function<absl::optional<bool>( |
| const HloInstruction* instr, const HloInstruction* operand, |
| const ShapeIndex& user_index)>; |
| |
| // Runs dataflow analysis on the given module. Parameters: |
| // |
| // ssa_form : If true then new values are defined at the merge points of |
| // kWhile instructions. Abusing nomenclature somewhat, we call these "phi |
| // values". The merge is formed by the init value and loop backedge. The |
| // SSA form is minimal in that a new phi value is defined only if the |
| // merge point is reachable by multiple different values. The SSA form is |
| // also in loop-closed form in that no values defined inside of a loop |
| // (while body) is used outside of the loop. Example use of this ssa_form |
| // mode is to reason about live range interference of buffers. |
| // |
| // If ssa_form is false, then merge points do not define new |
| // values. Rather, the HloValueSet for the merge point contains the union |
| // of the merged HloValues. |
| // |
| // bitcast_defines_value : If true then the Bitcast HLO instruction defines |
| // a new HLO value in the analysis. If false then Bitcast forwards the |
| // value of its operand. |
| static StatusOr<std::unique_ptr<HloDataflowAnalysis>> Run( |
| const HloModule& module, bool ssa_form = false, |
| bool bitcast_defines_value = false, |
| const CanShareBuffer& can_share_buffer = nullptr); |
| |
| static bool AreTransitiveUsesElementwiseOrTuple(const HloInstruction* inst); |
| |
| // Returns true if 'instruction' defines an HLO value at the given shape index |
| // of its output. |
| bool ValueIsDefinedAt(const HloInstruction* instruction, |
| const ShapeIndex& index = {}) const; |
| |
| // Returns the HloValue defined by 'instruction' at the given shape index of |
| // its output. |
| // |
| // Precondition: ValueIsDefinedAt is true for this instruction and index. |
| const HloValue& GetValueDefinedAt(const HloInstruction* instruction, |
| const ShapeIndex& index = {}) const; |
| HloValue& GetValueDefinedAt(const HloInstruction* instruction, |
| const ShapeIndex& index = {}); |
| |
| // Returns the InstructionValueSet for the given instruction. |
| const InstructionValueSet& GetInstructionValueSet( |
| const HloInstruction* instruction) const; |
| InstructionValueSet& GetInstructionValueSet( |
| const HloInstruction* instruction); |
| |
| // Returns all values that are contained in the output of this instruction in |
| // a flattened set. |
| HloValueSet GetFlattenedValueSet(const HloInstruction* instruction) const; |
| |
| // Returns the HloValueSet for the given instruction at the given index or the |
| // given position. |
| const HloValueSet& GetValueSet(const HloInstruction* instruction, |
| const ShapeIndex& index = {}) const; |
| const HloValueSet& GetValueSet(const HloPosition& position) const; |
| HloValueSet& GetValueSet(const HloPosition& position); |
| HloValueSet& GetValueSet(const HloInstruction* instruction, |
| const ShapeIndex& index = {}); |
| |
| // Returns the unique value in the HloValueSet at the given instruction and |
| // shape index. CHECKs if the value set does not contain a exactly one value. |
| const HloValue& GetUniqueValueAt(const HloInstruction* instruction, |
| const ShapeIndex& index = {}) const { |
| return GetValueSet(instruction, index).GetUniqueValue(); |
| } |
| HloValue& GetUniqueValueAt(const HloInstruction* instruction, |
| const ShapeIndex& index = {}) { |
| return GetValue(GetValueSet(instruction, index).GetUniqueValue().id()); |
| } |
| |
| // Returns the HloValue with the given Id. |
| const HloValue& GetValue(HloValue::Id value_id) const; |
| HloValue& GetValue(HloValue::Id value_id); |
| |
| // Returns the total number of HloValues. |
| int64_t value_count() const { return values_.size(); } |
| |
| // Returns a vector of all HloValues stabily sorted by HloValue::Id. |
| const std::vector<HloValue*>& values() const { return values_vector_; } |
| |
| // Returns the call graph used for computing the dataflow. |
| const CallGraph& call_graph() const { return *call_graph_; } |
| |
| std::string ToString() const; |
| |
| // Returns true if 'user' cannot possibly use the buffer at 'index' in |
| // 'operand'. Returns false otherwise. |
| // |
| // 'operand' does not have to be an operand of 'user'. This can be the |
| // case with indirect uses. |
| bool DoesNotUseOperandBuffer(const HloInstruction* operand, |
| const ShapeIndex& index, |
| const HloInstruction* user) const; |
| |
| // Returns true if 'user' (at 'user_index') can share a buffer with its |
| // operand 'operand' (at 'operand_index'). Returns false otherwise. |
| // |
| // REQUIRES: 'operand' is an operand of 'user'. |
| bool CanShareOperandBufferWithUser(HloInstruction* operand, |
| const ShapeIndex& operand_index, |
| HloInstruction* user, |
| const ShapeIndex& user_index) const; |
| |
| const HloModule& module() const { return module_; } |
| |
| // Returns true if the operation is an in-place operation and its operand 0 |
| // must alias with the output. |
| static bool IsInPlaceOperation(HloOpcode opcode); |
| |
| // Returns true if the operation is the start/done of an asynchronous |
| // operation, where the buffer used/produced by the op needs to stay alive |
| // until the asynchronous operation completes. |
| static bool IsAsynchronousOperationStart(HloOpcode opcode); |
| static bool IsAsynchronousOperationDone(HloOpcode opcode); |
| |
| // Returns a vector consisting of the HloUse (operand number and shape index) |
| // and output shape index of the in-place operations within this HLO. |
| static std::vector<std::pair<HloUse, ShapeIndex>> GetInPlaceInputOutputPairs( |
| HloInstruction* instruction); |
| |
| protected: |
| HloDataflowAnalysis(const HloModule& module, bool ssa_form, |
| bool bitcast_defines_value = false, |
| const CanShareBuffer& can_share_buffer = nullptr); |
| |
| // 1. During value propagation (Propagate function), always create phi |
| // values once it see multiple inputs merging at the same point. It then |
| // records those phi values as well as their inputs in a phi graph. |
| // |
| // 2. Post value propagation, Dataflow analysis can then do certain |
| // optimization(OptimizePhiValues) on the phi graph to prune uncessary phi |
| // nodes. |
| // |
| // Note that this applies in SSA form, and Both of the functions are |
| // guaranteed to exit. |
| // |
| void OptimizePhiValues(); |
| |
| // Returns a new HloValue defined at the given instruction and shape index. |
| HloValue* NewHloValue(HloInstruction* instruction, const ShapeIndex& index, |
| bool is_phi); |
| |
| // Marks the HloValue with the given ID for deletion. |
| void MarkValueForDeletion(HloValue::Id value_id); |
| |
| // Deletes all HloValues marked for deletion. Should be called after |
| // propagation is complete. |
| void DeleteMarkedValues(); |
| |
| // Constructs and initializes the InstructionValueSets of all instructions to |
| // contain exactly the HloValues defined by each instruction. These values can |
| // then propagated throughout the HLO graph by calling Propagate. |
| Status InitializeInstructionValueSets(); |
| |
| // Updates the value set of the given instruction based on the values flowing |
| // into the instruction (operands and cross-computation dataflow). |
| bool UpdateInstructionValueSet(HloInstruction* instruction); |
| |
| // Updates the value set for a particular instruction type. Returns whether |
| // the instruction value set changed. |
| bool UpdateBitcastValueSet(HloInstruction* bitcast); |
| bool UpdateCallValueSet(HloInstruction* call); |
| bool UpdateConditionalValueSet(HloInstruction* conditional); |
| bool UpdateCopyValueSet(HloInstruction* copy); |
| bool UpdateCustomCallValueSet(HloInstruction* custom_call); |
| bool UpdateDomainValueSet(HloInstruction* domain); |
| bool UpdateGetTupleElementValueSet(HloInstruction* gte); |
| bool UpdateParameterValueSet(HloInstruction* parameter); |
| bool UpdateCopyStartValueSet(HloInstruction* copy_start); |
| bool UpdateCopyDoneValueSet(HloInstruction* copy_done); |
| bool UpdateOptimizationBarrierValueSet(HloInstruction* barrier); |
| bool UpdateRecvDoneValueSet(HloInstruction* recv_done); |
| bool UpdateTupleSelectValueSet(HloInstruction* select); |
| bool UpdateSendValueSet(HloInstruction* send); |
| bool UpdateSetDimensionSizeValueSet(HloInstruction* set_dimension_size); |
| bool UpdateTupleValueSet(HloInstruction* tuple); |
| bool UpdateWhileValueSet(HloInstruction* xla_while); |
| bool UpdateAddDependencyValueSet(HloInstruction* add_dependency); |
| bool UpdateAllGatherStartValueSet(HloInstruction* all_gather_start); |
| bool UpdateAllGatherDoneValueSet(HloInstruction* all_gather_done); |
| bool UpdateAllReduceDoneValueSet(HloInstruction* all_reduce_done); |
| bool UpdateCollectivePermuteStartValueSet( |
| HloInstruction* collective_permute_start); |
| bool UpdateCollectivePermuteDoneValueSet( |
| HloInstruction* collective_permute_done); |
| |
| // Propagates the dataflow through the module. In particular, it propagates |
| // the HloValueSet from its defining instruction to the users of the |
| // instructions. |
| void Propagate(); |
| |
| // Returns the result of the SSA Phi function applied to the given inputs at |
| // the given instruction. |
| bool Phi(HloInstruction* instruction, |
| absl::Span<const InstructionValueSet* const> inputs); |
| |
| // Updates the positions of the HloValues in the output of the given |
| // instruction. This should be called after the instruction value set of |
| // 'instruction' has been changed. 'prev_value_set' must point to the previous |
| // state of the value set prior to the change. 'prev_value_set' may be null if |
| // this is the first time positions are being computed. The previous state is |
| // necessary to efficiently remove positions which have been eliminated due to |
| // changes in the instructions' InstructionValueSet. |
| void UpdatePositionsOfValuesAt( |
| HloInstruction* instruction, const InstructionValueSet& new_value_set, |
| const InstructionValueSet* prev_value_set = nullptr); |
| |
| // Verifies various invariants of the dataflow analysis. |
| Status Verify() const; |
| |
| const HloModule& module_; |
| const bool ssa_form_; |
| const bool bitcast_defines_value_; |
| |
| std::unique_ptr<CallGraph> call_graph_; |
| |
| // The map of all HloValues in the module. We pass around pointers to the |
| // mapped HloValues, so the underlying container must keep them valid despite |
| // mutations touching other map entries. |
| absl::flat_hash_map<HloValue::Id, std::unique_ptr<HloValue>> values_; |
| |
| // A map from instruction to InstructionValueSet. |
| absl::flat_hash_map<const HloInstruction*, |
| std::unique_ptr<InstructionValueSet>> |
| value_sets_; |
| |
| // Values marked for deletion during construction. We don't delete them |
| // immediately because references to them may remain in ValueSets temporarily |
| // during propagation. After construction, these values are deleted. |
| std::vector<HloValue::Id> value_ids_to_delete_; |
| |
| // A vector containing all HloValues sorted by HloValue::Id. |
| std::vector<HloValue*> values_vector_; |
| |
| // The Id to use for the next HloValue. |
| HloValue::Id next_value_id_ = 0; |
| |
| // An explicit graph holding phi values and edges. |
| PhiGraph phi_graph_; |
| |
| // Backend specific function that decides whether an instruction can share |
| // a buffer with its operand. |
| CanShareBuffer can_share_buffer_ = nullptr; |
| }; |
| |
| } // namespace xla |
| |
| #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_DATAFLOW_ANALYSIS_H_ |