| /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
| |
| Licensed under the Apache License, Version 2.0 (the "License"); |
| you may not use this file except in compliance with the License. |
| You may obtain a copy of the License at |
| |
| http://www.apache.org/licenses/LICENSE-2.0 |
| |
| Unless required by applicable law or agreed to in writing, software |
| distributed under the License is distributed on an "AS IS" BASIS, |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| See the License for the specific language governing permissions and |
| limitations under the License. |
| ==============================================================================*/ |
| |
| #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_H_ |
| #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_H_ |
| |
| #include <atomic> |
| #include <list> |
| #include <memory> |
| #include <random> |
| #include <string> |
| #include <unordered_map> |
| #include <vector> |
| |
| #include "absl/strings/string_view.h" |
| #include "absl/types/optional.h" |
| #include "absl/types/span.h" |
| #include "tensorflow/compiler/xla/iterator_util.h" |
| #include "tensorflow/compiler/xla/service/dynamic_parameter_binding.h" |
| #include "tensorflow/compiler/xla/service/hlo.pb.h" |
| #include "tensorflow/compiler/xla/service/hlo_clone_context.h" |
| #include "tensorflow/compiler/xla/service/hlo_computation.h" |
| #include "tensorflow/compiler/xla/service/hlo_input_output_alias_config.h" |
| #include "tensorflow/compiler/xla/service/hlo_instruction.h" |
| #include "tensorflow/compiler/xla/service/hlo_module_config.h" |
| #include "tensorflow/compiler/xla/service/hlo_schedule.h" |
| #include "tensorflow/compiler/xla/service/name_uniquer.h" |
| #include "tensorflow/compiler/xla/types.h" |
| #include "tensorflow/core/lib/gtl/iterator_range.h" |
| #include "tensorflow/core/platform/logging.h" |
| #include "tensorflow/core/platform/mutex.h" |
| |
| namespace xla { |
| |
| // Describes a compilation unit at the HLO level. |
| // |
| // HloModule is the top-level unit in the HLO IR. It corresponds to a whole |
| // "program". Running a module, from beginning to end, is the only way to run |
| // an XLA program. |
| // |
| // A module contains one "entry computation"; this HloComputation is like main() |
| // in a C program. The result of running the module is the result of running |
| // this computation. |
| // |
| // A module also contains some number of "nested computations". Each nested |
| // computation is attached to an HloInstruction within some other computation. |
| // The meaning of the nested computation depends on the instruction it's |
| // attached to. |
| class HloModule { |
| public: |
| // Constructor without a versioned computation handle. This constructor should |
| // only be used for HloModules used outside of the XLA service (eg |
| // tests). The versioned handle is used by the service in the compilation |
| // cache. A default configuration is created for this module. |
| explicit HloModule(const string& name, HloModuleConfig config); |
| virtual ~HloModule() {} |
| |
| // Adds an entry computation to the module. A module can only have one entry |
| // computation. Returns a pointer to the newly added computation. |
| HloComputation* AddEntryComputation( |
| std::unique_ptr<HloComputation> computation); |
| |
| // Replaces the current entry computation with another computation. |
| // The new entry computation must be a computation that is already in the |
| // module. |
| void ReplaceEntryComputation(HloComputation* entry_computation); |
| |
| // Adds an embedded computation to the module. |
| HloComputation* AddEmbeddedComputation( |
| std::unique_ptr<HloComputation> computation); |
| |
| // Removes an embedded computation. |
| Status RemoveEmbeddedComputation(HloComputation* to_remove); |
| |
| // Removes unused computations. |
| Status RemoveUnusedComputations(); |
| |
| // Replaces all uses of computations that are keys of 'replacements' with |
| // the corresponding values in 'replacements'. Replaces the entry computation, |
| // if applicable. |
| // |
| // This function iterates over all instructions in the module to find |
| // computations to replace. We could speed it up by keeping track of users of |
| // computations. |
| void ReplaceComputations( |
| const std::unordered_map<HloComputation*, HloComputation*>& replacements); |
| |
| const string& name() const { return name_; } |
| void set_name(string name) { name_ = std::move(name); } |
| |
| // Returns a deep copy of this module including all computations. |
| std::unique_ptr<HloModule> Clone(const string& suffix = "clone") const; |
| std::unique_ptr<HloModule> Clone(const HloModuleConfig& config, |
| const string& suffix = "clone") const; |
| |
| // Performs a deep clone of the computation, by recursively cloning all |
| // the called computations as well. If the clone context is specified, it |
| // will be populated with the cloned object mappings. |
| HloComputation* DeepCloneComputation(HloComputation* computation, |
| HloCloneContext* context = nullptr); |
| |
| // Return a pointer to the entry computation of the module. |
| HloComputation* entry_computation() const { |
| CHECK_NE(nullptr, entry_computation_); |
| return entry_computation_; |
| } |
| |
| bool has_entry_computation() const { return entry_computation_ != nullptr; } |
| |
| // Returns the root instruction shape of entry computation. |
| // |
| // Precondition: entry_computation_ is not nullptr. |
| const Shape& result_shape() const { |
| CHECK_NE(nullptr, entry_computation_); |
| return entry_computation()->root_instruction()->shape(); |
| } |
| |
| // Creates the ComputationLayout which describes the current status of the HLO |
| // module entry computation. |
| ComputationLayout compute_computation_layout() const { |
| return ComputationLayout(entry_computation()->ComputeProgramShape(), |
| /*ignore_layouts=*/false); |
| } |
| |
| ComputationLayout* mutable_entry_computation_layout() { |
| return config_.mutable_entry_computation_layout(); |
| } |
| |
| const ComputationLayout& entry_computation_layout() const { |
| return config_.entry_computation_layout(); |
| } |
| |
| // Generates a hash value of an HLO module. Hash considers |
| // information on opcode, shape, operands, and typically a root instruction. |
| // This function returns the same hash value for equivalent HLO modules, |
| // with respect to HloInstruction::Identical() method. |
| uint64 Hash() const; |
| |
| // Gets the computations in this module. |
| // |
| // Returns a view of HloComputation*s, so you can iterate over this in the |
| // natural way: |
| // |
| // for (HloComputation* c : module->computations()) { ... } |
| // |
| tensorflow::gtl::iterator_range<UnwrappingIterator< |
| std::vector<std::unique_ptr<HloComputation>>::const_iterator>> |
| computations() const { |
| return {MakeUnwrappingIterator(computations_.begin()), |
| MakeUnwrappingIterator(computations_.end())}; |
| } |
| tensorflow::gtl::iterator_range<UnwrappingIterator< |
| std::vector<std::unique_ptr<HloComputation>>::iterator>> |
| computations() { |
| return {MakeUnwrappingIterator(computations_.begin()), |
| MakeUnwrappingIterator(computations_.end())}; |
| } |
| |
| // Returns the computation in this module that has the name `name`. Returns |
| // null if there is no such computation. |
| HloComputation* GetComputationWithName(absl::string_view name); |
| |
| // Gets the number of computations in this module. |
| int64 computation_count() const { return computations_.size(); } |
| |
| // Returns the mutable computation for the given index. |
| HloComputation* mutable_computation(int64 idx) { |
| CHECK(idx >= 0 && idx < computations_.size()); |
| return computations_[idx].get(); |
| } |
| |
| // Gets the number of instructions in this module. |
| int64 instruction_count() const; |
| |
| // Compute and return a post order of all computations in the module. The sort |
| // is defined like so: if computation A has an instruction which calls |
| // computation B, then A will appear after B in the sort. |
| std::vector<HloComputation*> MakeComputationPostOrder() const; |
| |
| // Gets the computations in this module which aren't for fusion nodes. |
| // |
| // Postcondition: All computations in the returned list have |
| // !IsFusionComputation(). |
| // |
| // Note: Callers can and do rely on the return value here being a *snapshot* |
| // of the module's non-fusion computations -- that is, it's OK to add or |
| // remove computations from a module while iterating over |
| // MakeNonfusionComputations(). |
| std::vector<HloComputation*> MakeNonfusionComputations() const; |
| |
| // Same as MakeNonfusionComputations() but sorting the computations by names. |
| std::vector<HloComputation*> MakeNonfusionComputationsSorted() const; |
| |
| const HloModuleConfig& config() const { return config_; } |
| void set_config(const HloModuleConfig& config) { config_ = config; } |
| |
| // Return a string representation of the module. |
| // |
| // (We express the default options using an overload rather than a default |
| // param because gdb ignores default params, but does resolve overloads.) |
| string ToString() const { return ToString(HloPrintOptions()); } |
| string ToString(const HloPrintOptions& options) const; |
| |
| // Convert an HloModule to or from a proto. |
| HloModuleProto ToProto() const; |
| static StatusOr<std::unique_ptr<HloModule>> CreateFromProto( |
| const HloModuleProto& proto, const HloModuleConfig& module_config, |
| bool prohibit_empty_literal = true); |
| |
| // Creates and returns an HloModuleConfig with an appropriate program shape |
| // for the HLO module in the given proto. |
| static StatusOr<HloModuleConfig> CreateModuleConfigFromProto( |
| const HloModuleProto& module, const DebugOptions& debug_options, |
| const ExecutionOptions* execution_options = nullptr); |
| |
| // Creates and returns an HloModuleConfig with an appropriate program shape |
| // for the HLO module in the given proto. |
| static StatusOr<HloModuleConfig> CreateModuleConfigFromShape( |
| const ProgramShape& program_shape, const DebugOptions& debug_options, |
| const ExecutionOptions* execution_options = nullptr); |
| |
| // Outlines the given expression from the given computation. |
| // instructions_to_outline contains the instructions that form the expression. |
| // |
| // Precondition: instructions in instructions_to_outline are in topological |
| // order (root of outlined instructions last). TODO(jingyue): takes a set of |
| // instructions and topologically sorts them. |
| HloInstruction* OutlineExpressionFromComputation( |
| absl::Span<HloInstruction* const> instructions_to_outline, |
| const string& outlined_computation_name, HloComputation* computation); |
| |
| // Returns a randomly generated uint64. |
| uint64 RandomNew64() const; |
| |
| // Returns the NameUniquer for uniquing instruction names in this module. |
| NameUniquer& instruction_name_uniquer() { return instruction_name_uniquer_; } |
| |
| // Assign a new unique dense id for an instruction |
| int NewUniqueInstructionId() { |
| int result = next_unique_id_; |
| next_unique_id_++; |
| return result; |
| } |
| |
| // input_output_alias_config indicates the list of aliased buffers that are |
| // expected from the module. |
| HloInputOutputAliasConfig& input_output_alias_config() { |
| return input_output_alias_config_; |
| } |
| const HloInputOutputAliasConfig& input_output_alias_config() const { |
| return input_output_alias_config_; |
| } |
| |
| // DynamicParameterBinding holds the list of bindings that indicates which |
| // parameter dimensions are dynamic and which parameters represent their |
| // runtime value. |
| DynamicParameterBinding& dynamic_parameter_binding() { |
| return dynamic_parameter_binding_; |
| } |
| const DynamicParameterBinding& dynamic_parameter_binding() const { |
| return dynamic_parameter_binding_; |
| } |
| |
| // Returns an id that is unique to this module across all modules created over |
| // the lifetime of this process. |
| int unique_id() const { return unique_id_; } |
| |
| // Sets the schedule of the module to the given schedule. |
| Status set_schedule(HloSchedule schedule); |
| |
| // Clears the schedule of the module. |
| void clear_schedule() { schedule_.reset(); } |
| |
| // Returns true if the module has a schedule set. |
| bool has_schedule() const { return schedule_.has_value(); } |
| |
| // Returns the schedue of the module. CHECK fails if no schedule is set. |
| const HloSchedule& schedule() const { return *schedule_; } |
| HloSchedule& schedule() { return *schedule_; } |
| |
| HloComputation* AddComputationAndUnifyNamesAndIds( |
| std::unique_ptr<HloComputation> computation, bool is_entry) { |
| computation->ClearUniqueIdInternal(); |
| for (auto* instruction : computation->instructions()) { |
| instruction->ClearUniqueIdInternal(); |
| } |
| return AddComputationInternal(std::move(computation), is_entry, |
| /*uniquify_identifiers=*/true); |
| } |
| |
| Status CheckUniqueNamesAndIdsForComputationsAndInstructions() const; |
| |
| // Checks if this config has a list of entry parameters' HLO shardings for |
| // SPMD. |
| bool has_spmd_parameters_shardings() const { |
| return spmd_parameters_shardings_.has_value(); |
| } |
| |
| // Getter and setter for the list of entry parameters' HLO shardings for SPMD. |
| const std::vector<HloSharding>& spmd_parameters_shardings() const { |
| CHECK(spmd_parameters_shardings_.has_value()); |
| return *spmd_parameters_shardings_; |
| } |
| void set_spmd_parameters_shardings( |
| const std::vector<HloSharding>& shardings) { |
| spmd_parameters_shardings_ = shardings; |
| } |
| |
| // Checks if this config has the entry computation output's HLO sharding for |
| // SPMD. |
| bool has_spmd_output_sharding() const { |
| return spmd_output_sharding_.has_value(); |
| } |
| |
| // Getter and setter for the entry computation output's HLO shardings for |
| // SPMD. |
| const HloSharding& spmd_output_sharding() const { |
| CHECK(spmd_output_sharding_.has_value()); |
| return *spmd_output_sharding_; |
| } |
| void set_spmd_output_sharding(const HloSharding& sharding) { |
| spmd_output_sharding_ = sharding; |
| } |
| |
| private: |
| HloComputation* AddComputationInternal( |
| std::unique_ptr<HloComputation> computation, bool is_entry, |
| bool uniquify_identifiers); |
| |
| // Same as MakeComputationPostOrder() but sorting the computations by their |
| // contents. |
| std::vector<HloComputation*> MakeComputationSortedByContent() const; |
| |
| string name_; |
| HloModuleConfig config_; |
| HloComputation* entry_computation_ = nullptr; |
| std::vector<std::unique_ptr<HloComputation>> computations_; |
| |
| // Random number generator engine to use when generating random numbers per |
| // HloModule compilation. |
| // TODO(b/25995601): Replace with better seed setting or dev/random for |
| // where we don't need deterministic execution. |
| mutable std::mt19937_64 rng_{42}; |
| mutable tensorflow::mutex rng_mutex_; |
| |
| // Unique name generator for computation and instruction names, which are |
| // unique per module. |
| NameUniquer computation_name_uniquer_{/*separator=*/"."}; |
| NameUniquer instruction_name_uniquer_{/*separator=*/"."}; |
| int next_unique_id_ = 0; |
| |
| // Used to keep track of the next unique module id that should be assigned. |
| static std::atomic<int> next_unique_module_id_; |
| // A unique id to label modules with. |
| int unique_id_; |
| |
| // The HloSchedule of the module. The schedule if it exists contains a |
| // sequential order of instructions for each non-fusion computation in the |
| // module. |
| absl::optional<HloSchedule> schedule_; |
| |
| // alias_config indicates the alias information of input/output buffers that |
| // are expected from the module. |
| HloInputOutputAliasConfig input_output_alias_config_; |
| |
| // Bindings for dynamic parameter mapping. |
| DynamicParameterBinding dynamic_parameter_binding_; |
| |
| // The HLO shardings of the entry computation's parameters for |
| // SPMD-partitioned programs. |
| absl::optional<std::vector<HloSharding>> spmd_parameters_shardings_; |
| |
| // The HLO sharding of the entry computation's output (root) for |
| // SPMD-partitioned programs. |
| absl::optional<HloSharding> spmd_output_sharding_; |
| }; |
| |
| } // namespace xla |
| |
| #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_H_ |