blob: 40881b4b91c019e28a6bfcf058a81ff2e002ccae [file] [log] [blame]
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_SPMD_SPMD_PARTITIONER_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_SPMD_SPMD_PARTITIONER_H_
#include <memory>
#include <string>
#include <unordered_map>
#include "absl/types/optional.h"
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
#include "tensorflow/compiler/xla/service/hlo_sharding.h"
namespace xla {
namespace spmd {
struct SpmdPartitionerOptions {
// Always exchange halo on LHS for all convolutions. If false, backprop filter
// convolution exchanges halo on RHS.
bool conv_halo_exchange_always_on_lhs = true;
// The number of instructions to be reported for the highest memory profile
// instructions.
int64 report_instruction_count = 5;
// The minimum size in MiB of an einsum operand to be considered using
// windowed implementation in an HLO loop.
int64 threshold_for_windowed_einsum_mib = 256;
// Whether the entry computations' signature could change after partitioning.
bool allow_module_signature_change = false;
};
// Class to wrap the computation builder to capture information during SPMD
// transformation.
class SpmdBuilder : public HloComputation::Builder {
public:
SpmdBuilder(const std::string& name, HloInstruction* hlo)
: HloComputation::Builder(name) {
visiting_hlo_ = hlo;
}
HloInstruction* AddInstruction(std::unique_ptr<HloInstruction> instruction);
const std::vector<HloInstruction*>& derived_instructions(
HloInstruction* hlo) {
return instructions_.at(hlo);
}
void set_visiting_hlo(HloInstruction* hlo) { visiting_hlo_ = hlo; }
HloInstruction* visiting_hlo() const { return visiting_hlo_; }
private:
// Currently visiting instruction.
HloInstruction* visiting_hlo_;
// Map from the currently visiting (old) instruction to new instructions
// created during SPMD partitioning.
HloInstructionMap<std::vector<HloInstruction*>> instructions_;
};
// A set of functions that create the cross-partition collective ops.
struct SPMDCollectiveOpsCreator {
// Function used to create a partition ID HLO.
std::function<HloInstruction*(SpmdBuilder*)> create_partition_id;
// Function used to create a cross-partition all-reduce HLO.
std::function<HloInstruction*(SpmdBuilder*, HloInstruction* operand,
HloComputation* reduction, int64 channel_id)>
create_cross_partition_all_reduce;
// Function used to create a cross-partition collective-permute HLO.
std::function<HloInstruction*(
SpmdBuilder*, HloInstruction* operand,
std::vector<std::pair<int64, int64>>& src_dst_pairs,
int64 next_channel_id)>
create_cross_partition_collective_permute;
// Function used to create a cross-partition all-to-all HLO.
std::function<HloInstruction*(
SpmdBuilder*, absl::Span<HloInstruction* const> operands,
const std::vector<ReplicaGroup>& replica_groups, int64 channel_id,
absl::optional<int64> split_dimension)>
create_cross_partition_all_to_all;
// Function used to create a cross-partition all-gather HLO. This is optional:
// if it is nullptr, the partitioner will use all-reduce instead.
std::function<HloInstruction*(
SpmdBuilder*, HloInstruction* operand, const Shape& ag_shape,
const std::vector<std::vector<int64>>& partition_subgroups,
int64 channel_id, int64 all_gather_dimension)>
create_cross_partition_all_gather;
};
// Create a default SPMDCollectiveOpsCreator.
SPMDCollectiveOpsCreator GetDefaultCollectiveOpsCreator(int64 num_partitions,
int64 num_replicas);
// Logger to report memory usage during SPMD partitioning.
class SpmdLogger {
public:
explicit SpmdLogger(int64 report_instruction_count)
: report_instruction_count_(report_instruction_count) {}
static std::string ReportBeforePartition(const HloModule& module,
int64 report_instruction_count);
static std::string ReportAfterPartition(const HloModule& module,
int64 report_instruction_count);
// Registers the logging for the groups of instructions created to transform
// the given hlo.
void RegisterLogEntry(HloInstruction* hlo,
const std::vector<HloInstruction*>& group);
std::string MakeReport();
private:
template <typename F>
static std::string ReportMemoryUsage(const HloModule& module, const F& filter,
int64 report_instruction_count);
// A vector of logging messages (one for each original HLO instruction), where
// the first integer of the pair represents the size of the HBM used.
std::vector<std::pair<int64, std::string>> entries_;
int64 report_instruction_count_;
};
class SpmdPartitioningVisitor;
class SpmdPartitioner : public HloModulePass {
public:
SpmdPartitioner(int64 num_partitions, int64 num_replicas,
SpmdPartitionerOptions options);
SpmdPartitioner(int64 num_partitions, int64 num_replicas,
SpmdPartitionerOptions options,
SPMDCollectiveOpsCreator collective_ops_creator)
: num_partitions_(num_partitions),
num_replicas_(num_replicas),
options_(std::move(options)),
collective_ops_creator_(std::move(collective_ops_creator)) {}
absl::string_view name() const override { return "spmd-partitioning"; }
StatusOr<bool> Run(HloModule* module) override;
// Transforms the given computation with SPMD instructions, replacing it with
// a new computation.
StatusOr<bool> PartitionComputation(HloComputation* computation,
const HloSharding& root_sharding,
int64* next_channel_id,
SpmdLogger* logger);
// Creates all-gather based on HloSharding. Can be overridden to customize.
// The default uses a single all-gather even if there are multiple sharded
// dimensions, and adds potential reshapes and transposes to achieve that.
// If it returns false, the partitioner will fall back to all-reduce.
virtual HloInstruction* AllGatherShards(SpmdBuilder* b,
HloInstruction* operand,
const HloSharding& sharding,
int64 channel_id);
protected:
virtual std::unique_ptr<SpmdPartitioningVisitor> CreateVisitor(
HloComputation* computation, int64 num_partitions, int64 num_replicas,
const SPMDCollectiveOpsCreator& collective_ops_creator,
int64* next_channel_id, SpmdLogger* logger,
SpmdPartitionerOptions options);
// Verify that the sharding of instructions in the module are valid, and also
// fill in missing sharding information.
Status PreprocessSharding(HloModule* module);
const int64 num_partitions_;
const int64 num_replicas_;
SpmdPartitionerOptions options_;
SPMDCollectiveOpsCreator collective_ops_creator_;
};
// Class describes partition state of the data represented by an HLO created
// during SPMD partitioning pass.
//
// Data on some devices may include padding region, if the base (full) shape
// could not be evenly partitioned.
class PartitionedHlo {
public:
// Return value for ReshardAsWindowedInput which describes the resharded HLO,
// the window for the user on the shard, and if necessary, the dynamic slice
// offsets to be applied to the output of the op being sharded.
struct WindowedInputShardReturnValue {
HloInstruction* sharded_input;
Window shard_window;
absl::optional<std::vector<HloInstruction*>> dynamic_slice_index_on_output;
};
// A cache for resharding each partitioned HLO.
struct ReshardCache {
struct PerHloCache {
std::vector<std::pair<HloSharding, PartitionedHlo>> reshard_cache;
std::vector<
std::tuple<HloSharding, Window, WindowedInputShardReturnValue>>
window_reshard_cache;
};
std::unordered_map<HloInstruction*, PerHloCache> per_hlo_cache;
};
struct PartitioningState {
SpmdBuilder* b;
HloModule* module;
int64 num_replicas;
HloInstruction* partition_id;
SPMDCollectiveOpsCreator collective_ops_creator;
int64* next_channel_id;
ReshardCache* reshard_cache;
SpmdPartitioner* partitioner;
};
PartitionedHlo(HloInstruction* hlo, Shape base_shape, PartitioningState state)
: hlo_(hlo), base_shape_(base_shape), state_(std::move(state)) {
CHECK(hlo->has_sharding())
<< "PartitionedHlo is missing sharding:" << hlo->ToString();
// If the tuple shape instruction does not have a tuple sharding, reassign
// to use the tuple sharding. Reshard() implementation assumes this.
if (hlo_->shape().IsTuple() && !hlo_->sharding().IsTuple()) {
hlo_->set_sharding(
hlo_->sharding().GetTupleSharding(hlo_->shape()).ValueOrDie());
}
}
// Reshards the current SPMD instruction to a new sharding. Could only modify
// the reshard cache.
PartitionedHlo Reshard(const HloSharding& target);
// Pads the garbage area of the output with the provided value. Normally,
// unevenly partitioned dimensions are padded on the right, but this function
// allows specifying left-padded dimensions, which can be used during the
// handling of kReverse, etc.
PartitionedHlo PadWithValue(
HloInstruction* pad_value,
absl::Span<const int64> left_padded_dims = {}) const;
// Returns the SPMD instruction.
HloInstruction* hlo() const { return hlo_; }
// Returns the sharding of the SPMD instruction.
const HloSharding& sharding() const { return hlo_->sharding(); }
// Original full shape of the data.
const Shape& base_shape() const { return base_shape_; }
int64 NewChannel() const { return (*state_.next_channel_id)++; }
// Reshards the HLO to a usable partitioned input for a windowed user. Could
// only modify the reshard cache.
absl::optional<WindowedInputShardReturnValue> ReshardAsWindowedInput(
const Window& window, const HloSharding& target,
HloInstruction* pad_value, bool mask_invalid_region = true);
const PartitioningState& state() const { return state_; }
private:
// Same as Reshard except that it does not explicitly modify the reshard
// cache, although it would indirectly modify by calling Replicate().
PartitionedHlo ReshardNoCache(const HloSharding& target);
// Helper function to replicate the data on all devices. Could only modify
// the reshard cache.
PartitionedHlo Replicate();
// Helper function to broadcast data from a single device to all devices.
PartitionedHlo Broadcast() const;
// Helper function to reshard the tensor using AllToAll (instead of the
// default of Replicate followed by Slice).
PartitionedHlo ReshardWithAllToAll(const HloSharding& target,
int64 source_dim, int64 target_dim) const;
// Helper function to reshard the tensor using CollectivePermute.
PartitionedHlo ReshardWithCollectivePermute(const HloSharding& target) const;
// SPMD instruction.
HloInstruction* hlo_;
// The original shape of the data before SPMD transformation is applied.
Shape base_shape_;
PartitioningState state_;
};
struct DotGeneralDimsMapping {
// The dimension numbers for the operands and output corresponding to a
// logical dimension (e.g., batch, contracting, non-contracting). If an
// operand or the output doesn't have the logical dimension, it is set to
// -1.
struct DimsMapping {
int64 lhs;
int64 rhs;
int64 output;
};
std::vector<DimsMapping> batch_dims;
std::vector<DimsMapping> contracting_dims;
std::vector<DimsMapping> lhs_non_contracting_dims;
std::vector<DimsMapping> rhs_non_contracting_dims;
};
class SpmdPartitioningVisitor : public DfsHloVisitorWithDefault {
public:
SpmdPartitioningVisitor(
HloComputation* computation, int64 num_partitions, int64 num_replicas,
const SPMDCollectiveOpsCreator& collective_ops_creator,
int64* next_channel_id, SpmdLogger* logger,
SpmdPartitionerOptions options, SpmdPartitioner* partitioner);
Status DefaultAction(HloInstruction* hlo) override;
Status HandleAllReduce(HloInstruction* hlo) override;
Status HandleBroadcast(HloInstruction* hlo) override;
Status HandleConstant(HloInstruction* hlo) override;
Status HandleCustomCall(HloInstruction* hlo) override;
Status HandleDot(HloInstruction* hlo) override;
Status HandleDynamicSlice(HloInstruction* hlo) override;
Status HandleDynamicUpdateSlice(HloInstruction* hlo) override;
Status HandleGather(HloInstruction* hlo) override;
Status HandleGetTupleElement(HloInstruction* hlo) override;
Status HandleInfeed(HloInstruction* hlo) override;
Status HandleOutfeed(HloInstruction* hlo) override;
Status HandlePad(HloInstruction* hlo) override;
Status HandleParameter(HloInstruction* hlo) override;
Status HandleReduce(HloInstruction* hlo) override;
Status HandleReverse(HloInstruction* hlo) override;
Status HandleWhile(HloInstruction* hlo) override;
Status HandleConditional(HloInstruction* hlo) override;
Status HandleReduceWindow(HloInstruction* hlo) override;
Status HandleSelectAndScatter(HloInstruction* hlo) override;
Status HandleTuple(HloInstruction* hlo) override;
Status HandleRng(HloInstruction* hlo) override;
Status HandleConvolution(HloInstruction* hlo) override;
Status HandleConcatenate(HloInstruction* hlo) override;
Status HandleScatter(HloInstruction* hlo) override;
Status HandleSlice(HloInstruction* hlo) override;
Status HandleSort(HloInstruction* hlo) override;
Status HandleTranspose(HloInstruction* hlo) override;
Status HandleReshape(HloInstruction* hlo) override;
Status HandleIota(HloInstruction* hlo) override;
Status HandlePartitionId(HloInstruction* hlo) override;
// Handles convolution where both LHS and RHS operands are tiled.
Status HandleConvolutionTiledLhsAndRhs(HloInstruction* hlo);
// Implementation of dot partitioning given DotGeneralDimsMapping.
Status HandleDotHelper(
HloInstruction* hlo, const DotGeneralDimsMapping& dims_mapping,
const std::function<StatusOr<HloInstruction*>(
HloInstruction*, HloInstruction*, SpmdBuilder*)>& create_sharded_dot);
// Common handle for elementwise HLOs.
Status HandleElementwise(HloInstruction* hlo);
// Common handle for HLOs that runs on a single device.
Status HandleSingleDevice(const HloInstruction* hlo);
// Returns the PartitionedHlo that corresponds to the original hlo.
PartitionedHlo& GetPartitionedHlo(const HloInstruction* hlo) {
CHECK_EQ(partitioned_instructions_.count(hlo), 1);
return partitioned_instructions_.find(hlo)->second;
}
// Sets the PartitionedHlo for the original hlo.
void SetPartitionedHlo(const HloInstruction* hlo,
const PartitionedHlo& partitioned_hlo) {
CHECK_EQ(partitioned_instructions_.count(hlo), 0);
partitioned_instructions_.emplace(hlo, partitioned_hlo);
changed_ = true;
}
// Convenient wrapper that creates PartitionedHlo from the result of the func
// and maps it to the given original hlo.
void SetPartitionedHlo(const HloInstruction* hlo,
const std::function<HloInstruction*()>& func) {
HloInstruction* new_hlo = func();
new_hlo->set_sharding(hlo->sharding());
new_hlo->set_metadata(hlo->metadata());
SetPartitionedHlo(
hlo, PartitionedHlo(new_hlo, hlo->shape(), MakePartitioningState()));
changed_ = true;
}
int64 NewChannel() { return (*next_channel_id_)++; }
PartitionedHlo::PartitioningState MakePartitioningState() {
PartitionedHlo::PartitioningState state;
state.b = &b_;
state.module = module_;
state.num_replicas = num_replicas_;
state.partition_id = partition_id_;
state.collective_ops_creator = collective_ops_creator_;
state.next_channel_id = next_channel_id_;
state.reshard_cache = &reshard_cache_;
state.partitioner = partitioner_;
return state;
}
SpmdBuilder* builder() { return &b_; }
StatusOr<bool> DoPartition(HloComputation* computation,
const HloSharding& root_sharding);
private:
Status Preprocess(HloInstruction* hlo) override;
Status Postprocess(HloInstruction* hlo) override;
// Performs code motion for windowed dot-general loops in
// windowed_dot_general_loops_. Invoked after the visitor finishes traversing
// the graph.
Status DoCodeMotionForWindowedDotGeneralLoops(HloComputation* computation);
bool changed_;
HloModule* module_;
int64 num_partitions_;
int64 num_replicas_;
SPMDCollectiveOpsCreator collective_ops_creator_;
// Tracks the next channel id to use for cross-partition all-reduce.
int64* next_channel_id_;
SpmdBuilder b_;
HloInstruction* partition_id_;
PartitionedHlo::ReshardCache reshard_cache_;
// Mapping from the instruction in the original computation to the new SPMD
// partitioned instruction.
ConstHloInstructionMap<PartitionedHlo> partitioned_instructions_;
// Information about a loop created for windowed dot-general. Used when
// DoCodeMotionForWindowedDotGeneralLoops() executes after the visitor
// finishes traversing the graph.
struct WindowedDotGeneralLoop {
HloInstruction* while_loop;
int64 windowed_operand;
bool windowed_in_contracting_dims;
bool windowed_in_batch_dims;
};
std::vector<WindowedDotGeneralLoop> windowed_dot_general_loops_;
HloInstruction* visiting_hlo_;
SpmdLogger* logger_;
const SpmdPartitionerOptions options_;
SpmdPartitioner* partitioner_;
};
} // namespace spmd
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_SPMD_SPMD_PARTITIONER_H_