| /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
| |
| Licensed under the Apache License, Version 2.0 (the "License"); |
| you may not use this file except in compliance with the License. |
| You may obtain a copy of the License at |
| |
| http://www.apache.org/licenses/LICENSE-2.0 |
| |
| Unless required by applicable law or agreed to in writing, software |
| distributed under the License is distributed on an "AS IS" BASIS, |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| See the License for the specific language governing permissions and |
| limitations under the License. |
| ==============================================================================*/ |
| |
| #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_H_ |
| #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_H_ |
| |
| #include <atomic> |
| #include <list> |
| #include <memory> |
| #include <random> |
| #include <string> |
| #include <unordered_map> |
| #include <vector> |
| |
| #include "absl/strings/string_view.h" |
| #include "absl/types/optional.h" |
| #include "absl/types/span.h" |
| #include "tensorflow/compiler/xla/iterator_util.h" |
| #include "tensorflow/compiler/xla/service/hlo.pb.h" |
| #include "tensorflow/compiler/xla/service/hlo_clone_context.h" |
| #include "tensorflow/compiler/xla/service/hlo_computation.h" |
| #include "tensorflow/compiler/xla/service/hlo_instruction.h" |
| #include "tensorflow/compiler/xla/service/hlo_module_config.h" |
| #include "tensorflow/compiler/xla/service/hlo_schedule.h" |
| #include "tensorflow/compiler/xla/service/name_uniquer.h" |
| #include "tensorflow/compiler/xla/types.h" |
| #include "tensorflow/core/lib/gtl/iterator_range.h" |
| #include "tensorflow/core/platform/logging.h" |
| #include "tensorflow/core/platform/mutex.h" |
| |
| namespace xla { |
| |
| // Describes a compilation unit at the HLO level. |
| // |
| // HloModule is the top-level unit in the HLO IR. It corresponds to a whole |
| // "program". Running a module, from beginning to end, is the only way to run |
| // an XLA program. |
| // |
| // A module contains one "entry computation"; this HloComputation is like main() |
| // in a C program. The result of running the module is the result of running |
| // this computation. |
| // |
| // A module also contains some number of "nested computations". Each nested |
| // computation is attached to an HloInstruction within some other computation. |
| // The meaning of the nested computation depends on the instruction it's |
| // attached to. |
| class HloModule { |
| public: |
| // Constructor without a versioned computation handle. This constructor should |
| // only be used for HloModules used outside of the XLA service (eg |
| // tests). The versioned handle is used by the service in the compilation |
| // cache. A default configuration is created for this module. |
| explicit HloModule(const string& name, const HloModuleConfig& config); |
| virtual ~HloModule() {} |
| |
| // Adds an entry computation to the module. A module can only have one entry |
| // computation. Returns a pointer to the newly added computation. |
| HloComputation* AddEntryComputation( |
| std::unique_ptr<HloComputation> computation); |
| |
| // Adds an embedded computation to the module. |
| HloComputation* AddEmbeddedComputation( |
| std::unique_ptr<HloComputation> computation); |
| |
| // Removes an embedded computation. |
| Status RemoveEmbeddedComputation(HloComputation* to_remove); |
| |
| // Replaces all uses of computations that are keys of 'replacements' with |
| // the corresponding values in 'replacements'. Replaces the entry computation, |
| // if applicable. |
| // |
| // This function iterates over all instructions in the module to find |
| // computations to replace. We could speed it up by keeping track of users of |
| // computations. |
| void ReplaceComputations( |
| const std::unordered_map<HloComputation*, HloComputation*>& replacements); |
| |
| const string& name() const { return name_; } |
| void set_name(string name) { name_ = std::move(name); } |
| |
| // Returns a deep copy of this module including all computations. |
| std::unique_ptr<HloModule> Clone(const string& suffix = "clone") const; |
| |
| // Performs a deep clone of the computation, by recursively cloning all |
| // the called computations as well. If the clone context is specified, it |
| // will be populated with the cloned object mappings. |
| HloComputation* DeepCloneComputation(HloComputation* computation, |
| HloCloneContext* context = nullptr); |
| |
| // Return a pointer to the entry computation of the module.. |
| const HloComputation* entry_computation() const { |
| CHECK_NE(nullptr, entry_computation_); |
| return entry_computation_; |
| } |
| HloComputation* entry_computation() { |
| CHECK_NE(nullptr, entry_computation_); |
| return entry_computation_; |
| } |
| |
| // Creates the ComputationLayout which describes the current status of the HLO |
| // module entry computation. |
| ComputationLayout compute_computation_layout() const { |
| return ComputationLayout(entry_computation()->ComputeProgramShape(), |
| /*ignore_layouts=*/false); |
| } |
| |
| ComputationLayout* mutable_entry_computation_layout() { |
| return config_.mutable_entry_computation_layout(); |
| } |
| |
| const ComputationLayout& entry_computation_layout() const { |
| return config_.entry_computation_layout(); |
| } |
| |
| // Gets the computations in this module. |
| // |
| // Returns a view of HloComputation*s, so you can iterate over this in the |
| // natural way: |
| // |
| // for (HloComputation* c : module->computations()) { ... } |
| // |
| tensorflow::gtl::iterator_range<UnwrappingIterator< |
| std::vector<std::unique_ptr<HloComputation>>::const_iterator>> |
| computations() const { |
| return {MakeUnwrappingIterator(computations_.begin()), |
| MakeUnwrappingIterator(computations_.end())}; |
| } |
| tensorflow::gtl::iterator_range<UnwrappingIterator< |
| std::vector<std::unique_ptr<HloComputation>>::iterator>> |
| computations() { |
| return {MakeUnwrappingIterator(computations_.begin()), |
| MakeUnwrappingIterator(computations_.end())}; |
| } |
| |
| // Returns the computation in this module that has the name `name`. Returns |
| // null if there is no such computation. |
| HloComputation* GetComputationWithName(absl::string_view name); |
| |
| // Gets the number of computations in this module. |
| int64 computation_count() const { return computations_.size(); } |
| |
| // Gets the number of instructions in this module. |
| int64 instruction_count() const; |
| |
| // Compute and return a post order of all computations in the module. The sort |
| // is defined like so: if computation A has an instruction which calls |
| // computation B, then A will appear after B in the sort. |
| std::vector<HloComputation*> MakeComputationPostOrder() const; |
| |
| // Gets the computations in this module which aren't for fusion nodes. |
| // |
| // Postcondition: All computations in the returned list have |
| // !IsFusionComputation(). |
| // |
| // Note: Callers can and do rely on the return value here being a *snapshot* |
| // of the module's non-fusion computations -- that is, it's OK to add or |
| // remove computations from a module while iterating over |
| // MakeNonfusionComputations(). |
| std::vector<HloComputation*> MakeNonfusionComputations() const; |
| |
| const HloModuleConfig& config() const { return config_; } |
| |
| // Return a string representation of the module. |
| // |
| // (We express the default options using an overload rather than a default |
| // param because gdb ignores default params, but does resolve overloads.) |
| string ToString() const { return ToString(HloPrintOptions()); } |
| string ToString(const HloPrintOptions& options) const; |
| |
| // Convert an HloModule to or from a proto. |
| HloModuleProto ToProto() const; |
| static StatusOr<std::unique_ptr<HloModule>> CreateFromProto( |
| const HloModuleProto& proto, const HloModuleConfig& module_config); |
| |
| // Creates and returns an HloModuleConfig with an appropriate program shape |
| // for the HLO module in the given proto. |
| static StatusOr<HloModuleConfig> CreateModuleConfigFromProto( |
| const HloModuleProto& module, const DebugOptions& debug_options); |
| |
| // Outlines the given expression from the given computation. |
| // instructions_to_outline contains the instructions that form the expression. |
| // |
| // Precondition: instructions in instructions_to_outline are in topological |
| // order (root of outlined instructions last). TODO(jingyue): takes a set of |
| // instructions and topologically sorts them. |
| HloInstruction* OutlineExpressionFromComputation( |
| absl::Span<HloInstruction* const> instructions_to_outline, |
| const string& outlined_computation_name, HloComputation* computation); |
| |
| // Returns a randomly generated uint64. |
| uint64 RandomNew64() const; |
| |
| // Returns the NameUniquer for uniquing instruction names in this module. |
| NameUniquer& instruction_name_uniquer() { return instruction_name_uniquer_; } |
| |
| // Assign a new unique dense id for an instruction |
| int NewUniqueInstructionId() { |
| int result = next_unique_id_; |
| next_unique_id_++; |
| return result; |
| } |
| |
| // Returns the number of unique intruction ids given out. All ids up to |
| // this point are guaranteed to be in the range [0..NumUniqueInstructionIds()) |
| int NumUniqueInstructionIds() const { return next_unique_id_; } |
| |
| // Returns an id that is unique to this module across all modules created over |
| // the lifetime of this process. |
| int unique_id() const { return unique_id_; } |
| |
| // Returns a non-const version of the passed-in const HloInstruction*. This is |
| // safe on the argument that if you have a non-const module, then you can |
| // access all instructions in the module as non-const. |
| // |
| // Returns an error if the passed-in instruction is not from this module, |
| // except that it is allowed to pass in a null pointer. |
| // |
| // TODO(b/78350259): Eliminate const laundering. The argument above is not |
| // reliable since at any time someone could add or discover a way for a |
| // non-const module to transitively contain a const HloInstruction. The |
| // reliable way to do this would be to create a const laundering map from a |
| // module, mapping each encountered HloInstruction to its non-const version |
| // and then look up each instruction in need of laundering in that map, but |
| // this is much more expensive and complicated. This returns a Status instead |
| // of doing a CHECK-failure in part to make it strongly apparent that this is |
| // something that can fail. |
| StatusOr<HloInstruction*> LaunderConstInstructionFromModule( |
| const HloInstruction* hlo); |
| |
| // Sets the schedule of the module to the given schedule. |
| Status set_schedule(HloSchedule schedule); |
| |
| // Clears the schedule of the module. |
| void clear_schedule() { schedule_.reset(); } |
| |
| // Returns true if the module has a schedule set. |
| bool has_schedule() const { return schedule_.has_value(); } |
| |
| // Returns the schedue of the module. CHECK fails if no schedule is set. |
| const HloSchedule& schedule() const { return *schedule_; } |
| HloSchedule& schedule() { return *schedule_; } |
| |
| private: |
| HloComputation* AddComputationInternal( |
| std::unique_ptr<HloComputation> computation, bool is_entry, |
| bool uniquify_identifiers); |
| |
| string name_; |
| HloModuleConfig config_; |
| HloComputation* entry_computation_ = nullptr; |
| std::vector<std::unique_ptr<HloComputation>> computations_; |
| |
| // Random number generator engine to use when generating random numbers per |
| // HloModule compilation. |
| // TODO(b/25995601): Replace with better seed setting or dev/random for |
| // where we don't need deterministic execution. |
| mutable std::mt19937_64 rng_{42}; |
| mutable tensorflow::mutex rng_mutex_; |
| |
| // Unique name generator for computation and instruction names, which are |
| // unique per module. |
| NameUniquer computation_name_uniquer_{/*separator=*/"."}; |
| NameUniquer instruction_name_uniquer_{/*separator=*/"."}; |
| int next_unique_id_ = 0; |
| |
| // Used to keep track of the next unique module id that should be assigned. |
| static std::atomic<int> next_unique_module_id_; |
| // A unique id to label modules with. |
| int unique_id_; |
| |
| // The HloSchedule of the module. The schedule if it exists contains a |
| // sequential order of instructions for each non-fusion computation in the |
| // module. |
| absl::optional<HloSchedule> schedule_; |
| }; |
| |
| } // namespace xla |
| |
| #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_H_ |