blob: 989ea771b51f2884280a827da5e41d7b76ba8338 [file] [log] [blame]
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include <functional>
#include <numeric>
#include <queue>
#include <string>
#include <utility>
#include "absl/algorithm/container.h"
#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
#include "absl/memory/memory.h"
#include "absl/strings/match.h"
#include "absl/strings/numbers.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_join.h"
#include "absl/strings/str_split.h"
#include "absl/types/span.h"
#include "tensorflow/compiler/xla/client/sharding_builder.h"
#include "tensorflow/compiler/xla/client/xla_computation.h"
#include "tensorflow/compiler/xla/comparison_util.h"
#include "tensorflow/compiler/xla/execution_options_util.h"
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/permutation_util.h"
#include "tensorflow/compiler/xla/primitive_util.h"
#include "tensorflow/compiler/xla/service/hlo.pb.h"
#include "tensorflow/compiler/xla/service/hlo_input_output_alias_config.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/shape_inference.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/window_util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/stream_executor/lib/statusor.h"
namespace xla {
using absl::StrCat;
namespace {
static const char kNameSeparator = '.';
// Retrieves the base name of an instruction or computation fully qualified
// name, using separator as boundary between the initial base name part, and
// the numeric identification.
std::string GetBaseName(const std::string& name, char separator) {
auto pos = name.rfind(separator);
CHECK_NE(pos, std::string::npos) << name;
return name.substr(0, pos);
}
// Generates a fully qualified computation/instruction name.
std::string GetFullName(const std::string& base_name, char separator,
int64_t id) {
const char separator_str[] = {separator, '\0'};
return StrCat(base_name, separator_str, id);
}
// Common function to standardize setting name and IDs on computation and
// instruction proto entities.
template <typename T>
void SetProtoIdAndName(T* entry, const std::string& base_name, char separator,
int64_t id) {
entry->set_id(id);
entry->set_name(GetFullName(base_name, separator, id));
}
bool InstrIsSetBound(const HloInstructionProto* instr_proto) {
HloOpcode opcode = StringToHloOpcode(instr_proto->opcode()).ValueOrDie();
if (opcode == HloOpcode::kCustomCall &&
instr_proto->custom_call_target() == "SetBound") {
return true;
}
return false;
}
} // namespace
namespace internal {
XlaOp XlaBuilderFriend::BuildAddDependency(XlaBuilder* builder, XlaOp operand,
XlaOp token, const Shape& shape) {
return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
*instr.mutable_shape() = shape.ToProto();
return builder->AddInstruction(std::move(instr), HloOpcode::kAddDependency,
{operand, token});
});
}
XlaOp XlaBuilderFriend::BuildFusion(XlaBuilder* builder,
absl::Span<const XlaOp> operands,
absl::string_view fusion_kind,
const XlaComputation& fused_computation) {
return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
instr.set_fusion_kind(std::string(fusion_kind));
std::vector<const Shape*> operand_shape_ptrs;
TF_ASSIGN_OR_RETURN(auto program_shape,
fused_computation.GetProgramShape());
*instr.mutable_shape() = program_shape.result().ToProto();
builder->AddCalledComputation(fused_computation, &instr);
return builder->AddInstruction(std::move(instr), HloOpcode::kFusion,
operands);
});
}
XlaOp XlaBuilderFriend::BuildBitcast(XlaBuilder* builder, XlaOp operand,
const Shape& shape) {
return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
*instr.mutable_shape() = shape.ToProto();
return builder->AddInstruction(std::move(instr), HloOpcode::kBitcast,
{operand});
});
}
XlaOp XlaBuilderFriend::BuildPartitionId(XlaBuilder* builder,
const Shape& shape) {
return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
*instr.mutable_shape() = shape.ToProto();
return builder->AddInstruction(std::move(instr), HloOpcode::kPartitionId);
});
}
XlaOp XlaBuilderFriend::BuildRngGetAndUpdateState(XlaBuilder* builder,
int64_t delta,
const Shape& shape) {
return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
instr.set_delta(delta);
*instr.mutable_shape() = shape.ToProto();
return builder->AddInstruction(std::move(instr),
HloOpcode::kRngGetAndUpdateState);
});
}
HloInstructionProto* XlaBuilderFriend::GetInstruction(XlaOp op) {
return &op.builder()
->instructions_[op.builder()->handle_to_index_[op.handle_]];
}
HloInstructionProto* XlaBuilderFriend::GetInstructionByHandle(
XlaBuilder* builder, int64_t handle) {
return &builder->instructions_[builder->handle_to_index_[handle]];
}
} // namespace internal
XlaOp operator-(XlaOp x) { return Neg(x); }
XlaOp operator+(XlaOp x, XlaOp y) { return Add(x, y); }
XlaOp operator-(XlaOp x, XlaOp y) { return Sub(x, y); }
XlaOp operator*(XlaOp x, XlaOp y) { return Mul(x, y); }
XlaOp operator/(XlaOp x, XlaOp y) { return Div(x, y); }
XlaOp operator%(XlaOp x, XlaOp y) { return Rem(x, y); }
XlaOp operator~(XlaOp x) { return Not(x); }
XlaOp operator&(XlaOp x, XlaOp y) { return And(x, y); }
XlaOp operator|(XlaOp x, XlaOp y) { return Or(x, y); }
XlaOp operator^(XlaOp x, XlaOp y) { return Xor(x, y); }
XlaOp operator<<(XlaOp x, XlaOp y) { return ShiftLeft(x, y); }
XlaOp operator>>(XlaOp x, XlaOp y) {
XlaBuilder* builder = x.builder();
return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(const xla::Shape* shape, builder->GetShapePtr(x));
if (!ShapeUtil::ElementIsIntegral(*shape)) {
return InvalidArgument(
"Argument to >> operator does not have an integral type (%s).",
ShapeUtil::HumanString(*shape));
}
if (ShapeUtil::ElementIsSigned(*shape)) {
return ShiftRightArithmetic(x, y);
} else {
return ShiftRightLogical(x, y);
}
});
}
StatusOr<const Shape*> XlaBuilder::GetShapePtr(XlaOp op) const {
TF_RETURN_IF_ERROR(first_error_);
TF_RETURN_IF_ERROR(CheckOpBuilder(op));
auto it = handle_to_index_.find(op.handle());
if (it == handle_to_index_.end()) {
return InvalidArgument("No XlaOp with handle %d", op.handle());
}
return instruction_shapes_.at(it->second).get();
}
StatusOr<Shape> XlaBuilder::GetShape(XlaOp op) const {
TF_ASSIGN_OR_RETURN(const Shape* shape, GetShapePtr(op));
return *shape;
}
StatusOr<std::vector<Shape>> XlaBuilder::GetOperandShapes(
absl::Span<const XlaOp> operands) const {
std::vector<Shape> operand_shapes;
operand_shapes.reserve(operands.size());
for (XlaOp operand : operands) {
TF_ASSIGN_OR_RETURN(const Shape* shape, GetShapePtr(operand));
operand_shapes.push_back(*shape);
}
return operand_shapes;
}
std::string XlaBuilder::OpToString(XlaOp op) const {
std::string s;
ToStringHelper(&s, /*ident=*/0, op.handle());
return s;
}
static std::string ShapeToString(const xla::ShapeProto& shape) {
if (shape.tuple_shapes_size() > 1) {
return absl::StrCat(
"(",
absl::StrJoin(shape.tuple_shapes(), ", ",
[&](std::string* s, const xla::ShapeProto& subshape) {
absl::StrAppend(s, ShapeToString(subshape));
}),
")");
}
return absl::StrCat("[", absl::StrJoin(shape.dimensions(), ", "), "]");
}
void XlaBuilder::ToStringHelper(std::string* out, int ident,
int64_t op_handle) const {
const HloInstructionProto& instr =
*(LookUpInstructionByHandle(op_handle).ValueOrDie());
absl::StrAppend(out, std::string(ident, ' '), instr.opcode(),
", shape=", ShapeToString(instr.shape()));
if (instr.has_metadata()) {
absl::StrAppend(out, ", metadata={", instr.metadata().source_file(), ":",
instr.metadata().source_line(), "}");
}
if (instr.operand_ids_size()) {
absl::StrAppend(out, "\n");
}
absl::StrAppend(out, absl::StrJoin(instr.operand_ids(), "\n",
[&](std::string* s, int64_t subop) {
ToStringHelper(s, ident + 2, subop);
}));
}
XlaBuilder::XlaBuilder(const std::string& computation_name)
: name_(computation_name) {}
XlaBuilder::~XlaBuilder() {}
XlaOp XlaBuilder::ReportError(const Status& error) {
CHECK(!error.ok());
if (die_immediately_on_error_) {
LOG(FATAL) << "error building computation: " << error;
}
if (first_error_.ok()) {
first_error_ = error;
first_error_backtrace_.CreateCurrent(/*skip_count=*/1);
}
return XlaOp(this);
}
XlaOp XlaBuilder::ReportErrorOrReturn(const StatusOr<XlaOp>& op) {
if (!first_error_.ok()) {
return XlaOp(this);
}
if (!op.ok()) {
return ReportError(op.status());
}
return op.ValueOrDie();
}
XlaOp XlaBuilder::ReportErrorOrReturn(
const std::function<StatusOr<XlaOp>()>& op_creator) {
return ReportErrorOrReturn(op_creator());
}
StatusOr<ProgramShape> XlaBuilder::GetProgramShape(int64_t root_id) const {
TF_RETURN_IF_ERROR(first_error_);
TF_ASSIGN_OR_RETURN(const HloInstructionProto* root_proto,
LookUpInstructionByHandle(root_id));
ProgramShape program_shape;
*program_shape.mutable_result() = Shape(root_proto->shape());
// Check that the parameter numbers are continuous from 0, and add parameter
// shapes and names to the program shape.
const int64_t param_count = parameter_numbers_.size();
for (int64_t i = 0; i < param_count; i++) {
program_shape.add_parameters();
program_shape.add_parameter_names();
}
for (const HloInstructionProto& instr : instructions_) {
// Parameter number uniqueness is guaranteed in XlaBuilder::Parameter(). So
// to verify continuity, we just need to verify that every parameter is in
// the right range.
if (instr.opcode() == HloOpcodeString(HloOpcode::kParameter)) {
const int64_t index = instr.parameter_number();
TF_RET_CHECK(index >= 0 && index < param_count)
<< "invalid parameter number: " << index;
*program_shape.mutable_parameters(index) = Shape(instr.shape());
*program_shape.mutable_parameter_names(index) = instr.name();
}
}
return program_shape;
}
StatusOr<ProgramShape> XlaBuilder::GetProgramShape() const {
TF_RET_CHECK(!instructions_.empty());
return GetProgramShape(instructions_.back().id());
}
StatusOr<ProgramShape> XlaBuilder::GetProgramShape(XlaOp root) const {
if (root.builder_ != this) {
return InvalidArgument("Given root operation is not in this computation.");
}
return GetProgramShape(root.handle());
}
void XlaBuilder::IsConstantVisitor(const int64_t op_handle, int depth,
absl::flat_hash_set<int64_t>* visited,
bool* is_constant) const {
if (visited->contains(op_handle) || !*is_constant) {
return;
}
const HloInstructionProto& instr =
*(LookUpInstructionByHandle(op_handle).ValueOrDie());
HloInstructionProto to_print(instr);
to_print.clear_shape();
const HloOpcode opcode = StringToHloOpcode(instr.opcode()).ValueOrDie();
const std::string indent =
absl::StrJoin(std::vector<absl::string_view>(depth, " "), "");
if (VLOG_IS_ON(2)) {
VLOG(2) << indent << "Visiting:";
for (const auto& l : absl::StrSplit(to_print.DebugString(), '\n')) {
VLOG(2) << indent << l;
}
}
switch (opcode) {
default:
for (const int64_t operand_id : instr.operand_ids()) {
IsConstantVisitor(operand_id, depth + 1, visited, is_constant);
}
// TODO(b/32495713): We aren't checking the called computations.
break;
case HloOpcode::kGetDimensionSize:
// GetDimensionSize is always considered constant in XLA -- If a dynamic
// dimension is presented, -1 is returned.
break;
// Non functional ops.
case HloOpcode::kRng:
case HloOpcode::kAllReduce:
case HloOpcode::kReduceScatter:
// TODO(b/33009255): Implement constant folding for cross replica sum.
case HloOpcode::kInfeed:
case HloOpcode::kOutfeed:
case HloOpcode::kCall:
// TODO(b/32495713): We aren't checking the to_apply computation itself,
// so we conservatively say that computations containing the Call op
// cannot be constant. We cannot set is_functional=false in other similar
// cases since we're already relying on IsConstant to return true.
case HloOpcode::kCustomCall:
if (instr.custom_call_target() == "SetBound") {
// Set bound is considered constant -- the bound is used as the value.
break;
}
ABSL_FALLTHROUGH_INTENDED;
case HloOpcode::kWhile:
// TODO(b/32495713): We aren't checking the condition and body
// computations themselves.
case HloOpcode::kScatter:
// TODO(b/32495713): We aren't checking the embedded computation in
// Scatter.
case HloOpcode::kSend:
case HloOpcode::kRecv:
case HloOpcode::kParameter:
*is_constant = false;
break;
case HloOpcode::kGetTupleElement: {
const HloInstructionProto& operand_instr =
*(LookUpInstructionByHandle(instr.operand_ids(0)).ValueOrDie());
if (HloOpcodeString(HloOpcode::kTuple) == operand_instr.opcode()) {
IsConstantVisitor(operand_instr.operand_ids(instr.tuple_index()),
depth + 1, visited, is_constant);
} else {
for (const int64_t operand_id : instr.operand_ids()) {
IsConstantVisitor(operand_id, depth + 1, visited, is_constant);
}
}
}
}
if (VLOG_IS_ON(1) && !*is_constant) {
VLOG(1) << indent << "Non-constant: ";
for (const auto& l : absl::StrSplit(to_print.DebugString(), '\n')) {
VLOG(1) << indent << l;
}
}
visited->insert(op_handle);
}
Status XlaBuilder::SetDynamicBinding(int64_t dynamic_size_param_num,
ShapeIndex dynamic_size_param_index,
int64_t target_param_num,
ShapeIndex target_param_index,
int64_t target_dim_num) {
bool param_exists = false;
for (size_t index = 0; index < instructions_.size(); ++index) {
HloInstructionProto& instr = instructions_[index];
if (instr.opcode() == HloOpcodeString(HloOpcode::kParameter) &&
instr.parameter_number() == target_param_num) {
param_exists = true;
Shape param_shape(instr.shape());
Shape* param_shape_ptr = &param_shape;
for (int64_t index : target_param_index) {
param_shape_ptr = param_shape_ptr->mutable_tuple_shapes(index);
}
param_shape_ptr->set_dynamic_dimension(target_dim_num,
/*is_dynamic=*/true);
*instr.mutable_shape() = param_shape.ToProto();
instruction_shapes_[index] =
absl::make_unique<Shape>(std::move(param_shape));
}
}
if (!param_exists) {
return InvalidArgument(
"Asked to mark parameter %lld as dynamic sized parameter, but the "
"doesn't exists",
target_param_num);
}
TF_RETURN_IF_ERROR(dynamic_parameter_binding_.Bind(
DynamicParameterBinding::DynamicParameter{dynamic_size_param_num,
dynamic_size_param_index},
DynamicParameterBinding::DynamicDimension{
target_param_num, target_param_index, target_dim_num}));
return ::tensorflow::OkStatus();
}
Status XlaBuilder::SetInstructionFrontendAttribute(const XlaOp op,
std::string attribute,
std::string value) {
TF_ASSIGN_OR_RETURN(auto instr_proto, LookUpMutableInstruction(op));
auto* frontend_attributes = instr_proto->mutable_frontend_attributes();
(*frontend_attributes->mutable_map())[attribute] = std::move(value);
return ::tensorflow::OkStatus();
}
XlaComputation XlaBuilder::BuildAndNoteError() {
DCHECK(parent_builder_ != nullptr);
auto build_status = Build();
if (!build_status.ok()) {
parent_builder_->ReportError(
AddStatus(build_status.status(), absl::StrCat("error from: ", name_)));
return {};
}
return build_status.ConsumeValueOrDie();
}
Status XlaBuilder::GetCurrentStatus() const {
if (!first_error_.ok()) {
std::string backtrace;
first_error_backtrace_.Dump(tensorflow::DebugWriteToString, &backtrace);
return AppendStatus(first_error_, backtrace);
}
return ::tensorflow::OkStatus();
}
StatusOr<XlaComputation> XlaBuilder::Build(bool remove_dynamic_dimensions) {
TF_RETURN_IF_ERROR(GetCurrentStatus());
return Build(instructions_.back().id(), remove_dynamic_dimensions);
}
StatusOr<XlaComputation> XlaBuilder::Build(XlaOp root,
bool remove_dynamic_dimensions) {
if (root.builder_ != this) {
return InvalidArgument("Given root operation is not in this computation.");
}
return Build(root.handle(), remove_dynamic_dimensions);
}
StatusOr<XlaComputation> XlaBuilder::Build(int64_t root_id,
bool remove_dynamic_dimensions) {
TF_RETURN_IF_ERROR(GetCurrentStatus());
// TODO(b/121223198): XLA backend cannot handle dynamic dimensions yet, remove
// all dynamic dimensions before building xla program until we have support in
// the backend.
if (remove_dynamic_dimensions) {
std::function<void(Shape*)> remove_dynamic_dimension = [&](Shape* shape) {
if (shape->tuple_shapes_size() != 0) {
for (int i = 0; i < shape->tuple_shapes_size(); ++i) {
remove_dynamic_dimension(shape->mutable_tuple_shapes(i));
}
}
for (int64_t i = 0; i < shape->dimensions_size(); ++i) {
shape->set_dynamic_dimension(i, false);
}
};
for (size_t index = 0; index < instructions_.size(); ++index) {
remove_dynamic_dimension(instruction_shapes_[index].get());
*instructions_[index].mutable_shape() =
instruction_shapes_[index]->ToProto();
}
}
HloComputationProto entry;
SetProtoIdAndName(&entry, name_, kNameSeparator, GetNextId());
TF_ASSIGN_OR_RETURN(ProgramShape program_shape, GetProgramShape(root_id));
*entry.mutable_program_shape() = program_shape.ToProto();
entry.set_root_id(root_id);
for (auto& instruction : instructions_) {
// Ensures that the instruction names are unique among the whole graph.
instruction.set_name(
GetFullName(instruction.name(), kNameSeparator, instruction.id()));
entry.add_instructions()->Swap(&instruction);
}
XlaComputation computation(entry.id());
HloModuleProto* module = computation.mutable_proto();
module->set_name(entry.name());
module->set_id(entry.id());
module->set_entry_computation_name(entry.name());
module->set_entry_computation_id(entry.id());
*module->mutable_host_program_shape() = entry.program_shape();
for (auto& e : embedded_) {
module->add_computations()->Swap(&e.second);
}
module->add_computations()->Swap(&entry);
if (!input_output_aliases_.empty()) {
TF_RETURN_IF_ERROR(
PopulateInputOutputAlias(module, program_shape, input_output_aliases_));
}
*(module->mutable_dynamic_parameter_binding()) =
dynamic_parameter_binding_.ToProto();
// Clear data held by this builder.
this->instructions_.clear();
this->instruction_shapes_.clear();
this->handle_to_index_.clear();
this->embedded_.clear();
this->parameter_numbers_.clear();
return std::move(computation);
}
/* static */ Status XlaBuilder::PopulateInputOutputAlias(
HloModuleProto* module, const ProgramShape& program_shape,
const std::vector<InputOutputAlias>& input_output_aliases) {
HloInputOutputAliasConfig config(program_shape.result());
for (auto& alias : input_output_aliases) {
// The HloInputOutputAliasConfig does not do parameter validation as it only
// carries the result shape. Maybe it should be constructed with a
// ProgramShape to allow full validation. We will still get an error when
// trying to compile the HLO module, but would be better to have validation
// at this stage.
if (alias.param_number >= program_shape.parameters_size()) {
return InvalidArgument("Invalid parameter number %ld (total %ld)",
alias.param_number,
program_shape.parameters_size());
}
const Shape& parameter_shape = program_shape.parameters(alias.param_number);
if (!ShapeUtil::IndexIsValid(parameter_shape, alias.param_index)) {
return InvalidArgument("Invalid parameter %ld index: %s",
alias.param_number,
alias.param_index.ToString().c_str());
}
TF_RETURN_IF_ERROR(config.SetUpAlias(alias.output_index, alias.param_number,
alias.param_index, alias.kind));
}
*module->mutable_input_output_alias() = config.ToProto();
return ::tensorflow::OkStatus();
}
StatusOr<XlaOp> XlaBuilder::InDimBroadcast(
const Shape& shape, XlaOp operand,
absl::Span<const int64_t> broadcast_dimensions) {
TF_RETURN_IF_ERROR(first_error_);
HloInstructionProto instr;
*instr.mutable_shape() = shape.ToProto();
for (int64_t dim : broadcast_dimensions) {
instr.add_dimensions(dim);
}
return AddInstruction(std::move(instr), HloOpcode::kBroadcast, {operand});
}
StatusOr<XlaOp> XlaBuilder::AddBroadcastSequence(const Shape& output_shape,
XlaOp operand) {
TF_RETURN_IF_ERROR(first_error_);
TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
CHECK(ShapeUtil::IsScalar(*operand_shape) ||
operand_shape->rank() == output_shape.rank());
Shape broadcast_shape =
ShapeUtil::ChangeElementType(output_shape, operand_shape->element_type());
// Do explicit broadcast for scalar.
if (ShapeUtil::IsScalar(*operand_shape)) {
return InDimBroadcast(broadcast_shape, operand, {});
}
// Do explicit broadcast for degenerate broadcast.
std::vector<int64_t> broadcast_dimensions;
std::vector<int64_t> reshaped_dimensions;
for (int i = 0; i < operand_shape->rank(); i++) {
if (operand_shape->dimensions(i) == output_shape.dimensions(i)) {
broadcast_dimensions.push_back(i);
reshaped_dimensions.push_back(operand_shape->dimensions(i));
} else {
TF_RET_CHECK(operand_shape->dimensions(i) == 1)
<< "An explicit broadcast sequence requires the broadcasted "
"dimensions to be trivial; operand shape: "
<< *operand_shape << "; output_shape: " << output_shape;
}
}
Shape reshaped_shape =
ShapeUtil::MakeShape(operand_shape->element_type(), reshaped_dimensions);
std::vector<std::pair<int64_t, int64_t>> unmodified_dims =
ShapeUtil::DimensionsUnmodifiedByReshape(*operand_shape, reshaped_shape);
for (auto& unmodified : unmodified_dims) {
if (operand_shape->is_dynamic_dimension(unmodified.first)) {
reshaped_shape.set_dynamic_dimension(unmodified.second, true);
}
}
// Eliminate the size one dimensions.
TF_ASSIGN_OR_RETURN(
XlaOp reshaped_operand,
ReshapeInternal(reshaped_shape, operand, /*inferred_dimension=*/-1));
// Broadcast 'reshape' up to the larger size.
return InDimBroadcast(broadcast_shape, reshaped_operand,
broadcast_dimensions);
}
XlaOp XlaBuilder::UnaryOp(HloOpcode unop, XlaOp operand) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
TF_ASSIGN_OR_RETURN(
Shape shape, ShapeInference::InferUnaryOpShape(unop, *operand_shape));
return AddOpWithShape(unop, shape, {operand});
});
}
XlaOp XlaBuilder::BinaryOp(HloOpcode binop, XlaOp lhs, XlaOp rhs,
absl::Span<const int64_t> broadcast_dimensions,
absl::optional<ComparisonDirection> direction,
absl::optional<Comparison::Type> type) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(const Shape* lhs_shape, GetShapePtr(lhs));
TF_ASSIGN_OR_RETURN(const Shape* rhs_shape, GetShapePtr(rhs));
TF_ASSIGN_OR_RETURN(
Shape shape, ShapeInference::InferBinaryOpShape(
binop, *lhs_shape, *rhs_shape, broadcast_dimensions));
const int64_t lhs_rank = lhs_shape->rank();
const int64_t rhs_rank = rhs_shape->rank();
XlaOp updated_lhs = lhs;
XlaOp updated_rhs = rhs;
if (!broadcast_dimensions.empty() && lhs_rank != rhs_rank) {
const bool should_broadcast_lhs = lhs_rank < rhs_rank;
XlaOp from = should_broadcast_lhs ? lhs : rhs;
const Shape& from_shape = should_broadcast_lhs ? *lhs_shape : *rhs_shape;
std::vector<int64_t> to_size;
std::vector<bool> to_size_is_dynamic;
const auto rank = shape.rank();
to_size.reserve(rank);
to_size_is_dynamic.reserve(rank);
for (int i = 0; i < rank; i++) {
to_size.push_back(shape.dimensions(i));
to_size_is_dynamic.push_back(shape.is_dynamic_dimension(i));
}
for (int64_t from_dim = 0; from_dim < from_shape.rank(); from_dim++) {
int64_t to_dim = broadcast_dimensions[from_dim];
to_size[to_dim] = from_shape.dimensions(from_dim);
to_size_is_dynamic[to_dim] = from_shape.is_dynamic_dimension(from_dim);
}
const Shape& broadcasted_shape = ShapeUtil::MakeShape(
from_shape.element_type(), to_size, to_size_is_dynamic);
TF_ASSIGN_OR_RETURN(
XlaOp broadcasted_operand,
InDimBroadcast(broadcasted_shape, from, broadcast_dimensions));
updated_lhs = should_broadcast_lhs ? broadcasted_operand : lhs;
updated_rhs = !should_broadcast_lhs ? broadcasted_operand : rhs;
}
TF_ASSIGN_OR_RETURN(const Shape* updated_lhs_shape,
GetShapePtr(updated_lhs));
if (!ShapeUtil::SameDimensions(shape, *updated_lhs_shape)) {
TF_ASSIGN_OR_RETURN(updated_lhs,
AddBroadcastSequence(shape, updated_lhs));
}
TF_ASSIGN_OR_RETURN(const Shape* updated_rhs_shape,
GetShapePtr(updated_rhs));
if (!ShapeUtil::SameDimensions(shape, *updated_rhs_shape)) {
TF_ASSIGN_OR_RETURN(updated_rhs,
AddBroadcastSequence(shape, updated_rhs));
}
if (binop == HloOpcode::kCompare) {
if (!direction.has_value()) {
return InvalidArgument(
"kCompare expects a ComparisonDirection, but none provided.");
}
if (type == absl::nullopt) {
return Compare(shape, updated_lhs, updated_rhs, *direction);
} else {
return Compare(shape, updated_lhs, updated_rhs, *direction, *type);
}
}
if (direction.has_value()) {
return InvalidArgument(
"A comparison direction is provided for a non-compare opcode: %s.",
HloOpcodeString(binop));
}
return BinaryOpNoBroadcast(binop, shape, updated_lhs, updated_rhs);
});
}
XlaOp XlaBuilder::BinaryOpNoBroadcast(HloOpcode binop, const Shape& shape,
XlaOp lhs, XlaOp rhs) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
*instr.mutable_shape() = shape.ToProto();
return AddInstruction(std::move(instr), binop, {lhs, rhs});
});
}
StatusOr<XlaOp> XlaBuilder::Compare(const Shape& shape, XlaOp lhs, XlaOp rhs,
ComparisonDirection direction) {
TF_ASSIGN_OR_RETURN(auto operand_shape, GetShape(lhs));
return Compare(
shape, lhs, rhs, direction,
Comparison::DefaultComparisonType(operand_shape.element_type()));
}
StatusOr<XlaOp> XlaBuilder::Compare(const Shape& shape, XlaOp lhs, XlaOp rhs,
ComparisonDirection direction,
Comparison::Type type) {
HloInstructionProto instr;
instr.set_comparison_direction(ComparisonDirectionToString(direction));
instr.set_comparison_type(ComparisonTypeToString(type));
*instr.mutable_shape() = shape.ToProto();
return AddInstruction(std::move(instr), HloOpcode::kCompare, {lhs, rhs});
}
XlaOp XlaBuilder::TernaryOp(HloOpcode triop, XlaOp lhs, XlaOp rhs, XlaOp ehs) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
XlaOp updated_lhs = lhs;
XlaOp updated_rhs = rhs;
XlaOp updated_ehs = ehs;
// The client API supports implicit broadcast for kSelect and kClamp, but
// XLA does not support implicit broadcast. Make implicit broadcast explicit
// and update the operands.
if (triop == HloOpcode::kSelect || triop == HloOpcode::kClamp) {
TF_ASSIGN_OR_RETURN(const Shape* lhs_shape, GetShapePtr(lhs));
TF_ASSIGN_OR_RETURN(const Shape* rhs_shape, GetShapePtr(rhs));
TF_ASSIGN_OR_RETURN(const Shape* ehs_shape, GetShapePtr(ehs));
absl::optional<Shape> non_scalar_shape;
for (const Shape* shape : {lhs_shape, rhs_shape, ehs_shape}) {
if (shape->IsArray() && shape->rank() != 0) {
if (non_scalar_shape.has_value()) {
// TODO(jpienaar): The case where we need to compute the broadcasted
// shape by considering multiple of the shapes is not implemented.
// Consider reusing getBroadcastedType from mlir/Dialect/Traits.h.
TF_RET_CHECK(non_scalar_shape.value().dimensions() ==
shape->dimensions())
<< "Unimplemented implicit broadcast.";
} else {
non_scalar_shape = *shape;
}
}
}
if (non_scalar_shape.has_value()) {
if (ShapeUtil::IsScalar(*lhs_shape)) {
TF_ASSIGN_OR_RETURN(updated_lhs,
AddBroadcastSequence(*non_scalar_shape, lhs));
}
if (ShapeUtil::IsScalar(*rhs_shape)) {
TF_ASSIGN_OR_RETURN(updated_rhs,
AddBroadcastSequence(*non_scalar_shape, rhs));
}
if (ShapeUtil::IsScalar(*ehs_shape)) {
TF_ASSIGN_OR_RETURN(updated_ehs,
AddBroadcastSequence(*non_scalar_shape, ehs));
}
}
}
TF_ASSIGN_OR_RETURN(const Shape* lhs_shape, GetShapePtr(updated_lhs));
TF_ASSIGN_OR_RETURN(const Shape* rhs_shape, GetShapePtr(updated_rhs));
TF_ASSIGN_OR_RETURN(const Shape* ehs_shape, GetShapePtr(updated_ehs));
StatusOr<const Shape> status_or_shape = ShapeInference::InferTernaryOpShape(
triop, *lhs_shape, *rhs_shape, *ehs_shape);
if (!status_or_shape.status().ok()) {
return InvalidArgument(
"%s Input scalar shapes may have been changed to non-scalar shapes.",
status_or_shape.status().error_message());
}
return AddOpWithShape(triop, status_or_shape.ValueOrDie(),
{updated_lhs, updated_rhs, updated_ehs});
});
}
XlaOp XlaBuilder::ConstantLiteral(const LiteralSlice& literal) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
if (literal.shape().IsArray() && literal.element_count() > 1 &&
literal.IsAllFirst()) {
Literal scalar = LiteralUtil::GetFirstScalarLiteral(literal);
HloInstructionProto instr;
*instr.mutable_shape() = scalar.shape().ToProto();
*instr.mutable_literal() = scalar.ToProto();
TF_ASSIGN_OR_RETURN(
XlaOp scalar_op,
AddInstruction(std::move(instr), HloOpcode::kConstant));
return Broadcast(scalar_op, literal.shape().dimensions());
} else {
HloInstructionProto instr;
*instr.mutable_shape() = literal.shape().ToProto();
*instr.mutable_literal() = literal.ToProto();
return AddInstruction(std::move(instr), HloOpcode::kConstant);
}
});
}
XlaOp XlaBuilder::Iota(const Shape& shape, int64_t iota_dimension) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
*instr.mutable_shape() = shape.ToProto();
instr.add_dimensions(iota_dimension);
return AddInstruction(std::move(instr), HloOpcode::kIota);
});
}
XlaOp XlaBuilder::Iota(PrimitiveType type, int64_t size) {
return Iota(ShapeUtil::MakeShape(type, {size}), /*iota_dimension=*/0);
}
XlaOp XlaBuilder::Call(const XlaComputation& computation,
absl::Span<const XlaOp> operands) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
std::vector<const Shape*> operand_shape_ptrs;
TF_ASSIGN_OR_RETURN(const auto& operand_shapes, GetOperandShapes(operands));
absl::c_transform(operand_shapes, std::back_inserter(operand_shape_ptrs),
[](const Shape& shape) { return &shape; });
TF_ASSIGN_OR_RETURN(const ProgramShape& called_program_shape,
computation.GetProgramShape());
TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferCallShape(
operand_shape_ptrs,
/*to_apply=*/called_program_shape));
*instr.mutable_shape() = shape.ToProto();
AddCalledComputation(computation, &instr);
return AddInstruction(std::move(instr), HloOpcode::kCall, operands);
});
}
XlaOp XlaBuilder::Parameter(
int64_t parameter_number, const Shape& shape, const std::string& name,
const std::vector<bool>& replicated_at_leaf_buffers) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
if (!parameter_numbers_.insert(parameter_number).second) {
return InvalidArgument("parameter %d already registered",
parameter_number);
}
instr.set_parameter_number(parameter_number);
instr.set_name(name);
*instr.mutable_shape() = shape.ToProto();
if (!replicated_at_leaf_buffers.empty()) {
auto replication = instr.mutable_parameter_replication();
for (bool replicated : replicated_at_leaf_buffers) {
replication->add_replicated_at_leaf_buffers(replicated);
}
}
return AddInstruction(std::move(instr), HloOpcode::kParameter);
});
}
XlaOp XlaBuilder::Broadcast(XlaOp operand,
absl::Span<const int64_t> broadcast_sizes) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
TF_ASSIGN_OR_RETURN(
const Shape& shape,
ShapeInference::InferBroadcastShape(*operand_shape, broadcast_sizes));
// The client-level broadcast op just appends dimensions on the left (adds
// lowest numbered dimensions). The HLO broadcast instruction is more
// flexible and can add new dimensions anywhere. The instruction's
// dimensions field maps operand dimensions to dimensions in the broadcast
// output, so to append dimensions on the left the instruction's dimensions
// should just be the n highest dimension numbers of the output shape where
// n is the number of input dimensions.
const int64_t operand_rank = operand_shape->rank();
std::vector<int64_t> dimensions(operand_rank);
for (int i = 0; i < operand_rank; ++i) {
dimensions[i] = i + shape.rank() - operand_rank;
}
return InDimBroadcast(shape, operand, dimensions);
});
}
XlaOp XlaBuilder::BroadcastInDim(
XlaOp operand, const absl::Span<const int64_t> out_dim_size,
const absl::Span<const int64_t> broadcast_dimensions) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
// Output shape, in the case of degenerate broadcast, the out_dim_size is
// not necessarily the same as the dimension sizes of the output shape.
TF_ASSIGN_OR_RETURN(auto output_shape,
ShapeUtil::MakeValidatedShape(
operand_shape->element_type(), out_dim_size));
int64_t broadcast_rank = broadcast_dimensions.size();
if (operand_shape->rank() != broadcast_rank) {
return InvalidArgument(
"Size of broadcast_dimensions has to match operand's rank; operand "
"rank: %lld, size of broadcast_dimensions %u.",
operand_shape->rank(), broadcast_dimensions.size());
}
for (int i = 0; i < broadcast_rank; i++) {
const int64_t num_dims = out_dim_size.size();
if (broadcast_dimensions[i] < 0 || broadcast_dimensions[i] > num_dims) {
return InvalidArgument("Broadcast dimension %lld is out of bound",
broadcast_dimensions[i]);
}
output_shape.set_dynamic_dimension(
broadcast_dimensions[i], operand_shape->is_dynamic_dimension(i));
}
TF_RETURN_IF_ERROR(ShapeInference::InferBroadcastShape(
*operand_shape, output_shape, broadcast_dimensions)
.status());
std::vector<int64_t> in_dim_size(out_dim_size.begin(), out_dim_size.end());
for (int i = 0; i < broadcast_rank; i++) {
in_dim_size[broadcast_dimensions[i]] = operand_shape->dimensions(i);
}
const auto& in_dim_shape =
ShapeUtil::MakeShape(operand_shape->element_type(), in_dim_size);
TF_ASSIGN_OR_RETURN(
XlaOp in_dim_broadcast,
InDimBroadcast(in_dim_shape, operand, broadcast_dimensions));
// If broadcast is not degenerate, return broadcasted result.
if (ShapeUtil::Equal(in_dim_shape, output_shape)) {
return in_dim_broadcast;
}
// Otherwise handle degenerate broadcast case.
return AddBroadcastSequence(output_shape, in_dim_broadcast);
});
}
StatusOr<XlaOp> XlaBuilder::ReshapeInternal(const Shape& shape, XlaOp operand,
int64_t inferred_dimension) {
TF_RETURN_IF_ERROR(first_error_);
HloInstructionProto instr;
*instr.mutable_shape() = shape.ToProto();
if (inferred_dimension != -1) {
instr.add_dimensions(inferred_dimension);
}
return AddInstruction(std::move(instr), HloOpcode::kReshape, {operand});
}
XlaOp XlaBuilder::Slice(XlaOp operand, absl::Span<const int64_t> start_indices,
absl::Span<const int64_t> limit_indices,
absl::Span<const int64_t> strides) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferSliceShape(
*operand_shape, start_indices,
limit_indices, strides));
return SliceInternal(shape, operand, start_indices, limit_indices, strides);
});
}
StatusOr<XlaOp> XlaBuilder::SliceInternal(
const Shape& shape, XlaOp operand, absl::Span<const int64_t> start_indices,
absl::Span<const int64_t> limit_indices,
absl::Span<const int64_t> strides) {
HloInstructionProto instr;
*instr.mutable_shape() = shape.ToProto();
for (int i = 0, end = start_indices.size(); i < end; i++) {
auto* slice_config = instr.add_slice_dimensions();
slice_config->set_start(start_indices[i]);
slice_config->set_limit(limit_indices[i]);
slice_config->set_stride(strides[i]);
}
return AddInstruction(std::move(instr), HloOpcode::kSlice, {operand});
}
XlaOp XlaBuilder::SliceInDim(XlaOp operand, int64_t start_index,
int64_t limit_index, int64_t stride,
int64_t dimno) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(const Shape* shape, GetShapePtr(operand));
std::vector<int64_t> starts(shape->rank(), 0);
std::vector<int64_t> limits(shape->dimensions().begin(),
shape->dimensions().end());
std::vector<int64_t> strides(shape->rank(), 1);
starts[dimno] = start_index;
limits[dimno] = limit_index;
strides[dimno] = stride;
return Slice(operand, starts, limits, strides);
});
}
XlaOp XlaBuilder::DynamicSlice(XlaOp operand,
absl::Span<const XlaOp> start_indices,
absl::Span<const int64_t> slice_sizes) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
std::vector<const Shape*> start_indices_shape_ptrs;
TF_ASSIGN_OR_RETURN(const auto& start_indices_shapes,
GetOperandShapes(start_indices));
absl::c_transform(start_indices_shapes,
std::back_inserter(start_indices_shape_ptrs),
[](const Shape& shape) { return &shape; });
TF_ASSIGN_OR_RETURN(Shape shape,
ShapeInference::InferDynamicSliceShape(
*operand_shape, start_indices_shapes, slice_sizes));
return DynamicSliceInternal(shape, operand, start_indices, slice_sizes);
});
}
StatusOr<XlaOp> XlaBuilder::DynamicSliceInternal(
const Shape& shape, XlaOp operand, absl::Span<const XlaOp> start_indices,
absl::Span<const int64_t> slice_sizes) {
HloInstructionProto instr;
*instr.mutable_shape() = shape.ToProto();
for (int64_t size : slice_sizes) {
instr.add_dynamic_slice_sizes(size);
}
std::vector<XlaOp> operands = {operand};
operands.insert(operands.end(), start_indices.begin(), start_indices.end());
return AddInstruction(std::move(instr), HloOpcode::kDynamicSlice, operands);
}
XlaOp XlaBuilder::DynamicUpdateSlice(XlaOp operand, XlaOp update,
absl::Span<const XlaOp> start_indices) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
TF_ASSIGN_OR_RETURN(const Shape* update_shape, GetShapePtr(update));
std::vector<const Shape*> start_indices_shape_ptrs;
TF_ASSIGN_OR_RETURN(const auto& start_indices_shapes,
GetOperandShapes(start_indices));
absl::c_transform(start_indices_shapes,
std::back_inserter(start_indices_shape_ptrs),
[](const Shape& shape) { return &shape; });
TF_ASSIGN_OR_RETURN(
Shape shape, ShapeInference::InferDynamicUpdateSliceShape(
*operand_shape, *update_shape, start_indices_shapes));
return DynamicUpdateSliceInternal(shape, operand, update, start_indices);
});
}
StatusOr<XlaOp> XlaBuilder::DynamicUpdateSliceInternal(
const Shape& shape, XlaOp operand, XlaOp update,
absl::Span<const XlaOp> start_indices) {
HloInstructionProto instr;
*instr.mutable_shape() = shape.ToProto();
std::vector<XlaOp> operands = {operand, update};
operands.insert(operands.end(), start_indices.begin(), start_indices.end());
return AddInstruction(std::move(instr), HloOpcode::kDynamicUpdateSlice,
operands);
}
XlaOp XlaBuilder::ConcatInDim(absl::Span<const XlaOp> operands,
int64_t dimension) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
std::vector<const Shape*> operand_shape_ptrs;
TF_ASSIGN_OR_RETURN(const auto& operand_shapes, GetOperandShapes(operands));
absl::c_transform(operand_shapes, std::back_inserter(operand_shape_ptrs),
[](const Shape& shape) { return &shape; });
TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferConcatOpShape(
operand_shape_ptrs, dimension));
return ConcatInDimInternal(shape, operands, dimension);
});
}
StatusOr<XlaOp> XlaBuilder::ConcatInDimInternal(
const Shape& shape, absl::Span<const XlaOp> operands, int64_t dimension) {
HloInstructionProto instr;
*instr.mutable_shape() = shape.ToProto();
instr.add_dimensions(dimension);
return AddInstruction(std::move(instr), HloOpcode::kConcatenate, operands);
}
XlaOp XlaBuilder::Pad(XlaOp operand, XlaOp padding_value,
const PaddingConfig& padding_config) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
TF_ASSIGN_OR_RETURN(const Shape* padding_value_shape,
GetShapePtr(padding_value));
TF_ASSIGN_OR_RETURN(
Shape shape, ShapeInference::InferPadShape(
*operand_shape, *padding_value_shape, padding_config));
return PadInternal(shape, operand, padding_value, padding_config);
});
}
XlaOp XlaBuilder::PadInDim(XlaOp operand, XlaOp padding_value, int64_t dimno,
int64_t pad_lo, int64_t pad_hi) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(const Shape* shape, GetShapePtr(operand));
PaddingConfig padding_config = MakeNoPaddingConfig(shape->rank());
auto* dims = padding_config.mutable_dimensions(dimno);
dims->set_edge_padding_low(pad_lo);
dims->set_edge_padding_high(pad_hi);
return Pad(operand, padding_value, padding_config);
});
}
StatusOr<XlaOp> XlaBuilder::PadInternal(const Shape& shape, XlaOp operand,
XlaOp padding_value,
const PaddingConfig& padding_config) {
HloInstructionProto instr;
*instr.mutable_shape() = shape.ToProto();
*instr.mutable_padding_config() = padding_config;
return AddInstruction(std::move(instr), HloOpcode::kPad,
{operand, padding_value});
}
XlaOp XlaBuilder::Reshape(XlaOp operand, absl::Span<const int64_t> dimensions,
absl::Span<const int64_t> new_sizes,
int64_t inferred_dimension) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
TF_ASSIGN_OR_RETURN(const Shape shape, ShapeInference::InferReshapeShape(
*operand_shape, dimensions,
new_sizes, inferred_dimension));
XlaOp transposed = IsIdentityPermutation(dimensions)
? operand
: Transpose(operand, dimensions);
return ReshapeInternal(shape, transposed, inferred_dimension);
});
}
XlaOp XlaBuilder::Reshape(XlaOp operand, absl::Span<const int64_t> new_sizes,
int64_t inferred_dimension) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(const Shape* shape, GetShapePtr(operand));
std::vector<int64_t> dimensions(shape->dimensions_size());
std::iota(dimensions.begin(), dimensions.end(), 0);
return Reshape(operand, dimensions, new_sizes, inferred_dimension);
});
}
XlaOp XlaBuilder::Reshape(const Shape& shape, XlaOp operand,
int64_t inferred_dimension) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
return ReshapeInternal(shape, operand, inferred_dimension);
});
}
XlaOp XlaBuilder::DynamicReshape(XlaOp operand,
absl::Span<const XlaOp> dim_sizes,
absl::Span<const int64_t> new_size_bounds,
const std::vector<bool>& dims_are_dynamic) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
std::vector<const Shape*> dim_size_shape_ptrs;
TF_ASSIGN_OR_RETURN(const auto& dim_size_shapes,
GetOperandShapes(dim_sizes));
absl::c_transform(dim_size_shapes, std::back_inserter(dim_size_shape_ptrs),
[](const Shape& shape) { return &shape; });
TF_ASSIGN_OR_RETURN(const Shape shape,
ShapeInference::InferDynamicReshapeShape(
*operand_shape, dim_size_shape_ptrs,
new_size_bounds, dims_are_dynamic));
TF_RETURN_IF_ERROR(first_error_);
std::vector<XlaOp> operands;
operands.reserve(1 + dim_sizes.size());
operands.push_back(operand);
for (const XlaOp& dim_size : dim_sizes) {
operands.push_back(dim_size);
}
HloInstructionProto instr;
*instr.mutable_shape() = shape.ToProto();
return AddInstruction(std::move(instr), HloOpcode::kDynamicReshape,
operands);
});
}
XlaOp XlaBuilder::Collapse(XlaOp operand,
absl::Span<const int64_t> dimensions) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
if (dimensions.size() <= 1) {
// Not collapsing anything, trivially we can return the operand versus
// enqueueing a trivial reshape.
return operand;
}
// Out-of-order collapse is not supported.
// Checks that the collapsed dimensions are in order and consecutive.
for (absl::Span<const int64_t>::size_type i = 1; i < dimensions.size();
++i) {
if (dimensions[i] - 1 != dimensions[i - 1]) {
return InvalidArgument(
"Collapsed dimensions are not in consecutive order.");
}
}
// Create a new sizes vector from the old shape, replacing the collapsed
// dimensions by the product of their sizes.
TF_ASSIGN_OR_RETURN(const Shape* original_shape, GetShapePtr(operand));
VLOG(3) << "original shape: " << ShapeUtil::HumanString(*original_shape);
VLOG(3) << "dims to collapse: " << absl::StrJoin(dimensions, ",");
std::vector<int64_t> new_sizes;
for (int i = 0; i < original_shape->rank(); ++i) {
if (i <= dimensions.front() || i > dimensions.back()) {
new_sizes.push_back(original_shape->dimensions(i));
} else {
new_sizes.back() *= original_shape->dimensions(i);
}
}
VLOG(3) << "new sizes: [" << absl::StrJoin(new_sizes, ",") << "]";
return Reshape(operand, new_sizes);
});
}
// Dummy pass-through computation returning it's parameter of shape `shape`.
static StatusOr<XlaComputation> PassthroughComputation(
const xla::Shape& shape) {
XlaBuilder builder("dummy");
XlaOp out = Parameter(&builder, 0, shape, "p");
return builder.Build(out);
}
XlaOp XlaBuilder::Select(XlaOp pred, XlaOp on_true, XlaOp on_false) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(const Shape* true_shape, GetShapePtr(on_true));
TF_ASSIGN_OR_RETURN(const Shape* false_shape, GetShapePtr(on_false));
TF_RET_CHECK(true_shape->IsTuple() == false_shape->IsTuple());
if (true_shape->IsTuple()) {
TF_ASSIGN_OR_RETURN(XlaComputation passthrough_true,
PassthroughComputation(*true_shape));
TF_ASSIGN_OR_RETURN(XlaComputation passthrough_false,
PassthroughComputation(*false_shape));
return Conditional(pred, on_true, passthrough_true, on_false,
passthrough_false);
}
return TernaryOp(HloOpcode::kSelect, pred, on_true, on_false);
});
}
XlaOp XlaBuilder::Tuple(absl::Span<const XlaOp> elements) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
std::vector<const Shape*> operand_shape_ptrs;
TF_ASSIGN_OR_RETURN(const auto& operand_shapes, GetOperandShapes(elements));
absl::c_transform(operand_shapes, std::back_inserter(operand_shape_ptrs),
[](const Shape& shape) { return &shape; });
TF_ASSIGN_OR_RETURN(const Shape shape,
ShapeInference::InferVariadicOpShape(
HloOpcode::kTuple, operand_shape_ptrs));
return TupleInternal(shape, elements);
});
}
StatusOr<XlaOp> XlaBuilder::TupleInternal(const Shape& shape,
absl::Span<const XlaOp> elements) {
HloInstructionProto instr;
*instr.mutable_shape() = shape.ToProto();
return AddInstruction(std::move(instr), HloOpcode::kTuple, elements);
}
XlaOp XlaBuilder::GetTupleElement(XlaOp tuple_data, int64_t index) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(const Shape* tuple_shape, GetShapePtr(tuple_data));
if (!tuple_shape->IsTuple()) {
return InvalidArgument(
"Operand to GetTupleElement() is not a tuple; got %s",
ShapeUtil::HumanString(*tuple_shape));
}
if (index < 0 || index >= ShapeUtil::TupleElementCount(*tuple_shape)) {
return InvalidArgument(
"GetTupleElement() index (%d) out of range for tuple shape %s", index,
ShapeUtil::HumanString(*tuple_shape));
}
return GetTupleElementInternal(
ShapeUtil::GetTupleElementShape(*tuple_shape, index), tuple_data,
index);
});
}
StatusOr<XlaOp> XlaBuilder::GetTupleElementInternal(const Shape& shape,
XlaOp tuple_data,
int64_t index) {
HloInstructionProto instr;
*instr.mutable_shape() = shape.ToProto();
instr.set_tuple_index(index);
return AddInstruction(std::move(instr), HloOpcode::kGetTupleElement,
{tuple_data});
}
XlaOp XlaBuilder::Dot(XlaOp lhs, XlaOp rhs,
const PrecisionConfig* precision_config,
absl::optional<PrimitiveType> preferred_element_type) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(const Shape* lhs_shape, GetShapePtr(lhs));
DotDimensionNumbers dimension_numbers;
dimension_numbers.add_lhs_contracting_dimensions(
lhs_shape->dimensions_size() == 1 ? 0 : 1);
dimension_numbers.add_rhs_contracting_dimensions(0);
return DotGeneral(lhs, rhs, dimension_numbers, precision_config,
preferred_element_type);
});
}
XlaOp XlaBuilder::DotGeneral(
XlaOp lhs, XlaOp rhs, const DotDimensionNumbers& dimension_numbers,
const PrecisionConfig* precision_config,
absl::optional<PrimitiveType> preferred_element_type) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(const Shape* lhs_shape, GetShapePtr(lhs));
TF_ASSIGN_OR_RETURN(const Shape* rhs_shape, GetShapePtr(rhs));
TF_ASSIGN_OR_RETURN(
Shape shape,
ShapeInference::InferDotOpShape(
*lhs_shape, *rhs_shape, dimension_numbers, preferred_element_type));
return DotGeneralInternal(shape, lhs, rhs, dimension_numbers,
precision_config);
});
}
StatusOr<XlaOp> XlaBuilder::DotGeneralInternal(
const Shape& shape, XlaOp lhs, XlaOp rhs,
const DotDimensionNumbers& dimension_numbers,
const PrecisionConfig* precision_config) {
HloInstructionProto instr;
*instr.mutable_shape() = shape.ToProto();
*instr.mutable_dot_dimension_numbers() = dimension_numbers;
if (precision_config != nullptr) {
*instr.mutable_precision_config() = *precision_config;
}
return AddInstruction(std::move(instr), HloOpcode::kDot, {lhs, rhs});
}
Status XlaBuilder::VerifyConvolution(
const Shape& lhs_shape, const Shape& rhs_shape,
const ConvolutionDimensionNumbers& dimension_numbers) const {
if (lhs_shape.rank() != rhs_shape.rank()) {
return InvalidArgument(
"Convolution arguments must have same number of "
"dimensions. Got: %s and %s",
ShapeUtil::HumanString(lhs_shape), ShapeUtil::HumanString(rhs_shape));
}
int num_dims = lhs_shape.rank();
if (num_dims < 2) {
return InvalidArgument(
"Convolution expects argument arrays with >= 3 dimensions. "
"Got: %s and %s",
ShapeUtil::HumanString(lhs_shape), ShapeUtil::HumanString(rhs_shape));
}
int num_spatial_dims = num_dims - 2;
const auto check_spatial_dimensions = [&](absl::string_view field_name,
absl::Span<const int64_t> numbers) {
if (numbers.size() != num_spatial_dims) {
return InvalidArgument("Expected %d elements for %s, but got %d.",
num_spatial_dims, field_name, numbers.size());
}
for (int i = 0; i < numbers.size(); ++i) {
if (numbers[i] < 0 || numbers[i] >= num_dims) {
return InvalidArgument("Convolution %s[%d] is out of bounds: %d",
field_name, i, numbers[i]);
}
}
return ::tensorflow::OkStatus();
};
TF_RETURN_IF_ERROR(
check_spatial_dimensions("input_spatial_dimensions",
dimension_numbers.input_spatial_dimensions()));
TF_RETURN_IF_ERROR(
check_spatial_dimensions("kernel_spatial_dimensions",
dimension_numbers.kernel_spatial_dimensions()));
return check_spatial_dimensions(
"output_spatial_dimensions",
dimension_numbers.output_spatial_dimensions());
}
XlaOp XlaBuilder::Conv(XlaOp lhs, XlaOp rhs,
absl::Span<const int64_t> window_strides,
Padding padding, int64_t feature_group_count,
int64_t batch_group_count,
const PrecisionConfig* precision_config,
absl::optional<PrimitiveType> preferred_element_type) {
return ConvWithGeneralDimensions(
lhs, rhs, window_strides, padding,
CreateDefaultConvDimensionNumbers(window_strides.size()),
feature_group_count, batch_group_count, precision_config,
preferred_element_type);
}
XlaOp XlaBuilder::ConvWithGeneralPadding(
XlaOp lhs, XlaOp rhs, absl::Span<const int64_t> window_strides,
absl::Span<const std::pair<int64_t, int64_t>> padding,
int64_t feature_group_count, int64_t batch_group_count,
const PrecisionConfig* precision_config,
absl::optional<PrimitiveType> preferred_element_type) {
return ConvGeneral(lhs, rhs, window_strides, padding,
CreateDefaultConvDimensionNumbers(window_strides.size()),
feature_group_count, batch_group_count, precision_config,
preferred_element_type);
}
XlaOp XlaBuilder::ConvWithGeneralDimensions(
XlaOp lhs, XlaOp rhs, absl::Span<const int64_t> window_strides,
Padding padding, const ConvolutionDimensionNumbers& dimension_numbers,
int64_t feature_group_count, int64_t batch_group_count,
const PrecisionConfig* precision_config,
absl::optional<PrimitiveType> preferred_element_type) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(const Shape* lhs_shape, GetShapePtr(lhs));
TF_ASSIGN_OR_RETURN(const Shape* rhs_shape, GetShapePtr(rhs));
TF_RETURN_IF_ERROR(
VerifyConvolution(*lhs_shape, *rhs_shape, dimension_numbers));
std::vector<int64_t> base_area_dimensions(
dimension_numbers.input_spatial_dimensions_size());
for (std::vector<int64_t>::size_type i = 0; i < base_area_dimensions.size();
++i) {
base_area_dimensions[i] =
lhs_shape->dimensions(dimension_numbers.input_spatial_dimensions(i));
}
std::vector<int64_t> window_dimensions(
dimension_numbers.kernel_spatial_dimensions_size());
for (std::vector<int64_t>::size_type i = 0; i < window_dimensions.size();
++i) {
window_dimensions[i] =
rhs_shape->dimensions(dimension_numbers.kernel_spatial_dimensions(i));
}
return ConvGeneral(lhs, rhs, window_strides,
MakePadding(base_area_dimensions, window_dimensions,
window_strides, padding),
dimension_numbers, feature_group_count,
batch_group_count, precision_config,
preferred_element_type);
});
}
XlaOp XlaBuilder::ConvGeneral(
XlaOp lhs, XlaOp rhs, absl::Span<const int64_t> window_strides,
absl::Span<const std::pair<int64_t, int64_t>> padding,
const ConvolutionDimensionNumbers& dimension_numbers,
int64_t feature_group_count, int64_t batch_group_count,
const PrecisionConfig* precision_config,
absl::optional<PrimitiveType> preferred_element_type) {
return ConvGeneralDilated(lhs, rhs, window_strides, padding, {}, {},
dimension_numbers, feature_group_count,
batch_group_count, precision_config,
preferred_element_type);
}
XlaOp XlaBuilder::ConvGeneralDilated(
XlaOp lhs, XlaOp rhs, absl::Span<const int64_t> window_strides,
absl::Span<const std::pair<int64_t, int64_t>> padding,
absl::Span<const int64_t> lhs_dilation,
absl::Span<const int64_t> rhs_dilation,
const ConvolutionDimensionNumbers& dimension_numbers,
int64_t feature_group_count, int64_t batch_group_count,
const PrecisionConfig* precision_config,
absl::optional<PrimitiveType> preferred_element_type) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(const Shape* lhs_shape, GetShapePtr(lhs));
TF_ASSIGN_OR_RETURN(const Shape* rhs_shape, GetShapePtr(rhs));
TF_RETURN_IF_ERROR(
VerifyConvolution(*lhs_shape, *rhs_shape, dimension_numbers));
std::vector<int64_t> window_dimensions(
dimension_numbers.kernel_spatial_dimensions_size());
for (std::vector<int64_t>::size_type i = 0; i < window_dimensions.size();
++i) {
window_dimensions[i] =
rhs_shape->dimensions(dimension_numbers.kernel_spatial_dimensions(i));
}
TF_ASSIGN_OR_RETURN(Window window,
ShapeInference::InferWindowFromDimensions(
window_dimensions, window_strides, padding,
lhs_dilation, rhs_dilation));
TF_ASSIGN_OR_RETURN(
Shape shape,
ShapeInference::InferConvolveShape(
*lhs_shape, *rhs_shape, feature_group_count, batch_group_count,
window, dimension_numbers, preferred_element_type));
return ConvGeneralDilatedInternal(shape, lhs, rhs, window, window_strides,
padding, lhs_dilation, rhs_dilation,
dimension_numbers, feature_group_count,
batch_group_count, precision_config);
});
}
StatusOr<HloInstructionProto> XlaBuilder::DynamicConvInstruction(
XlaOp lhs, XlaOp rhs, absl::Span<const int64_t> window_strides,
absl::Span<const std::pair<int64_t, int64_t>> padding,
absl::Span<const int64_t> lhs_dilation,
absl::Span<const int64_t> rhs_dilation,
const ConvolutionDimensionNumbers& dimension_numbers,
int64_t feature_group_count, int64_t batch_group_count,
const PrecisionConfig* precision_config, PaddingType padding_type,
absl::optional<PrimitiveType> preferred_element_type) {
TF_ASSIGN_OR_RETURN(const Shape* lhs_shape, GetShapePtr(lhs));
TF_ASSIGN_OR_RETURN(const Shape* rhs_shape, GetShapePtr(rhs));
std::vector<int64_t> window_dimensions(
dimension_numbers.kernel_spatial_dimensions_size());
for (std::vector<int64_t>::size_type i = 0; i < window_dimensions.size();
++i) {
window_dimensions[i] =
rhs_shape->dimensions(dimension_numbers.kernel_spatial_dimensions(i));
}
TF_ASSIGN_OR_RETURN(Window window, ShapeInference::InferWindowFromDimensions(
window_dimensions, window_strides,
padding, lhs_dilation, rhs_dilation));
TF_ASSIGN_OR_RETURN(
Shape shape,
ShapeInference::InferConvolveShape(
*lhs_shape, *rhs_shape, feature_group_count, batch_group_count,
window, dimension_numbers, preferred_element_type));
HloInstructionProto instr;
*instr.mutable_shape() = shape.ToProto();
*instr.mutable_window() = window;
*instr.mutable_convolution_dimension_numbers() = dimension_numbers;
instr.set_feature_group_count(feature_group_count);
instr.set_batch_group_count(batch_group_count);
instr.set_padding_type(padding_type);
if (precision_config != nullptr) {
*instr.mutable_precision_config() = *precision_config;
}
return std::move(instr);
}
XlaOp XlaBuilder::DynamicConvInputGrad(
XlaOp input_sizes, XlaOp lhs, XlaOp rhs,
absl::Span<const int64_t> window_strides,
absl::Span<const std::pair<int64_t, int64_t>> padding,
absl::Span<const int64_t> lhs_dilation,
absl::Span<const int64_t> rhs_dilation,
const ConvolutionDimensionNumbers& dimension_numbers,
int64_t feature_group_count, int64_t batch_group_count,
const PrecisionConfig* precision_config, PaddingType padding_type,
absl::optional<PrimitiveType> preferred_element_type) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(
HloInstructionProto instr,
DynamicConvInstruction(
lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation,
dimension_numbers, feature_group_count, batch_group_count,
precision_config, padding_type, preferred_element_type));
instr.set_custom_call_target("DynamicConvolutionInputGrad");
return AddInstruction(std::move(instr), HloOpcode::kCustomCall,
{input_sizes, lhs, rhs});
});
}
XlaOp XlaBuilder::DynamicConvKernelGrad(
XlaOp activations, XlaOp gradients,
absl::Span<const int64_t> window_strides,
absl::Span<const std::pair<int64_t, int64_t>> padding,
absl::Span<const int64_t> lhs_dilation,
absl::Span<const int64_t> rhs_dilation,
const ConvolutionDimensionNumbers& dimension_numbers,
int64_t feature_group_count, int64_t batch_group_count,
const PrecisionConfig* precision_config, PaddingType padding_type,
absl::optional<PrimitiveType> preferred_element_type) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(
HloInstructionProto instr,
DynamicConvInstruction(activations, gradients, window_strides, padding,
lhs_dilation, rhs_dilation, dimension_numbers,
feature_group_count, batch_group_count,
precision_config, padding_type,
preferred_element_type));
instr.set_custom_call_target("DynamicConvolutionKernelGrad");
// The gradient of kernel has kernel shape and shouldn't have any dynamic
// sizes.
instr.mutable_shape()->clear_is_dynamic_dimension();
return AddInstruction(std::move(instr), HloOpcode::kCustomCall,
{activations, gradients});
});
}
XlaOp XlaBuilder::DynamicConvForward(
XlaOp lhs, XlaOp rhs, absl::Span<const int64_t> window_strides,
absl::Span<const std::pair<int64_t, int64_t>> padding,
absl::Span<const int64_t> lhs_dilation,
absl::Span<const int64_t> rhs_dilation,
const ConvolutionDimensionNumbers& dimension_numbers,
int64_t feature_group_count, int64_t batch_group_count,
const PrecisionConfig* precision_config, PaddingType padding_type,
absl::optional<PrimitiveType> preferred_element_type) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(
HloInstructionProto instr,
DynamicConvInstruction(
lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation,
dimension_numbers, feature_group_count, batch_group_count,
precision_config, padding_type, preferred_element_type));
instr.set_custom_call_target("DynamicConvolutionForward");
return AddInstruction(std::move(instr), HloOpcode::kCustomCall, {lhs, rhs});
});
}
StatusOr<XlaOp> XlaBuilder::ConvGeneralDilatedInternal(
const Shape& shape, XlaOp lhs, XlaOp rhs, const Window& window,
absl::Span<const int64_t> window_strides,
absl::Span<const std::pair<int64_t, int64_t>> padding,
absl::Span<const int64_t> lhs_dilation,
absl::Span<const int64_t> rhs_dilation,
const ConvolutionDimensionNumbers& dimension_numbers,
int64_t feature_group_count, int64_t batch_group_count,
const PrecisionConfig* precision_config) {
HloInstructionProto instr;
*instr.mutable_shape() = shape.ToProto();
*instr.mutable_window() = window;
*instr.mutable_convolution_dimension_numbers() = dimension_numbers;
instr.set_feature_group_count(feature_group_count);
instr.set_batch_group_count(batch_group_count);
if (precision_config != nullptr) {
*instr.mutable_precision_config() = *precision_config;
}
return AddInstruction(std::move(instr), HloOpcode::kConvolution, {lhs, rhs});
}
XlaOp XlaBuilder::Fft(XlaOp operand, const FftType fft_type,
const absl::Span<const int64_t> fft_length) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferFftShape(
*operand_shape, fft_type, fft_length));
return FftInternal(shape, operand, fft_type, fft_length);
});
}
StatusOr<XlaOp> XlaBuilder::FftInternal(
const Shape& shape, XlaOp operand, const FftType fft_type,
const absl::Span<const int64_t> fft_length) {
HloInstructionProto instr;
*instr.mutable_shape() = shape.ToProto();
instr.set_fft_type(fft_type);
for (int64_t i : fft_length) {
instr.add_fft_length(i);
}
return AddInstruction(std::move(instr), HloOpcode::kFft, {operand});
}
StatusOr<XlaOp> XlaBuilder::TriangularSolveInternal(
const Shape& shape, XlaOp a, XlaOp b, TriangularSolveOptions options) {
HloInstructionProto instr;
*instr.mutable_triangular_solve_options() = std::move(options);
*instr.mutable_shape() = shape.ToProto();
return AddInstruction(std::move(instr), HloOpcode::kTriangularSolve, {a, b});
}
StatusOr<XlaOp> XlaBuilder::CholeskyInternal(const Shape& shape, XlaOp a,
bool lower) {
HloInstructionProto instr;
xla::CholeskyOptions& options = *instr.mutable_cholesky_options();
options.set_lower(lower);
*instr.mutable_shape() = shape.ToProto();
return AddInstruction(std::move(instr), HloOpcode::kCholesky, {a});
}
XlaOp XlaBuilder::Infeed(const Shape& shape, const std::string& config) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
if (!LayoutUtil::HasLayout(shape)) {
return InvalidArgument("Given shape to Infeed must have a layout");
}
const Shape infeed_instruction_shape =
ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeTokenShape()});
*instr.mutable_shape() = infeed_instruction_shape.ToProto();
instr.set_infeed_config(config);
if (shape.IsArray() && sharding() &&
sharding()->type() == OpSharding::OTHER) {
// TODO(b/110793772): Support tiled array-shaped infeeds.
return InvalidArgument(
"Tiled sharding is not yet supported for array-shaped infeeds");
}
if (sharding() && sharding()->type() == OpSharding::REPLICATED) {
return InvalidArgument(
"Replicated sharding is not yet supported for infeeds");
}
// Infeed takes a single token operand. Generate the token to pass to the
// infeed.
XlaOp token;
auto make_token = [&]() {
HloInstructionProto token_instr;
*token_instr.mutable_shape() = ShapeUtil::MakeTokenShape().ToProto();
return AddInstruction(std::move(token_instr), HloOpcode::kAfterAll, {});
};
if (sharding()) {
// Arbitrarily assign token to device 0.
OpSharding sharding = sharding_builder::AssignDevice(0);
XlaScopedShardingAssignment scoped_sharding(this, sharding);
TF_ASSIGN_OR_RETURN(token, make_token());
} else {
TF_ASSIGN_OR_RETURN(token, make_token());
}
// The sharding is set by the client according to the data tuple shape.
// However, the shape of the infeed instruction is a tuple containing the
// data and a token. For tuple sharding type, the sharding must be changed
// to accommodate the token.
XlaOp infeed;
if (sharding() && sharding()->type() == OpSharding::TUPLE) {
// TODO(b/80000000): Remove this when clients have been updated to handle
// tokens.
OpSharding infeed_instruction_sharding = *sharding();
// Arbitrarily assign the token to device 0.
*infeed_instruction_sharding.add_tuple_shardings() =
sharding_builder::AssignDevice(0);
XlaScopedShardingAssignment scoped_sharding(this,
infeed_instruction_sharding);
TF_ASSIGN_OR_RETURN(infeed, AddInstruction(std::move(instr),
HloOpcode::kInfeed, {token}));
} else {
TF_ASSIGN_OR_RETURN(infeed, AddInstruction(std::move(instr),
HloOpcode::kInfeed, {token}));
}
// The infeed instruction produces a tuple of the infed data and a token
// type. Return XLA op containing the data.
// TODO(b/80000000): Remove this when clients have been updated to handle
// tokens.
HloInstructionProto infeed_data;
*infeed_data.mutable_shape() = shape.ToProto();
infeed_data.set_tuple_index(0);
return AddInstruction(std::move(infeed_data), HloOpcode::kGetTupleElement,
{infeed});
});
}
XlaOp XlaBuilder::InfeedWithToken(XlaOp token, const Shape& shape,
const std::string& config) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
if (!LayoutUtil::HasLayout(shape)) {
return InvalidArgument("Given shape to Infeed must have a layout");
}
const Shape infeed_instruction_shape =
ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeTokenShape()});
if (shape.IsArray() && sharding() &&
sharding()->type() == OpSharding::OTHER) {
// TODO(b/110793772): Support tiled array-shaped infeeds.
return InvalidArgument(
"Tiled sharding is not yet supported for array-shaped infeeds");
}
if (sharding() && sharding()->type() == OpSharding::REPLICATED) {
return InvalidArgument(
"Replicated sharding is not yet supported for infeeds");
}
return InfeedWithTokenInternal(infeed_instruction_shape, token, config);
});
}
StatusOr<XlaOp> XlaBuilder::InfeedWithTokenInternal(
const Shape& infeed_instruction_shape, XlaOp token,
const std::string& config) {
HloInstructionProto instr;
*instr.mutable_shape() = infeed_instruction_shape.ToProto();
instr.set_infeed_config(config);
return AddInstruction(std::move(instr), HloOpcode::kInfeed, {token});
}
void XlaBuilder::Outfeed(XlaOp operand, const Shape& shape_with_layout,
const std::string& outfeed_config) {
ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
*instr.mutable_shape() = ShapeUtil::MakeTokenShape().ToProto();
// Check and set outfeed shape.
if (!LayoutUtil::HasLayout(shape_with_layout)) {
return InvalidArgument("Given shape to Outfeed must have a layout");
}
TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
if (!ShapeUtil::Compatible(*operand_shape, shape_with_layout)) {
return InvalidArgument(
"Outfeed shape %s must be compatible with operand shape %s",
ShapeUtil::HumanStringWithLayout(shape_with_layout),
ShapeUtil::HumanStringWithLayout(*operand_shape));
}
*instr.mutable_outfeed_shape() = shape_with_layout.ToProto();
instr.set_outfeed_config(outfeed_config);
// Outfeed takes a token as its second operand. Generate the token to pass
// to the outfeed.
HloInstructionProto token_instr;
*token_instr.mutable_shape() = ShapeUtil::MakeTokenShape().ToProto();
TF_ASSIGN_OR_RETURN(XlaOp token, AddInstruction(std::move(token_instr),
HloOpcode::kAfterAll, {}));
TF_RETURN_IF_ERROR(
AddInstruction(std::move(instr), HloOpcode::kOutfeed, {operand, token})
.status());
// The outfeed instruction produces a token. However, existing users expect
// a nil shape (empty tuple). This should only be relevant if the outfeed is
// the root of a computation.
// TODO(b/80000000): Remove this when clients have been updated to handle
// tokens.
HloInstructionProto tuple_instr;
*tuple_instr.mutable_shape() = ShapeUtil::MakeNil().ToProto();
// The dummy tuple should have no sharding.
{
XlaScopedShardingAssignment scoped_sharding(this, OpSharding());
TF_ASSIGN_OR_RETURN(
XlaOp empty_tuple,
AddInstruction(std::move(tuple_instr), HloOpcode::kTuple, {}));
return empty_tuple;
}
});
}
XlaOp XlaBuilder::OutfeedWithToken(XlaOp operand, XlaOp token,
const Shape& shape_with_layout,
const std::string& outfeed_config) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
// Check and set outfeed shape.
if (!LayoutUtil::HasLayout(shape_with_layout)) {
return InvalidArgument("Given shape to Outfeed must have a layout");
}
TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
if (!ShapeUtil::Compatible(*operand_shape, shape_with_layout)) {
return InvalidArgument(
"Outfeed shape %s must be compatible with operand shape %s",
ShapeUtil::HumanStringWithLayout(shape_with_layout),
ShapeUtil::HumanStringWithLayout(*operand_shape));
}
return OutfeedWithTokenInternal(operand, token, shape_with_layout,
outfeed_config);
});
}
StatusOr<XlaOp> XlaBuilder::OutfeedWithTokenInternal(
XlaOp operand, XlaOp token, const Shape& shape_with_layout,
const std::string& outfeed_config) {
HloInstructionProto instr;
*instr.mutable_shape() = ShapeUtil::MakeTokenShape().ToProto();
*instr.mutable_outfeed_shape() = shape_with_layout.ToProto();
instr.set_outfeed_config(outfeed_config);
return AddInstruction(std::move(instr), HloOpcode::kOutfeed,
{operand, token});
}
XlaOp XlaBuilder::CreateToken() {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
*instr.mutable_shape() = ShapeUtil::MakeTokenShape().ToProto();
return AddInstruction(std::move(instr), HloOpcode::kAfterAll);
});
}
XlaOp XlaBuilder::AfterAll(absl::Span<const XlaOp> tokens) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
if (tokens.empty()) {
return InvalidArgument("AfterAll requires at least one operand");
}
for (int i = 0, end = tokens.size(); i < end; ++i) {
XlaOp operand = tokens[i];
TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
if (!operand_shape->IsToken()) {
return InvalidArgument(
"All operands to AfterAll must be tokens; operand %d has shape %s",
i, ShapeUtil::HumanString(*operand_shape));
}
}
HloInstructionProto instr;
*instr.mutable_shape() = ShapeUtil::MakeTokenShape().ToProto();
return AddInstruction(std::move(instr), HloOpcode::kAfterAll, tokens);
});
}
XlaOp XlaBuilder::CustomCall(
const std::string& call_target_name, absl::Span<const XlaOp> operands,
const Shape& shape, const std::string& opaque,
absl::optional<absl::Span<const Shape>> operand_shapes_with_layout,
bool has_side_effect,
absl::Span<const std::pair<ShapeIndex, std::pair<int64_t, ShapeIndex>>>
output_operand_aliasing,
const Literal* literal, absl::optional<Window> window,
absl::optional<ConvolutionDimensionNumbers> dnums,
CustomCallSchedule schedule, CustomCallApiVersion api_version) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
if (absl::StartsWith(call_target_name, "$")) {
return InvalidArgument(
"Invalid custom_call_target \"%s\": Call targets that start with '$' "
"are reserved for internal use.",
call_target_name);
}
if (operand_shapes_with_layout.has_value()) {
if (!LayoutUtil::HasLayout(shape)) {
return InvalidArgument(
"Result shape must have layout for custom call with constrained "
"layout.");
}
if (operands.size() != operand_shapes_with_layout->size()) {
return InvalidArgument(
"Must specify a shape with layout for each operand for custom call "
"with constrained layout; given %d shapes, expected %d",
operand_shapes_with_layout->size(), operands.size());
}
int64_t operand_num = 0;
for (const Shape& operand_shape : *operand_shapes_with_layout) {
if (!LayoutUtil::HasLayout(operand_shape)) {
return InvalidArgument(
"No layout specified for operand %d for custom call with "
"constrained layout.",
operand_num);
}
++operand_num;
}
}
return CustomCallInternal(call_target_name, operands, shape, opaque,
operand_shapes_with_layout, has_side_effect,
output_operand_aliasing, literal, window, dnums,
schedule, api_version);
});
}
StatusOr<XlaOp> XlaBuilder::CustomCallInternal(
const std::string& call_target_name, absl::Span<const XlaOp> operands,
const Shape& shape, const std::string& opaque,
absl::optional<absl::Span<const Shape>> operand_shapes_with_layout,
bool has_side_effect,
absl::Span<const std::pair<ShapeIndex, std::pair<int64_t, ShapeIndex>>>
output_operand_aliasing,
const Literal* literal, absl::optional<Window> window,
absl::optional<ConvolutionDimensionNumbers> dnums,
CustomCallSchedule schedule, CustomCallApiVersion api_version) {
HloInstructionProto instr;
// Bit of a hack: cudnn conv custom-calls are created through this API. Give
// them a user-friendly name. (This has no effect on correctness, it's just
// cosmetic.)
if (call_target_name == "__cudnn$convForward") {
instr.set_name("cudnn-conv");
} else if (call_target_name == "__cudnn$convBackwardInput") {
instr.set_name("cudnn-conv-bw-input");
} else if (call_target_name == "__cudnn$convBackwardFilter") {
instr.set_name("cudnn-conv-bw-filter");
} else if (call_target_name == "__cudnn$convBiasActivationForward") {
instr.set_name("cudnn-conv-bias-activation");
}
*instr.mutable_shape() = shape.ToProto();
instr.set_custom_call_target(call_target_name);
instr.set_backend_config(opaque);
if (operand_shapes_with_layout.has_value()) {
instr.set_constrain_layout(true);
for (const Shape& operand_shape : *operand_shapes_with_layout) {
*instr.add_operand_shapes_with_layout() = operand_shape.ToProto();
}
}
if (literal != nullptr) {
*instr.mutable_literal() = literal->ToProto();
}
instr.set_custom_call_has_side_effect(has_side_effect);
for (const auto& pair : output_operand_aliasing) {
auto aliasing = instr.add_custom_call_output_operand_aliasing();
aliasing->set_operand_index(pair.second.first);
for (int64_t index : pair.second.second) {
aliasing->add_operand_shape_index(index);
}
for (int64_t index : pair.first) {
aliasing->add_output_shape_index(index);
}
}
if (window.has_value()) {
*instr.mutable_window() = *window;
}
if (dnums.has_value()) {
*instr.mutable_convolution_dimension_numbers() = *dnums;
}
instr.set_custom_call_schedule(schedule);
instr.set_custom_call_api_version(api_version);
return AddInstruction(std::move(instr), HloOpcode::kCustomCall, operands);
}
XlaOp XlaBuilder::CustomCall(
const std::string& call_target_name, absl::Span<const XlaOp> operands,
const XlaComputation& computation, const Shape& shape,
const std::string& opaque,
absl::optional<absl::Span<const Shape>> operand_shapes_with_layout,
bool has_side_effect,
absl::Span<const std::pair<ShapeIndex, std::pair<int64_t, ShapeIndex>>>
output_operand_aliasing,
const Literal* literal, CustomCallSchedule schedule,
CustomCallApiVersion api_version) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
if (absl::StartsWith(call_target_name, "$")) {
return InvalidArgument(
"Invalid custom_call_target \"%s\": Call targets that start with '$' "
"are reserved for internal use.",
call_target_name);
}
*instr.mutable_shape() = shape.ToProto();
instr.set_custom_call_target(call_target_name);
instr.set_backend_config(opaque);
if (literal != nullptr) {
*instr.mutable_literal() = literal->ToProto();
}
if (operand_shapes_with_layout.has_value()) {
if (!LayoutUtil::HasLayout(shape)) {
return InvalidArgument(
"Result shape must have layout for custom call with constrained "
"layout.");
}
if (operands.size() != operand_shapes_with_layout->size()) {
return InvalidArgument(
"Must specify a shape with layout for each operand for custom call "
"with constrained layout; given %d shapes, expected %d",
operand_shapes_with_layout->size(), operands.size());
}
instr.set_constrain_layout(true);
int64_t operand_num = 0;
for (const Shape& operand_shape : *operand_shapes_with_layout) {
if (!LayoutUtil::HasLayout(operand_shape)) {
return InvalidArgument(
"No layout specified for operand %d for custom call with "
"constrained layout.",
operand_num);
}
*instr.add_operand_shapes_with_layout() = operand_shape.ToProto();
++operand_num;
}
}
AddCalledComputation(computation, &instr);
for (const auto& pair : output_operand_aliasing) {
auto aliasing = instr.add_custom_call_output_operand_aliasing();
aliasing->set_operand_index(pair.second.first);
for (int64_t index : pair.second.second) {
aliasing->add_operand_shape_index(index);
}
for (int64_t index : pair.first) {
aliasing->add_output_shape_index(index);
}
}
instr.set_custom_call_schedule(schedule);
instr.set_custom_call_api_version(api_version);
return AddInstruction(std::move(instr), HloOpcode::kCustomCall, operands);
});
}
XlaOp XlaBuilder::OptimizationBarrier(XlaOp operand) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
Shape shape = *operand_shape;
HloInstructionProto instr;
*instr.mutable_shape() = shape.ToProto();
return AddInstruction(std::move(instr), HloOpcode::kOptimizationBarrier,
{operand});
});
}
XlaOp XlaBuilder::Transpose(XlaOp operand,
absl::Span<const int64_t> permutation) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferTransposeShape(
*operand_shape, permutation));
return TransposeInternal(shape, operand, permutation);
});
}
StatusOr<XlaOp> XlaBuilder::TransposeInternal(
const Shape& shape, XlaOp operand, absl::Span<const int64_t> permutation) {
HloInstructionProto instr;
*instr.mutable_shape() = shape.ToProto();
for (int64_t dim : permutation) {
instr.add_dimensions(dim);
}
return AddInstruction(std::move(instr), HloOpcode::kTranspose, {operand});
}
XlaOp XlaBuilder::Rev(XlaOp operand, absl::Span<const int64_t> dimensions) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferReverseShape(
*operand_shape, dimensions));
return RevInternal(shape, operand, dimensions);
});
}
StatusOr<XlaOp> XlaBuilder::RevInternal(const Shape& shape, XlaOp operand,
absl::Span<const int64_t> dimensions) {
HloInstructionProto instr;
*instr.mutable_shape() = shape.ToProto();
for (int64_t dim : dimensions) {
instr.add_dimensions(dim);
}
return AddInstruction(std::move(instr), HloOpcode::kReverse, {operand});
}
XlaOp XlaBuilder::Sort(absl::Span<const XlaOp> operands,
const XlaComputation& comparator, int64_t dimension,
bool is_stable) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
std::vector<const Shape*> operand_shape_ptrs;
TF_ASSIGN_OR_RETURN(std::vector<Shape> operand_shapes,
GetOperandShapes(operands));
absl::c_transform(operand_shapes, std::back_inserter(operand_shape_ptrs),
[](const Shape& shape) { return &shape; });
TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferVariadicOpShape(
HloOpcode::kSort, operand_shape_ptrs));
return SortInternal(shape, operands, comparator, dimension, is_stable);
});
}
StatusOr<XlaOp> XlaBuilder::SortInternal(const Shape& shape,
absl::Span<const XlaOp> operands,
const XlaComputation& comparator,
int64_t dimension, bool is_stable) {
HloInstructionProto instr;
*instr.mutable_shape() = shape.ToProto();
instr.set_is_stable(is_stable);
if (dimension == -1) {
TF_ASSIGN_OR_RETURN(const Shape* keys_shape, GetShapePtr(operands[0]));
dimension = keys_shape->rank() - 1;
}
instr.add_dimensions(dimension);
AddCalledComputation(comparator, &instr);
return AddInstruction(std::move(instr), HloOpcode::kSort, operands);
}
XlaOp XlaBuilder::ConvertElementType(XlaOp operand,
PrimitiveType new_element_type) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferConvertShape(
*operand_shape, new_element_type));
return AddOpWithShape(HloOpcode::kConvert, shape, {operand});
});
}
XlaOp XlaBuilder::BitcastConvertType(XlaOp operand,
PrimitiveType new_element_type) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferBitcastConvertShape(
*operand_shape, new_element_type));
return BitcastConvertTypeInternal(shape, operand);
});
}
StatusOr<XlaOp> XlaBuilder::BitcastConvertTypeInternal(const Shape& shape,
XlaOp operand) {
HloInstructionProto instr;
*instr.mutable_shape() = shape.ToProto();
return AddInstruction(std::move(instr), HloOpcode::kBitcastConvert,
{operand});
}
XlaOp XlaBuilder::Clamp(XlaOp min, XlaOp operand, XlaOp max) {
return TernaryOp(HloOpcode::kClamp, min, operand, max);
}
XlaOp XlaBuilder::Map(absl::Span<const XlaOp> operands,
const XlaComputation& computation,
absl::Span<const int64_t> dimensions,
absl::Span<const XlaOp> static_operands) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
if (!static_operands.empty()) {
return Unimplemented("static_operands is not supported in Map");
}
HloInstructionProto instr;
std::vector<const Shape*> operand_shape_ptrs;
TF_ASSIGN_OR_RETURN(const auto& operand_shapes, GetOperandShapes(operands));
absl::c_transform(operand_shapes, std::back_inserter(operand_shape_ptrs),
[](const Shape& shape) { return &shape; });
TF_ASSIGN_OR_RETURN(const ProgramShape& called_program_shape,
computation.GetProgramShape());
TF_ASSIGN_OR_RETURN(
Shape shape, ShapeInference::InferMapShape(
operand_shape_ptrs, called_program_shape, dimensions));
*instr.mutable_shape() = shape.ToProto();
Shape output_shape(instr.shape());
const int64_t output_rank = output_shape.rank();
AddCalledComputation(computation, &instr);
std::vector<XlaOp> new_operands(operands.begin(), operands.end());
for (XlaOp& new_operand : new_operands) {
TF_ASSIGN_OR_RETURN(const Shape* shape, GetShapePtr(new_operand));
const int64_t rank = shape->rank();
if (rank != output_rank) {
TF_ASSIGN_OR_RETURN(new_operand,
InDimBroadcast(output_shape, new_operand, {}));
TF_ASSIGN_OR_RETURN(shape, GetShapePtr(new_operand));
}
if (!ShapeUtil::SameDimensions(output_shape, *shape)) {
TF_ASSIGN_OR_RETURN(new_operand,
AddBroadcastSequence(output_shape, new_operand));
}
}
return AddInstruction(std::move(instr), HloOpcode::kMap, new_operands);
});
}
XlaOp XlaBuilder::RngOp(RandomDistribution distribution,
absl::Span<const XlaOp> parameters,
const Shape& shape) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
// Check the number of parameters per RNG distribution.
switch (distribution) {
case RandomDistribution::RNG_NORMAL:
case RandomDistribution::RNG_UNIFORM:
if (parameters.size() != 2) {
return InvalidArgument(
"RNG distribution (%s) expects 2 parameters, but got %ld",
RandomDistribution_Name(distribution), parameters.size());
}
break;
default:
LOG(FATAL) << "unhandled distribution " << distribution;
}
TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(shape));
return RngOpInternal(distribution, parameters, shape);
});
}
StatusOr<XlaOp> XlaBuilder::RngOpInternal(RandomDistribution distribution,
absl::Span<const XlaOp> parameters,
const Shape& shape) {
HloInstructionProto instr;
*instr.mutable_shape() = shape.ToProto();
instr.set_distribution(distribution);
return AddInstruction(std::move(instr), HloOpcode::kRng, parameters);
}
XlaOp XlaBuilder::RngNormal(XlaOp mu, XlaOp sigma, const Shape& shape) {
return RngOp(RandomDistribution::RNG_NORMAL, {mu, sigma}, shape);
}
XlaOp XlaBuilder::RngUniform(XlaOp a, XlaOp b, const Shape& shape) {
return RngOp(RandomDistribution::RNG_UNIFORM, {a, b}, shape);
}
XlaOp XlaBuilder::RngBitGenerator(RandomAlgorithm algorithm,
XlaOp initial_state, const Shape& shape) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(shape));
TF_ASSIGN_OR_RETURN(Shape state_shape, GetShape(initial_state));
Shape output_shape = shape;
switch (output_shape.element_type()) {
case PrimitiveType::F32:
case PrimitiveType::S32:
case PrimitiveType::U32:
output_shape.set_element_type(PrimitiveType::U32);
break;
case PrimitiveType::F64:
case PrimitiveType::S64:
case PrimitiveType::U64:
output_shape.set_element_type(PrimitiveType::U64);
break;
default:
return InvalidArgument("Unsupported shape for RngBitGenerator: %s",
PrimitiveType_Name(output_shape.element_type()));
}
return RngBitGeneratorInternal(
ShapeUtil::MakeTupleShapeWithPtrs({&state_shape, &output_shape}),
algorithm, initial_state);
});
}
StatusOr<XlaOp> XlaBuilder::RngBitGeneratorInternal(
const Shape& full_result_shape, RandomAlgorithm algorithm,
XlaOp initial_state) {
HloInstructionProto instr;
*instr.mutable_shape() = full_result_shape.ToProto();
instr.set_rng_algorithm(algorithm);
return AddInstruction(std::move(instr), HloOpcode::kRngBitGenerator,
{initial_state});
}
XlaOp XlaBuilder::While(const XlaComputation& condition,
const XlaComputation& body, XlaOp init) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
// Infer shape.
TF_ASSIGN_OR_RETURN(const auto& body_program_shape, body.GetProgramShape());
TF_ASSIGN_OR_RETURN(const auto& condition_program_shape,
condition.GetProgramShape());
TF_ASSIGN_OR_RETURN(const Shape* init_shape, GetShapePtr(init));
TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferWhileShape(
condition_program_shape,
body_program_shape, *init_shape));
return WhileInternal(shape, condition, body, init);
});
}
StatusOr<XlaOp> XlaBuilder::WhileInternal(const Shape& shape,
const XlaComputation& condition,
const XlaComputation& body,
XlaOp init) {
HloInstructionProto instr;
*instr.mutable_shape() = shape.ToProto();
// Body comes before condition computation in the vector.
AddCalledComputation(body, &instr);
AddCalledComputation(condition, &instr);
return AddInstruction(std::move(instr), HloOpcode::kWhile, {init});
}
XlaOp XlaBuilder::Gather(XlaOp input, XlaOp start_indices,
const GatherDimensionNumbers& dimension_numbers,
absl::Span<const int64_t> slice_sizes,
bool indices_are_sorted) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(const Shape* input_shape, GetShapePtr(input));
TF_ASSIGN_OR_RETURN(const Shape* start_indices_shape,
GetShapePtr(start_indices));
TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferGatherShape(
*input_shape, *start_indices_shape,
dimension_numbers, slice_sizes));
return GatherInternal(shape, input, start_indices, dimension_numbers,
slice_sizes, indices_are_sorted);
});
}
StatusOr<XlaOp> XlaBuilder::GatherInternal(
const Shape& shape, XlaOp input, XlaOp start_indices,
const GatherDimensionNumbers& dimension_numbers,
absl::Span<const int64_t> slice_sizes, bool indices_are_sorted) {
HloInstructionProto instr;
instr.set_indices_are_sorted(indices_are_sorted);
*instr.mutable_shape() = shape.ToProto();
*instr.mutable_gather_dimension_numbers() = dimension_numbers;
for (int64_t bound : slice_sizes) {
instr.add_gather_slice_sizes(bound);
}
return AddInstruction(std::move(instr), HloOpcode::kGather,
{input, start_indices});
}
XlaOp XlaBuilder::Scatter(XlaOp input, XlaOp scatter_indices, XlaOp updates,
const XlaComputation& update_computation,
const ScatterDimensionNumbers& dimension_numbers,
bool indices_are_sorted, bool unique_indices) {
return Scatter(absl::MakeConstSpan(&input, 1), scatter_indices,
absl::MakeConstSpan(&updates, 1), update_computation,
dimension_numbers, indices_are_sorted, unique_indices);
}
XlaOp XlaBuilder::Scatter(absl::Span<const XlaOp> inputs, XlaOp scatter_indices,
absl::Span<const XlaOp> updates,
const XlaComputation& update_computation,
const ScatterDimensionNumbers& dimension_numbers,
bool indices_are_sorted, bool unique_indices) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
if (inputs.empty()) {
return InvalidArgument("Scatter inputs cannot be empty.");
}
if (inputs.size() != updates.size()) {
return InvalidArgument(
"Scatter should have same number of inputs and updates: %d vs %d.",
inputs.size(), updates.size());
}
absl::InlinedVector<const Shape*, 3> operand_shapes;
operand_shapes.reserve(inputs.size() + 1 + updates.size());
for (const XlaOp& input : inputs) {
TF_ASSIGN_OR_RETURN(const Shape* input_shape, GetShapePtr(input));
operand_shapes.push_back(input_shape);
}
TF_ASSIGN_OR_RETURN(const Shape* scatter_indices_shape,
GetShapePtr(scatter_indices));
operand_shapes.push_back(scatter_indices_shape);
for (const XlaOp& update : updates) {
TF_ASSIGN_OR_RETURN(const Shape* update_shape, GetShapePtr(update));
operand_shapes.push_back(update_shape);
}
TF_ASSIGN_OR_RETURN(const ProgramShape& to_apply_shape,
update_computation.GetProgramShape());
TF_ASSIGN_OR_RETURN(Shape shape,
ShapeInference::InferScatterShape(
operand_shapes, to_apply_shape, dimension_numbers));
return ScatterInternal(shape, inputs, scatter_indices, updates,
update_computation, dimension_numbers,
indices_are_sorted, unique_indices);
});
}
StatusOr<XlaOp> XlaBuilder::ScatterInternal(
const Shape& shape, absl::Span<const XlaOp> inputs, XlaOp scatter_indices,
absl::Span<const XlaOp> updates, const XlaComputation& update_computation,
const ScatterDimensionNumbers& dimension_numbers, bool indices_are_sorted,
bool unique_indices) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
instr.set_indices_are_sorted(indices_are_sorted);
instr.set_unique_indices(unique_indices);
*instr.mutable_shape() = shape.ToProto();
*instr.mutable_scatter_dimension_numbers() = dimension_numbers;
AddCalledComputation(update_computation, &instr);
absl::InlinedVector<XlaOp, 3> operands;
operands.reserve(inputs.size() + 1 + updates.size());
absl::c_copy(inputs, std::back_inserter(operands));
operands.push_back(scatter_indices);
absl::c_copy(updates, std::back_inserter(operands));
return AddInstruction(std::move(instr), HloOpcode::kScatter, operands);
});
}
XlaOp XlaBuilder::Conditional(XlaOp predicate, XlaOp true_operand,
const XlaComputation& true_computation,
XlaOp false_operand,
const XlaComputation& false_computation) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(const xla::Shape* shape, GetShapePtr(predicate));
if (!ShapeUtil::IsScalar(*shape) || shape->element_type() != PRED) {
return InvalidArgument(
"Argument to predicated-Conditional is not a scalar of PRED type "
"(%s).",
ShapeUtil::HumanString(*shape));
}
// The index of true_computation must be 0 and that of false computation
// must be 1.
return ConditionalImpl(predicate, {&true_computation, &false_computation},
{true_operand, false_operand});
});
}
XlaOp XlaBuilder::Conditional(
XlaOp branch_index,
absl::Span<const XlaComputation* const> branch_computations,
absl::Span<const XlaOp> branch_operands) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(const xla::Shape* shape, GetShapePtr(branch_index));
if (!ShapeUtil::IsScalar(*shape) || shape->element_type() != S32) {
return InvalidArgument(
"Argument to indexed-Conditional is not a scalar of S32 type (%s).",
ShapeUtil::HumanString(*shape));
}
return ConditionalImpl(branch_index, branch_computations, branch_operands);
});
}
XlaOp XlaBuilder::ConditionalImpl(
XlaOp branch_index,
absl::Span<const XlaComputation* const> branch_computations,
absl::Span<const XlaOp> branch_operands) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
TF_ASSIGN_OR_RETURN(const Shape* branch_index_shape,
GetShapePtr(branch_index));
std::vector<Shape> branch_operand_shapes(branch_operands.size());
std::vector<ProgramShape> branch_computation_shapes(
branch_computations.size());
for (int j = 0, end = branch_operands.size(); j < end; ++j) {
TF_ASSIGN_OR_RETURN(branch_operand_shapes[j],
GetShape(branch_operands[j]));
TF_ASSIGN_OR_RETURN(branch_computation_shapes[j],
branch_computations[j]->GetProgramShape());
}
TF_ASSIGN_OR_RETURN(const Shape shape,
ShapeInference::InferConditionalShape(
*branch_index_shape, branch_computation_shapes,
branch_operand_shapes));
*instr.mutable_shape() = shape.ToProto();
for (const XlaComputation* branch_computation : branch_computations) {
AddCalledComputation(*branch_computation, &instr);
}
std::vector<XlaOp> operands(1, branch_index);
for (const XlaOp branch_operand : branch_operands) {
operands.emplace_back(branch_operand);
}
return AddInstruction(std::move(instr), HloOpcode::kConditional,
absl::MakeSpan(operands));
});
}
Status XlaBuilder::CheckOpBuilder(XlaOp op) const {
if (this != op.builder()) {
return InvalidArgument(
"XlaOp with handle %d is built by builder '%s', but is trying to use "
"it in builder '%s'",
op.handle(), op.builder()->name(), name());
}
return ::tensorflow::OkStatus();
}
XlaOp XlaBuilder::Reduce(XlaOp operand, XlaOp init_value,
const XlaComputation& computation,
absl::Span<const int64_t> dimensions_to_reduce) {
return Reduce(absl::Span<const XlaOp>({operand}),
absl::Span<const XlaOp>({init_value}), computation,
dimensions_to_reduce);
}
XlaOp XlaBuilder::Reduce(absl::Span<const XlaOp> operands,
absl::Span<const XlaOp> init_values,
const XlaComputation& computation,
absl::Span<const int64_t> dimensions_to_reduce) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(const ProgramShape& called_program_shape,
computation.GetProgramShape());
std::vector<XlaOp> all_operands;
all_operands.insert(all_operands.end(), operands.begin(), operands.end());
all_operands.insert(all_operands.end(), init_values.begin(),
init_values.end());
std::vector<const Shape*> operand_shape_ptrs;
TF_ASSIGN_OR_RETURN(const auto& operand_shapes,
GetOperandShapes(all_operands));
absl::c_transform(operand_shapes, std::back_inserter(operand_shape_ptrs),
[](const Shape& shape) { return &shape; });
TF_ASSIGN_OR_RETURN(
Shape shape,
ShapeInference::InferReduceShape(
operand_shape_ptrs, dimensions_to_reduce, called_program_shape));
return ReduceInternal(shape, all_operands, computation,
dimensions_to_reduce);
});
}
StatusOr<XlaOp> XlaBuilder::ReduceInternal(
const Shape& shape, absl::Span<const XlaOp> all_operands,
const XlaComputation& computation,
absl::Span<const int64_t> dimensions_to_reduce) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
*instr.mutable_shape() = shape.ToProto();
for (int64_t dim : dimensions_to_reduce) {
instr.add_dimensions(dim);
}
AddCalledComputation(computation, &instr);
return AddInstruction(std::move(instr), HloOpcode::kReduce, all_operands);
});
}
XlaOp XlaBuilder::ReduceAll(XlaOp operand, XlaOp init_value,
const XlaComputation& computation) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
std::vector<int64_t> all_dimnos(operand_shape->rank());
std::iota(all_dimnos.begin(), all_dimnos.end(), 0);
return Reduce(operand, init_value, computation, all_dimnos);
});
}
XlaOp XlaBuilder::ReduceWindow(XlaOp operand, XlaOp init_value,
const XlaComputation& computation,
absl::Span<const int64_t> window_dimensions,
absl::Span<const int64_t> window_strides,
Padding padding) {
return ReduceWindow(absl::MakeSpan(&operand, 1),
absl::MakeSpan(&init_value, 1), computation,
window_dimensions, window_strides, padding);
}
XlaOp XlaBuilder::ReduceWindow(absl::Span<const XlaOp> operands,
absl::Span<const XlaOp> init_values,
const XlaComputation& computation,
absl::Span<const int64_t> window_dimensions,
absl::Span<const int64_t> window_strides,
Padding padding) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
const Shape* operand_shape = nullptr;
for (const auto& operand : operands) {
TF_ASSIGN_OR_RETURN(operand_shape, GetShapePtr(operand));
TF_RETURN_IF_ERROR(ValidatePaddingValues(
operand_shape->dimensions(), window_dimensions, window_strides));
}
CHECK(operand_shape != nullptr);
std::vector<std::pair<int64_t, int64_t>> padding_values =
MakePadding(operand_shape->dimensions(), window_dimensions,
window_strides, padding);
TF_ASSIGN_OR_RETURN(auto window,
ShapeInference::InferWindowFromDimensions(
window_dimensions, window_strides, padding_values,
/*lhs_dilation=*/{},
/*rhs_dilation=*/{}));
PaddingType padding_type = PADDING_INVALID;
for (int64_t i = 0; i < operand_shape->rank(); ++i) {
if (operand_shape->is_dynamic_dimension(i) &&
!window_util::IsTrivialWindowDimension(window.dimensions(i)) &&
padding == Padding::kSame) {
// SAME padding can create dynamic padding sizes. The padding size
// need to be rewritten by dynamic padder using HloInstructions. We
// create a CustomCall to handle this.
padding_type = PADDING_SAME;
}
}
if (padding_type == PADDING_SAME) {
TF_ASSIGN_OR_RETURN(
HloInstructionProto instr,
ReduceWindowInternal(operands, init_values, computation,
window_dimensions, window_strides, {}, {},
padding_values));
instr.set_custom_call_target("DynamicReduceWindowSamePadding");
std::vector<XlaOp> args;
args.insert(args.end(), operands.begin(), operands.end());
args.insert(args.end(), init_values.begin(), init_values.end());
return AddInstruction(std::move(instr), HloOpcode::kCustomCall, args);
}
return ReduceWindowWithGeneralPadding(
operands, init_values, computation, window_dimensions, window_strides,
/*base_dilations=*/{}, /*window_dilations=*/{}, padding_values);
});
}
XlaOp XlaBuilder::ReduceWindowWithGeneralPadding(
absl::Span<const XlaOp> operands, absl::Span<const XlaOp> init_values,
const XlaComputation& computation,
absl::Span<const int64_t> window_dimensions,
absl::Span<const int64_t> window_strides,
absl::Span<const int64_t> base_dilations,
absl::Span<const int64_t> window_dilations,
absl::Span<const std::pair<int64_t, int64_t>> padding) {
std::vector<const Shape*> operand_shapes, init_shapes;
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
if (operands.size() == 1) {
const auto& operand = operands[0];
const auto& init_value = init_values[0];
TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
operand_shapes.push_back(operand_shape);
TF_ASSIGN_OR_RETURN(const Shape* init_shape, GetShapePtr(init_value));
init_shapes.push_back(init_shape);
TF_ASSIGN_OR_RETURN(const ProgramShape& to_apply_shape,
computation.GetProgramShape());
TF_ASSIGN_OR_RETURN(auto window,
ShapeInference::InferWindowFromDimensions(
window_dimensions, window_strides, padding,
/*lhs_dilation=*/base_dilations,
/*rhs_dilation=*/window_dilations));
TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferReduceWindowShape(
absl::MakeSpan(operand_shapes),
absl::MakeSpan(init_shapes), window,
to_apply_shape));
return ReduceWindowInternal(shape, operands[0], init_values[0],
computation, window);
}
TF_ASSIGN_OR_RETURN(
HloInstructionProto instr,
ReduceWindowInternal(operands, init_values, computation,
window_dimensions, window_strides, base_dilations,
window_dilations, padding));
std::vector<XlaOp> args;
args.insert(args.end(), operands.begin(), operands.end());
args.insert(args.end(), init_values.begin(), init_values.end());
return AddInstruction(std::move(instr), HloOpcode::kReduceWindow, args);
});
}
StatusOr<HloInstructionProto> XlaBuilder::ReduceWindowInternal(
absl::Span<const XlaOp> operands, absl::Span<const XlaOp> init_values,
const XlaComputation& computation,
absl::Span<const int64_t> window_dimensions,
absl::Span<const int64_t> window_strides,
absl::Span<const int64_t> base_dilations,
absl::Span<const int64_t> window_dilations,
absl::Span<const std::pair<int64_t, int64_t>> padding) {
std::vector<const Shape*> operand_shapes, init_shapes;
for (int i = 0; i < operands.size(); ++i) {
const auto& operand = operands[i];
const auto& init_value = init_values[i];
TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
operand_shapes.push_back(operand_shape);
TF_ASSIGN_OR_RETURN(const Shape* init_shape, GetShapePtr(init_value));
init_shapes.push_back(init_shape);
}
TF_ASSIGN_OR_RETURN(const ProgramShape& to_apply_shape,
computation.GetProgramShape());
TF_ASSIGN_OR_RETURN(auto window,
ShapeInference::InferWindowFromDimensions(
window_dimensions, window_strides, padding,
/*lhs_dilation=*/base_dilations,
/*rhs_dilation=*/window_dilations));
TF_ASSIGN_OR_RETURN(Shape shape,
ShapeInference::InferReduceWindowShape(
absl::MakeSpan(operand_shapes),
absl::MakeSpan(init_shapes), window, to_apply_shape));
HloInstructionProto instr;
*instr.mutable_shape() = shape.ToProto();
*instr.mutable_window() = std::move(window);
AddCalledComputation(computation, &instr);
return instr;
}
StatusOr<XlaOp> XlaBuilder::ReduceWindowInternal(
const Shape& shape, XlaOp operand, XlaOp init_value,
const XlaComputation& computation, Window window) {
HloInstructionProto instr;
*instr.mutable_shape() = shape.ToProto();
*instr.mutable_window() = std::move(window);
AddCalledComputation(computation, &instr);
return AddInstruction(std::move(instr), HloOpcode::kReduceWindow,
{operand, init_value});
}
XlaOp XlaBuilder::BatchNormTraining(XlaOp operand, XlaOp scale, XlaOp offset,
float epsilon, int64_t feature_index) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
TF_ASSIGN_OR_RETURN(const Shape* scale_shape, GetShapePtr(scale));
TF_ASSIGN_OR_RETURN(const Shape* offset_shape, GetShapePtr(offset));
TF_ASSIGN_OR_RETURN(
Shape shape,
ShapeInference::InferBatchNormTrainingShape(
*operand_shape, *scale_shape, *offset_shape, feature_index));
*instr.mutable_shape() = shape.ToProto();
instr.set_epsilon(epsilon);
instr.set_feature_index(feature_index);
return AddInstruction(std::move(instr), HloOpcode::kBatchNormTraining,
{operand, scale, offset});
});
}
XlaOp XlaBuilder::BatchNormInference(XlaOp operand, XlaOp scale, XlaOp offset,
XlaOp mean, XlaOp variance, float epsilon,
int64_t feature_index) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
TF_ASSIGN_OR_RETURN(const Shape* scale_shape, GetShapePtr(scale));
TF_ASSIGN_OR_RETURN(const Shape* offset_shape, GetShapePtr(offset));
TF_ASSIGN_OR_RETURN(const Shape* mean_shape, GetShapePtr(mean));
TF_ASSIGN_OR_RETURN(const Shape* variance_shape, GetShapePtr(variance));
TF_ASSIGN_OR_RETURN(Shape shape,
ShapeInference::InferBatchNormInferenceShape(
*operand_shape, *scale_shape, *offset_shape,
*mean_shape, *variance_shape, feature_index));
*instr.mutable_shape() = shape.ToProto();
instr.set_epsilon(epsilon);
instr.set_feature_index(feature_index);
return AddInstruction(std::move(instr), HloOpcode::kBatchNormInference,
{operand, scale, offset, mean, variance});
});
}
XlaOp XlaBuilder::BatchNormGrad(XlaOp operand, XlaOp scale, XlaOp batch_mean,
XlaOp batch_var, XlaOp grad_output,
float epsilon, int64_t feature_index) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
TF_ASSIGN_OR_RETURN(const Shape* scale_shape, GetShapePtr(scale));
TF_ASSIGN_OR_RETURN(const Shape* batch_mean_shape, GetShapePtr(batch_mean));
TF_ASSIGN_OR_RETURN(const Shape* batch_var_shape, GetShapePtr(batch_var));
TF_ASSIGN_OR_RETURN(const Shape* grad_output_shape,
GetShapePtr(grad_output));
TF_ASSIGN_OR_RETURN(
Shape shape, ShapeInference::InferBatchNormGradShape(
*operand_shape, *scale_shape, *batch_mean_shape,
*batch_var_shape, *grad_output_shape, feature_index));
*instr.mutable_shape() = shape.ToProto();
instr.set_epsilon(epsilon);
instr.set_feature_index(feature_index);
return AddInstruction(std::move(instr), HloOpcode::kBatchNormGrad,
{operand, scale, batch_mean, batch_var, grad_output});
});
}
XlaOp XlaBuilder::AllGather(XlaOp operand, int64_t all_gather_dimension,
int64_t shard_count,
absl::Span<const ReplicaGroup> replica_groups,
const absl::optional<ChannelHandle>& channel_id,
const absl::optional<Layout>& layout,
const absl::optional<bool> use_global_device_ids) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
TF_ASSIGN_OR_RETURN(
Shape inferred_shape,
ShapeInference::InferAllGatherShape({operand_shape},
all_gather_dimension, shard_count));
if (layout) {
*inferred_shape.mutable_layout() = *layout;
instr.set_constrain_layout(true);
}
*instr.mutable_shape() = inferred_shape.ToProto();
instr.add_dimensions(all_gather_dimension);
for (const ReplicaGroup& group : replica_groups) {
*instr.add_replica_groups() = group;
}
if (channel_id.has_value()) {
instr.set_channel_id(channel_id->handle());
}
if (use_global_device_ids.has_value()) {
instr.set_use_global_device_ids(use_global_device_ids.value());
}
TF_ASSIGN_OR_RETURN(
auto all_gather,
AddInstruction(std::move(instr), HloOpcode::kAllGather, {operand}));
return all_gather;
});
}
XlaOp XlaBuilder::CrossReplicaSum(
XlaOp operand, absl::Span<const ReplicaGroup> replica_groups) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(const Shape* shape, GetShapePtr(operand));
const Shape* element_shape;
if (shape->IsTuple()) {
if (shape->tuple_shapes_size() == 0) {
return Unimplemented(
"0 element tuple CrossReplicaSum is not supported");
}
element_shape = &shape->tuple_shapes(0);
} else {
element_shape = shape;
}
const Shape scalar_shape =
ShapeUtil::MakeShape(element_shape->element_type(), {});
auto b = CreateSubBuilder("sum");
auto x = b->Parameter(/*parameter_number=*/0, scalar_shape, "x");
auto y = b->Parameter(/*parameter_number=*/1, scalar_shape, "y");
if (scalar_shape.element_type() == PRED) {
Or(x, y);
} else {
Add(x, y);
}
TF_ASSIGN_OR_RETURN(auto computation, b->Build());
return AllReduce(operand, computation, replica_groups,
/*channel_id=*/absl::nullopt);
});
}
XlaOp XlaBuilder::AllReduce(XlaOp operand, const XlaComputation& computation,
absl::Span<const ReplicaGroup> replica_groups,
const absl::optional<ChannelHandle>& channel_id,
const absl::optional<Shape>& shape_with_layout) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
std::vector<const Shape*> operand_shapes;
std::vector<XlaOp> operands;
if (operand_shape->IsTuple()) {
if (operand_shape->tuple_shapes_size() == 0) {
return Unimplemented("0 element tuple AllReduce is not supported");
}
for (int i = 0; i < operand_shape->tuple_shapes_size(); ++i) {
if (operand_shape->tuple_shapes(i).element_type() !=
operand_shape->tuple_shapes(0).element_type()) {
return Unimplemented(
"All the shapes of a tuple input of AllReduce must have the same "
"element type");
}
operand_shapes.push_back(&operand_shape->tuple_shapes(i));
operands.push_back(GetTupleElement(operand, i));
}
} else {
operand_shapes.push_back(operand_shape);
operands.push_back(operand);
}
TF_ASSIGN_OR_RETURN(Shape inferred_shape,
ShapeInference::InferAllReduceShape(operand_shapes));
if (shape_with_layout) {
if (!LayoutUtil::HasLayout(*shape_with_layout)) {
return InvalidArgument("shape_with_layout must have the layout set: %s",
shape_with_layout->ToString());
}
if (!ShapeUtil::Compatible(*shape_with_layout, *operand_shape)) {
return InvalidArgument(
"Provided shape_with_layout must be compatible with the "
"operand shape: %s vs %s",
shape_with_layout->ToString(), operand_shape->ToString());
}
instr.set_constrain_layout(true);
if (operand_shape->IsTuple() && !inferred_shape.IsTuple()) {
// For a single-element tuple, take the tuple element shape.
TF_RET_CHECK(shape_with_layout->tuple_shapes_size() == 1);
*instr.mutable_shape() = shape_with_layout->tuple_shapes(0).ToProto();
} else {
*instr.mutable_shape() = shape_with_layout->ToProto();
}
} else {
*instr.mutable_shape() = inferred_shape.ToProto();
}
for (const ReplicaGroup& group : replica_groups) {
*instr.add_replica_groups() = group;
}
if (channel_id.has_value()) {
instr.set_channel_id(channel_id->handle());
}
AddCalledComputation(computation, &instr);
TF_ASSIGN_OR_RETURN(
auto all_reduce,
AddInstruction(std::move(instr), HloOpcode::kAllReduce, operands));
if (operand_shape->IsTuple() && !inferred_shape.IsTuple()) {
// For a single-element tuple, wrap the result into a tuple.
TF_RET_CHECK(operand_shapes.size() == 1);
TF_RET_CHECK(ShapeUtil::Compatible(*operand_shapes[0], inferred_shape));
return Tuple({all_reduce});
}
return all_reduce;
});
}
XlaOp XlaBuilder::ReduceScatter(
XlaOp operand, const XlaComputation& computation, int64_t scatter_dimension,
int64_t shard_count, absl::Span<const ReplicaGroup> replica_groups,
const absl::optional<ChannelHandle>& channel_id,
const absl::optional<Layout>& layout,
const absl::optional<bool> use_global_device_ids) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
std::vector<const Shape*> operand_shapes;
std::vector<XlaOp> operands;
if (operand_shape->IsTuple()) {
if (operand_shape->tuple_shapes_size() == 0) {
return Unimplemented("0 element tuple ReduceScatter is not supported");
}
for (int i = 0; i < operand_shape->tuple_shapes_size(); ++i) {
if (operand_shape->tuple_shapes(i).element_type() !=
operand_shape->tuple_shapes(0).element_type()) {
return Unimplemented(
"All the shapes of a tuple input of ReduceScatter must have "
"the same "
"element type");
}
operand_shapes.push_back(&operand_shape->tuple_shapes(i));
operands.push_back(GetTupleElement(operand, i));
}
} else {
operand_shapes.push_back(operand_shape);
operands.push_back(operand);
}
TF_ASSIGN_OR_RETURN(Shape inferred_shape,
ShapeInference::InferReduceScatterShape(
operand_shapes, scatter_dimension, shard_count));
if (layout) {
*inferred_shape.mutable_layout() = *layout;
instr.set_constrain_layout(true);
}
*instr.mutable_shape() = inferred_shape.ToProto();
AddCalledComputation(computation, &instr);
instr.add_dimensions(scatter_dimension);
for (const ReplicaGroup& group : replica_groups) {
*instr.add_replica_groups() = group;
}
if (channel_id.has_value()) {
instr.set_channel_id(channel_id->handle());
}
if (use_global_device_ids.has_value()) {
instr.set_use_global_device_ids(use_global_device_ids.value());
}
TF_ASSIGN_OR_RETURN(
auto reduce_scatter,
AddInstruction(std::move(instr), HloOpcode::kReduceScatter, {operand}));
return reduce_scatter;
});
}
XlaOp XlaBuilder::AllToAll(XlaOp operand, int64_t split_dimension,
int64_t concat_dimension, int64_t split_count,
absl::Span<const ReplicaGroup> replica_groups,
const absl::optional<Layout>& layout) {
// Array all_to_all may need to violate layout constraint to be legal so use
// the tuple version.
if (layout.has_value()) {
return AllToAllTuple(operand, split_dimension, concat_dimension,
split_count, replica_groups, layout);
}
return AllToAllArray(operand, split_dimension, concat_dimension, split_count,
replica_groups);
}
XlaOp XlaBuilder::AllToAllArray(XlaOp operand, int64_t split_dimension,
int64_t concat_dimension, int64_t split_count,
absl::Span<const ReplicaGroup> replica_groups) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
TF_ASSIGN_OR_RETURN(
const Shape all_to_all_shape,
ShapeInference::InferAllToAllShape(*operand_shape, split_dimension,
concat_dimension, split_count));
HloInstructionProto instr;
*instr.mutable_shape() = operand_shape->ToProto();
if (replica_groups.empty()) {
auto* group = instr.add_replica_groups();
for (int64_t i = 0; i < split_count; ++i) {
group->add_replica_ids(i);
}
} else {
for (const ReplicaGroup& group : replica_groups) {
*instr.add_replica_groups() = group;
}
}
instr.add_dimensions(split_dimension);
TF_ASSIGN_OR_RETURN(
XlaOp all_to_all,
AddInstruction(std::move(instr), HloOpcode::kAllToAll, {operand}));
if (split_dimension == concat_dimension) {
return all_to_all;
}
DimensionVector sizes;
for (int64_t i = 0; i < operand_shape->rank(); ++i) {
if (i != split_dimension) {
sizes.push_back(operand_shape->dimensions(i));
continue;
}
sizes.push_back(split_count);
sizes.push_back(operand_shape->dimensions(i) / split_count);
}
all_to_all = Reshape(all_to_all, sizes);
std::vector<int64_t> permutation;
const auto rank = operand_shape->rank();
permutation.reserve(rank + 1);
for (int64_t i = 0; i < rank; ++i) {
int64_t dim_after_reshape = i >= split_dimension ? i + 1 : i;
if (i == concat_dimension) {
permutation.push_back(split_dimension);
}
permutation.push_back(dim_after_reshape);
}
all_to_all = Transpose(all_to_all, permutation);
return Reshape(all_to_all_shape, all_to_all);
});
}
XlaOp XlaBuilder::AllToAllTuple(absl::Span<const XlaOp> operands,
absl::Span<const ReplicaGroup> replica_groups,
const absl::optional<Layout>& layout) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
TF_ASSIGN_OR_RETURN(auto operand_shapes, this->GetOperandShapes(operands));
std::vector<const Shape*> operand_shape_ptrs;
operand_shape_ptrs.reserve(operand_shapes.size());
absl::c_transform(operand_shapes, std::back_inserter(operand_shape_ptrs),
[](const Shape& shape) { return &shape; });
TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferAllToAllTupleShape(
operand_shape_ptrs));
if (layout) {
TF_RET_CHECK(shape.IsTuple() && !ShapeUtil::IsNestedTuple(shape));
for (int64_t i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) {
const int64_t layout_minor_to_major_size =
layout->minor_to_major().size();
if (layout_minor_to_major_size != shape.tuple_shapes(i).rank()) {
return InvalidArgument(
"Provided layout must be compatible with the operands' shape. "
"The layout is %s, but operand %d has shape %s.",
layout->ToString(), i, shape.tuple_shapes(i).ToString());
}
*(shape.mutable_tuple_shapes(i)->mutable_layout()) = *layout;
}
instr.set_constrain_layout(true);
}
*instr.mutable_shape() = shape.ToProto();
for (const ReplicaGroup& group : replica_groups) {
*instr.add_replica_groups() = group;
}
return AddInstruction(std::move(instr), HloOpcode::kAllToAll, operands);
});
}
XlaOp XlaBuilder::AllToAllTuple(XlaOp operand, int64_t split_dimension,
int64_t concat_dimension, int64_t split_count,
absl::Span<const ReplicaGroup> replica_groups,
const absl::optional<Layout>& layout) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
// The HloInstruction for Alltoall currently only handles the data
// communication: it accepts N already split parts and scatters them to N
// cores, and each core gathers the N received parts into a tuple as the
// output. So here we explicitly split the operand before the hlo alltoall,
// and concat the tuple elements.
//
// First, run shape inference to make sure the shapes are valid.
TF_RETURN_IF_ERROR(
ShapeInference::InferAllToAllShape(*operand_shape, split_dimension,
concat_dimension, split_count)
.status());
// Split into N parts.
std::vector<XlaOp> slices;
slices.reserve(split_count);
const int64_t block_size =
operand_shape->dimensions(split_dimension) / split_count;
for (int i = 0; i < split_count; i++) {
slices.push_back(SliceInDim(operand, /*start_index=*/i * block_size,
/*limit_index=*/(i + 1) * block_size,
/*stride=*/1, /*dimno=*/split_dimension));
}
// Handle data communication.
XlaOp alltoall = this->AllToAllTuple(slices, replica_groups, layout);
// Concat the N received parts.
std::vector<XlaOp> received;
received.reserve(split_count);
for (int i = 0; i < split_count; i++) {
received.push_back(this->GetTupleElement(alltoall, i));
}
return this->ConcatInDim(received, concat_dimension);
});
}
XlaOp XlaBuilder::CollectivePermute(
XlaOp operand,
const std::vector<std::pair<int64_t, int64_t>>& source_target_pairs) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
HloInstructionProto instr;
TF_ASSIGN_OR_RETURN(
Shape shape,
ShapeInference::InferCollectivePermuteShape({operand_shape}));
*instr.mutable_shape() = shape.ToProto();
for (const auto& pair : source_target_pairs) {
auto* proto_pair = instr.add_source_target_pairs();
proto_pair->set_source(pair.first);
proto_pair->set_target(pair.second);
}
return AddInstruction(std::move(instr), HloOpcode::kCollectivePermute,
{operand});
});
}
XlaOp XlaBuilder::ReplicaId() {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
*instr.mutable_shape() = ShapeUtil::MakeShape(U32, {}).ToProto();
return AddInstruction(std::move(instr), HloOpcode::kReplicaId, {});
});
}
XlaOp XlaBuilder::SelectAndScatter(XlaOp operand, const XlaComputation& select,
absl::Span<const int64_t> window_dimensions,
absl::Span<const int64_t> window_strides,
Padding padding, XlaOp source,
XlaOp init_value,
const XlaComputation& scatter) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
std::vector<std::pair<int64_t, int64_t>> padding_values =
MakePadding(operand_shape->dimensions(), window_dimensions,
window_strides, padding);
TF_ASSIGN_OR_RETURN(auto window,
ShapeInference::InferWindowFromDimensions(
window_dimensions, window_strides, padding_values,
/*lhs_dilation=*/{},
/*rhs_dilation=*/{}));
PaddingType padding_type = PADDING_INVALID;
for (int64_t i = 0; i < operand_shape->rank(); ++i) {
if (operand_shape->is_dynamic_dimension(i) &&
!window_util::IsTrivialWindowDimension(window.dimensions(i)) &&
padding == Padding::kSame) {
// SAME padding can create dynamic padding sizes. The padding size
// need to be rewritten by dynamic padder using HloInstructions. We
// create a CustomCall to handle this.
padding_type = PADDING_SAME;
}
}
if (padding_type == PADDING_SAME) {
TF_ASSIGN_OR_RETURN(
HloInstructionProto instr,
SelectAndScatterInternal(operand, select, window_dimensions,
window_strides, padding_values, source,
init_value, scatter));
instr.set_custom_call_target("DynamicSelectAndScatterSamePadding");
return AddInstruction(std::move(instr), HloOpcode::kCustomCall,
{operand, source, init_value});
}
return SelectAndScatterWithGeneralPadding(
operand, select, window_dimensions, window_strides, padding_values,
source, init_value, scatter);
});
}
StatusOr<HloInstructionProto> XlaBuilder::SelectAndScatterInternal(
XlaOp operand, const XlaComputation& select,
absl::Span<const int64_t> window_dimensions,
absl::Span<const int64_t> window_strides,
absl::Span<const std::pair<int64_t, int64_t>> padding, XlaOp source,
XlaOp init_value, const XlaComputation& scatter) {
HloInstructionProto instr;
TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
TF_ASSIGN_OR_RETURN(const Shape* source_shape, GetShapePtr(source));
TF_ASSIGN_OR_RETURN(const Shape* init_shape, GetShapePtr(init_value));
TF_ASSIGN_OR_RETURN(const ProgramShape& select_shape,
select.GetProgramShape());
TF_ASSIGN_OR_RETURN(const ProgramShape& scatter_shape,
scatter.GetProgramShape());
TF_ASSIGN_OR_RETURN(*instr.mutable_window(),
ShapeInference::InferWindowFromDimensions(
window_dimensions, window_strides, padding,
/*lhs_dilation=*/{}, /*rhs_dilation=*/{}));
TF_ASSIGN_OR_RETURN(Shape shape,
ShapeInference::InferSelectAndScatterShape(
*operand_shape, select_shape, instr.window(),
*source_shape, *init_shape, scatter_shape));
*instr.mutable_shape() = shape.ToProto();
AddCalledComputation(select, &instr);
AddCalledComputation(scatter, &instr);
return instr;
}
XlaOp XlaBuilder::SelectAndScatterWithGeneralPadding(
XlaOp operand, const XlaComputation& select,
absl::Span<const int64_t> window_dimensions,
absl::Span<const int64_t> window_strides,
absl::Span<const std::pair<int64_t, int64_t>> padding, XlaOp source,
XlaOp init_value, const XlaComputation& scatter) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(HloInstructionProto instr,
SelectAndScatterInternal(
operand, select, window_dimensions, window_strides,
padding, source, init_value, scatter));
return AddInstruction(std::move(instr), HloOpcode::kSelectAndScatter,
{operand, source, init_value});
});
}
XlaOp XlaBuilder::ReducePrecision(XlaOp operand, const int exponent_bits,
const int mantissa_bits) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
TF_ASSIGN_OR_RETURN(Shape shape,
ShapeInference::InferReducePrecisionShape(
*operand_shape, exponent_bits, mantissa_bits));
return ReducePrecisionInternal(shape, operand, exponent_bits,
mantissa_bits);
});
}
StatusOr<XlaOp> XlaBuilder::ReducePrecisionInternal(const Shape& shape,
XlaOp operand,
const int exponent_bits,
const int mantissa_bits) {
HloInstructionProto instr;
*instr.mutable_shape() = shape.ToProto();
instr.set_exponent_bits(exponent_bits);
instr.set_mantissa_bits(mantissa_bits);
return AddInstruction(std::move(instr), HloOpcode::kReducePrecision,
{operand});
}
void XlaBuilder::Send(XlaOp operand, const ChannelHandle& handle) {
ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
// Send HLO takes two operands: a data operand and a token. Generate the
// token to pass into the send.
// TODO(b/80000000): Remove this when clients have been updated to handle
// tokens.
HloInstructionProto token_instr;
*token_instr.mutable_shape() = ShapeUtil::MakeTokenShape().ToProto();
TF_ASSIGN_OR_RETURN(XlaOp token, AddInstruction(std::move(token_instr),
HloOpcode::kAfterAll, {}));
return SendWithToken(operand, token, handle);
});
}
XlaOp XlaBuilder::SendWithToken(XlaOp operand, XlaOp token,
const ChannelHandle& handle) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
if (handle.type() != ChannelHandle::DEVICE_TO_DEVICE) {
return InvalidArgument("Send must use a device-to-device channel");
}
// Send instruction produces a tuple of {aliased operand, U32 context,
// token}.
HloInstructionProto send_instr;
TF_ASSIGN_OR_RETURN(const Shape* shape, GetShapePtr(operand));
*send_instr.mutable_shape() =
ShapeUtil::MakeTupleShape({*shape, ShapeUtil::MakeShape(U32, {}),
ShapeUtil::MakeTokenShape()})
.ToProto();
send_instr.set_channel_id(handle.handle());
TF_ASSIGN_OR_RETURN(XlaOp send,
AddInstruction(std::move(send_instr), HloOpcode::kSend,
{operand, token}));
HloInstructionProto send_done_instr;
*send_done_instr.mutable_shape() = ShapeUtil::MakeTokenShape().ToProto();
send_done_instr.set_channel_id(handle.handle());
return AddInstruction(std::move(send_done_instr), HloOpcode::kSendDone,
{send});
});
}
XlaOp XlaBuilder::Recv(const Shape& shape, const ChannelHandle& handle) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
// Recv HLO takes a single token operand. Generate the token to pass into
// the Recv and RecvDone instructions.
// TODO(b/80000000): Remove this when clients have been updated to handle
// tokens.
HloInstructionProto token_instr;
*token_instr.mutable_shape() = ShapeUtil::MakeTokenShape().ToProto();
TF_ASSIGN_OR_RETURN(XlaOp token, AddInstruction(std::move(token_instr),
HloOpcode::kAfterAll, {}));
XlaOp recv = RecvWithToken(token, shape, handle);
// The RecvDone instruction produces a tuple of the data and a token
// type. Return XLA op containing the data.
// TODO(b/80000000): Remove this when clients have been updated to handle
// tokens.
HloInstructionProto recv_data;
*recv_data.mutable_shape() = shape.ToProto();
recv_data.set_tuple_index(0);
return AddInstruction(std::move(recv_data), HloOpcode::kGetTupleElement,
{recv});
});
}
XlaOp XlaBuilder::RecvWithToken(XlaOp token, const Shape& shape,
const ChannelHandle& handle) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
if (handle.type() != ChannelHandle::DEVICE_TO_DEVICE) {
return InvalidArgument("Recv must use a device-to-device channel");
}
// Recv instruction produces a tuple of {receive buffer, U32 context,
// token}.
HloInstructionProto recv_instr;
*recv_instr.mutable_shape() =
ShapeUtil::MakeTupleShape(
{shape, ShapeUtil::MakeShape(U32, {}), ShapeUtil::MakeTokenShape()})
.ToProto();
recv_instr.set_channel_id(handle.handle());
TF_ASSIGN_OR_RETURN(XlaOp recv, AddInstruction(std::move(recv_instr),
HloOpcode::kRecv, {token}));
HloInstructionProto recv_done_instr;
*recv_done_instr.mutable_shape() =
ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeTokenShape()})
.ToProto();
recv_done_instr.set_channel_id(handle.handle());
return AddInstruction(std::move(recv_done_instr), HloOpcode::kRecvDone,
{recv});
});
}
XlaOp XlaBuilder::SendToHost(XlaOp operand, XlaOp token,
const Shape& shape_with_layout,
const ChannelHandle& handle) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
if (!LayoutUtil::HasLayout(shape_with_layout)) {
return InvalidArgument("Shape passed to SendToHost must have a layout");
}
TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
if (!ShapeUtil::Compatible(*operand_shape, shape_with_layout)) {
return InvalidArgument(
"SendToHost shape %s must be compatible with operand shape %s",
ShapeUtil::HumanStringWithLayout(shape_with_layout),
ShapeUtil::HumanStringWithLayout(*operand_shape));
}
// TODO(b/111544877): Support tuple shapes.
if (!operand_shape->IsArray()) {
return InvalidArgument("SendToHost only supports array shapes, shape: %s",
ShapeUtil::HumanString(*operand_shape));
}
if (handle.type() != ChannelHandle::DEVICE_TO_HOST) {
return InvalidArgument("SendToHost must use a device-to-host channel");
}
// Send instruction produces a tuple of {aliased operand, U32 context,
// token}.
HloInstructionProto send_instr;
*send_instr.mutable_shape() =
ShapeUtil::MakeTupleShape({shape_with_layout,
ShapeUtil::MakeShape(U32, {}),
ShapeUtil::MakeTokenShape()})
.ToProto();
send_instr.set_channel_id(handle.handle());
send_instr.set_is_host_transfer(true);
TF_ASSIGN_OR_RETURN(XlaOp send,
AddInstruction(std::move(send_instr), HloOpcode::kSend,
{operand, token}));
HloInstructionProto send_done_instr;
*send_done_instr.mutable_shape() = ShapeUtil::MakeTokenShape().ToProto();
send_done_instr.set_channel_id(handle.handle());
send_done_instr.set_is_host_transfer(true);
TF_ASSIGN_OR_RETURN(XlaOp send_done,
AddInstruction(std::move(send_done_instr),
HloOpcode::kSendDone, {send}));
return send_done;
});
}
XlaOp XlaBuilder::RecvFromHost(XlaOp token, const Shape& shape,
const ChannelHandle& handle) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
if (!LayoutUtil::HasLayout(shape)) {
return InvalidArgument("Shape passed to RecvFromHost must have a layout");
}
// TODO(b/111544877): Support tuple shapes.
if (!shape.IsArray()) {
return InvalidArgument(
"RecvFromHost only supports array shapes, shape: %s",
ShapeUtil::HumanString(shape));
}
if (handle.type() != ChannelHandle::HOST_TO_DEVICE) {
return InvalidArgument("RecvFromHost must use a host-to-device channel");
}
// Recv instruction produces a tuple of {receive buffer, U32 context,
// token}.
HloInstructionProto recv_instr;
*recv_instr.mutable_shape() =
ShapeUtil::MakeTupleShape(
{shape, ShapeUtil::MakeShape(U32, {}), ShapeUtil::MakeTokenShape()})
.ToProto();
recv_instr.set_channel_id(handle.handle());
recv_instr.set_is_host_transfer(true);
TF_ASSIGN_OR_RETURN(XlaOp recv, AddInstruction(std::move(recv_instr),
HloOpcode::kRecv, {token}));
HloInstructionProto recv_done_instr;
*recv_done_instr.mutable_shape() =
ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeTokenShape()})
.ToProto();
recv_done_instr.set_channel_id(handle.handle());
recv_done_instr.set_is_host_transfer(true);
return AddInstruction(std::move(recv_done_instr), HloOpcode::kRecvDone,
{recv});
});
}
XlaOp XlaBuilder::GetDimensionSize(XlaOp operand, int64_t dimension) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferGetDimensionSizeShape(
*operand_shape, dimension));
// Calling GetDimensionSize on a static dimension returns a constant
// instruction.
if (!operand_shape->is_dynamic_dimension(dimension)) {
return ConstantR0<int32_t>(this, operand_shape->dimensions(dimension));
}
*instr.mutable_shape() = shape.ToProto();
instr.add_dimensions(dimension);
return AddInstruction(std::move(instr), HloOpcode::kGetDimensionSize,
{operand});
});
}
XlaOp XlaBuilder::RemoveDynamicDimension(XlaOp operand, int64_t dimension) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
Shape shape = *operand_shape;
shape.set_dynamic_dimension(dimension, false);
// Setting an op's dynamic dimension to its static size removes the dynamic
// dimension.
XlaOp static_size =
ConstantR0<int32_t>(this, operand_shape->dimensions(dimension));
return SetDimensionSizeInternal(shape, operand, static_size, dimension);
});
}
XlaOp XlaBuilder::SetDimensionSize(XlaOp operand, XlaOp val,
int64_t dimension) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
TF_ASSIGN_OR_RETURN(const Shape* val_shape, GetShapePtr(val));
TF_ASSIGN_OR_RETURN(Shape shape,
ShapeInference::InferSetDimensionSizeShape(
*operand_shape, *val_shape, dimension));
return SetDimensionSizeInternal(shape, operand, val, dimension);
});
}
StatusOr<XlaOp> XlaBuilder::SetDimensionSizeInternal(const Shape& shape,
XlaOp operand, XlaOp val,
int64_t dimension) {
TF_ASSIGN_OR_RETURN(const HloInstructionProto* val_proto,
LookUpInstruction(val));
if (StringToHloOpcode(val_proto->opcode()).ValueOrDie() ==
HloOpcode::kConstant &&
shape.is_dynamic_dimension(dimension)) {
TF_ASSIGN_OR_RETURN(auto constant_size,
Literal::CreateFromProto(val_proto->literal(), true));
if (constant_size.Get<int32_t>({}) == shape.dimensions(dimension)) {
return operand;
}
}
HloInstructionProto instr;
*instr.mutable_shape() = shape.ToProto();
instr.add_dimensions(dimension);
return AddInstruction(std::move(instr), HloOpcode::kSetDimensionSize,
{operand, val});
}
StatusOr<bool> XlaBuilder::IsConstant(XlaOp operand) const {
TF_RETURN_IF_ERROR(first_error_);
// Verify that the handle is valid.
TF_RETURN_IF_ERROR(LookUpInstruction(operand).status());
bool is_constant = true;
absl::flat_hash_set<int64_t> visited;
IsConstantVisitor(operand.handle(), /*depth=*/0, &visited, &is_constant);
return is_constant;
}
StatusOr<XlaComputation> XlaBuilder::BuildConstantSubGraph(
XlaOp root_op, bool dynamic_dimension_is_minus_one) {
TF_ASSIGN_OR_RETURN(bool is_constant, IsConstant(root_op));
if (!is_constant) {
auto op_status = LookUpInstruction(root_op);
std::string op_string =
op_status.ok() ? op_status.ValueOrDie()->name() : "<unknown operation>";
return InvalidArgument(
"Operand to BuildConstantSubGraph depends on a parameter.\n\n"
" op requested for constant subgraph: %s\n\n"
"This is an internal error that typically happens when the XLA user "
"(e.g. TensorFlow) is attempting to determine a value that must be a "
"compile-time constant (e.g. an array dimension) but it is not capable "
"of being evaluated at XLA compile time.\n\n"
"Please file a usability bug with the framework being used (e.g. "
"TensorFlow).",
op_string);
}
TF_ASSIGN_OR_RETURN(const HloInstructionProto* root,
LookUpInstruction(root_op));
if (VLOG_IS_ON(4)) {
VLOG(4) << "Build constant subgraph for:\n" << OpToString(root_op);
}
HloComputationProto entry;
SetProtoIdAndName(&entry, StrCat(name_, "_compute_constant"), kNameSeparator,
GetNextId());
ProgramShapeProto* program_shape = entry.mutable_program_shape();
*program_shape->mutable_result() = root->shape();
// We use std::set to keep the instruction ids in ascending order (which is
// also a valid dependency order). The related ops will be added to the
// subgraph in the same order.
std::set<int64_t> related_ops;
absl::flat_hash_map<int64_t, int64_t> substitutions;
absl::flat_hash_set<int64_t> related_calls; // Related computations.
std::queue<int64_t> worklist;
worklist.push(root->id());
related_ops.insert(root->id());
while (!worklist.empty()) {
int64_t handle = worklist.front();
worklist.pop();
TF_ASSIGN_OR_RETURN(const HloInstructionProto* instr_proto,
LookUpInstructionByHandle(handle));
auto default_behavior = [&related_ops, &worklist, &related_calls,
instr_proto]() {
for (int64_t id : instr_proto->operand_ids()) {
if (related_ops.insert(id).second) {
worklist.push(id);
}
}
for (int64_t called_id : instr_proto->called_computation_ids()) {
related_calls.insert(called_id);
}
};
if (instr_proto->opcode() ==
HloOpcodeString(HloOpcode::kGetDimensionSize) ||
InstrIsSetBound(instr_proto)) {
int32_t constant_value = -1;
HloInstructionProto const_instr;
if (instr_proto->opcode() ==
HloOpcodeString(HloOpcode::kGetDimensionSize)) {
// At this point, BuildConstantSubGraph should never encounter a
// GetDimensionSize with a dynamic dimension. IsConstant check would
// have failed at the beginning of this function.
//
// Replace GetDimensionSize with a Constant representing the static
// bound of the shape.
int64_t dimension = instr_proto->dimensions(0);
int64_t operand_handle = instr_proto->operand_ids(0);
TF_ASSIGN_OR_RETURN(const HloInstructionProto* operand_proto,
LookUpInstructionByHandle(operand_handle));
if (!(operand_proto->shape().is_dynamic_dimension(dimension) &&
dynamic_dimension_is_minus_one)) {
constant_value = static_cast<int32_t>(
operand_proto->shape().dimensions(dimension));
}
Literal literal = LiteralUtil::CreateR0(constant_value);
*const_instr.mutable_literal() = literal.ToProto();
*const_instr.mutable_shape() = literal.shape().ToProto();
} else {
if (instr_proto->literal().shape().element_type() == TUPLE) {
*const_instr.mutable_literal() =
// First literal of SetBound contains bounds, second literal
// contains dynamism indicators.
instr_proto->literal().tuple_literals(0);
} else {
*const_instr.mutable_literal() = instr_proto->literal();
}
*const_instr.mutable_shape() = instr_proto->shape();
}
*const_instr.mutable_opcode() = HloOpcodeString(HloOpcode::kConstant);
const_instr.set_id(handle);
*const_instr.mutable_name() =
GetFullName(const_instr.opcode(), kNameSeparator, const_instr.id());
*entry.add_instructions() =
const_instr; // Add to the result constant graph.
} else if (instr_proto->opcode() ==
HloOpcodeString(HloOpcode::kGetTupleElement)) {
// Look through GTE(Tuple(..), i).
TF_ASSIGN_OR_RETURN(
const HloInstructionProto* maybe_tuple_instr,
LookUpInstructionByHandle(instr_proto->operand_ids(0)));
if (maybe_tuple_instr->opcode() == HloOpcodeString(HloOpcode::kTuple)) {
int64_t id = maybe_tuple_instr->operand_ids(instr_proto->tuple_index());
// Enqueue any dependencies of `id`.
if (related_ops.insert(id).second) {
worklist.push(id);
}
substitutions[handle] = id;
} else {
default_behavior();
}
} else {
default_behavior();
}
}
// Resolve any substitutions for the root id.
int64_t root_id = root->id();
auto it = substitutions.find(root_id);
while (it != substitutions.end()) {
root_id = it->second;
it = substitutions.find(root_id);
}
entry.set_root_id(root_id);
// Add related ops to the computation.
for (int64_t id : related_ops) {
if (substitutions.find(id) != substitutions.end()) {
// Skip adding this instruction; we will replace references to it with the
// substitution instruction's id.
continue;
}
TF_ASSIGN_OR_RETURN(const HloInstructionProto* instr_src,
LookUpInstructionByHandle(id));
if (instr_src->opcode() == HloOpcodeString(HloOpcode::kGetDimensionSize) ||
InstrIsSetBound(instr_src)) {
continue;
}
HloInstructionProto* instr = entry.add_instructions();
*instr = *instr_src;
// Replace operands in case we have substitutions mapped.
instr->clear_operand_ids();
for (int64_t operand_id : instr_src->operand_ids()) {
auto it = substitutions.find(operand_id);
while (it != substitutions.end()) {
operand_id = it->second;
it = substitutions.find(operand_id);
}
instr->add_operand_ids(operand_id);
}
// Ensures that the instruction names are unique among the graph.
const std::string& new_name =
StrCat(instr->name(), ".", entry.id(), ".", instr->id());
instr->set_name(new_name);
}
XlaComputation computation(entry.id());
HloModuleProto* module = computation.mutable_proto();
module->set_name(entry.name());
module->set_id(entry.id());
module->set_entry_computation_name(entry.name());
module->set_entry_computation_id(entry.id());
*module->mutable_host_program_shape() = *program_shape;
for (auto& e : embedded_) {
if (related_calls.find(e.second.id()) != related_calls.end()) {
*module->add_computations() = e.second;
}
}
*module->add_computations() = std::move(entry);
if (VLOG_IS_ON(4)) {
VLOG(4) << "Constant computation:\n" << module->DebugString();
}
return std::move(computation);
}
std::unique_ptr<XlaBuilder> XlaBuilder::CreateSubBuilder(
const std::string& computation_name) {
auto sub_builder = absl::make_unique<XlaBuilder>(computation_name);
sub_builder->parent_builder_ = this;
sub_builder->die_immediately_on_error_ = this->die_immediately_on_error_;
return sub_builder;
}
/* static */ ConvolutionDimensionNumbers
XlaBuilder::CreateDefaultConvDimensionNumbers(int num_spatial_dims) {
ConvolutionDimensionNumbers dimension_numbers;
dimension_numbers.set_input_batch_dimension(kConvBatchDimension);
dimension_numbers.set_input_feature_dimension(kConvFeatureDimension);
dimension_numbers.set_output_batch_dimension(kConvBatchDimension);
dimension_numbers.set_output_feature_dimension(kConvFeatureDimension);
dimension_numbers.set_kernel_output_feature_dimension(
kConvKernelOutputDimension);
dimension_numbers.set_kernel_input_feature_dimension(
kConvKernelInputDimension);
for (int i = 0; i < num_spatial_dims; ++i) {
dimension_numbers.add_input_spatial_dimensions(i + 2);
dimension_numbers.add_kernel_spatial_dimensions(i + 2);
dimension_numbers.add_output_spatial_dimensions(i + 2);
}
return dimension_numbers;
}
/* static */ Status XlaBuilder::Validate(
const ConvolutionDimensionNumbers& dnum) {
if (dnum.input_spatial_dimensions_size() < 2) {
return FailedPrecondition("input spacial dimension < 2: %d",
dnum.input_spatial_dimensions_size());
}
if (dnum.kernel_spatial_dimensions_size() < 2) {
return FailedPrecondition("kernel spacial dimension < 2: %d",
dnum.kernel_spatial_dimensions_size());
}
if (dnum.output_spatial_dimensions_size() < 2) {
return FailedPrecondition("output spacial dimension < 2: %d",
dnum.output_spatial_dimensions_size());
}
if (std::set<int64_t>(
{dnum.input_batch_dimension(), dnum.input_feature_dimension(),
dnum.input_spatial_dimensions(0), dnum.input_spatial_dimensions(1)})
.size() != 4) {
return FailedPrecondition(
"dimension numbers for the input are not unique: (%d, %d, %d, "
"%d)",
dnum.input_batch_dimension(), dnum.input_feature_dimension(),
dnum.input_spatial_dimensions(0), dnum.input_spatial_dimensions(1));
}
if (std::set<int64_t>({dnum.kernel_output_feature_dimension(),
dnum.kernel_input_feature_dimension(),
dnum.kernel_spatial_dimensions(0),
dnum.kernel_spatial_dimensions(1)})
.size() != 4) {
return FailedPrecondition(
"dimension numbers for the weight are not unique: (%d, %d, %d, "
"%d)",
dnum.kernel_output_feature_dimension(),
dnum.kernel_input_feature_dimension(),
dnum.kernel_spatial_dimensions(0), dnum.kernel_spatial_dimensions(1));
}
if (std::set<int64_t>({dnum.output_batch_dimension(),
dnum.output_feature_dimension(),
dnum.output_spatial_dimensions(0),
dnum.output_spatial_dimensions(1)})
.size() != 4) {
return FailedPrecondition(
"dimension numbers for the output are not unique: (%d, %d, %d, "
"%d)",
dnum.output_batch_dimension(), dnum.output_feature_dimension(),
dnum.output_spatial_dimensions(0), dnum.output_spatial_dimensions(1));
}
return ::tensorflow::OkStatus();
}
StatusOr<XlaOp> XlaBuilder::AddInstruction(HloInstructionProto&& instr,
HloOpcode opcode,
absl::Span<const XlaOp> operands) {
TF_RETURN_IF_ERROR(first_error_);
const int64_t handle = GetNextId();
instr.set_id(handle);
instr.set_opcode(HloOpcodeString(opcode));
if (instr.name().empty()) {
instr.set_name(instr.opcode());
}
for (const auto& operand : operands) {
if (operand.builder_ == nullptr) {
return InvalidArgument("invalid XlaOp with handle %d", operand.handle());
}
if (operand.builder_ != this) {
return InvalidArgument("Do not add XlaOp from builder %s to builder %s",
operand.builder_->name(), this->name());
}
instr.add_operand_ids(operand.handle());
}
if (one_shot_metadata_.has_value()) {
*instr.mutable_metadata() = one_shot_metadata_.value();
one_shot_metadata_.reset();
} else {
*instr.mutable_metadata() = metadata_;
}
if (sharding_) {
*instr.mutable_sharding() = *sharding_;
}
*instr.mutable_frontend_attributes() = frontend_attributes_;
handle_to_index_[handle] = instructions_.size();
instructions_.push_back(std::move(instr));
instruction_shapes_.push_back(
absl::make_unique<Shape>(instructions_.back().shape()));
XlaOp op(handle, this);
return op;
}
StatusOr<XlaOp> XlaBuilder::AddOpWithShape(HloOpcode opcode, const Shape& shape,
absl::Span<const XlaOp> operands) {
HloInstructionProto instr;
*instr.mutable_shape() = shape.ToProto();
return AddInstruction(std::move(instr), opcode, operands);
}
void XlaBuilder::AddCalledComputation(const XlaComputation& computation,
HloInstructionProto* instr) {
absl::flat_hash_map<int64_t, int64_t> remapped_ids;
std::vector<HloComputationProto> imported_computations;
imported_computations.reserve(computation.proto().computations_size());
// Before we import the computations by remapping IDs, and capturing the
// old->new mappings in remapped_ids.
for (const HloComputationProto& e : computation.proto().computations()) {
HloComputationProto new_computation(e);
int64_t computation_id = GetNextId();
remapped_ids[new_computation.id()] = computation_id;
SetProtoIdAndName(&new_computation,
GetBaseName(new_computation.name(), kNameSeparator),
kNameSeparator, computation_id);
for (auto& instruction : *new_computation.mutable_instructions()) {
int64_t instruction_id = GetNextId();
remapped_ids[instruction.id()] = instruction_id;
SetProtoIdAndName(&instruction,
GetBaseName(instruction.name(), kNameSeparator),
kNameSeparator, instruction_id);
}
new_computation.set_root_id(remapped_ids.at(new_computation.root_id()));
imported_computations.push_back(std::move(new_computation));
}
// Once we have imported all the computations, and captured all the ID
// mappings, we go back and fixup the IDs in the imported computations.
instr->add_called_computation_ids(
remapped_ids.at(computation.proto().entry_computation_id()));
for (auto& imported_computation : imported_computations) {
for (auto& instruction : *imported_computation.mutable_instructions()) {
for (auto& operand_id : *instruction.mutable_operand_ids()) {
operand_id = remapped_ids.at(operand_id);
}
for (auto& control_predecessor_id :
*instruction.mutable_control_predecessor_ids()) {
control_predecessor_id = remapped_ids.at(control_predecessor_id);
}
for (auto& called_computation_id :
*instruction.mutable_called_computation_ids()) {
called_computation_id = remapped_ids.at(called_computation_id);
}
}
int64_t computation_id = imported_computation.id();
for (int64_t i = 0; i < imported_computation.instructions_size(); ++i) {
ImportedInstruction imported_instruction;
imported_instruction.computation_id = computation_id;
imported_instruction.instruction_index = i;
handle_to_imported_index_.insert(
{imported_computation.instructions(i).id(), imported_instruction});
}
embedded_.insert({computation_id, std::move(imported_computation)});
}
}
StatusOr<const HloInstructionProto*> XlaBuilder::LookUpInstruction(
const XlaOp op) const {
TF_RETURN_IF_ERROR(first_error_);
return LookUpInstructionInternal<const HloInstructionProto*>(op);
}
StatusOr<const HloInstructionProto*> XlaBuilder::LookUpInstructionByHandle(
int64_t handle) const {
return LookUpInstructionByHandleInternal<const HloInstructionProto*>(handle);
}
StatusOr<HloInstructionProto*> XlaBuilder::LookUpMutableInstruction(
const XlaOp op) {
TF_RETURN_IF_ERROR(first_error_);
return LookUpInstructionInternal<HloInstructionProto*>(op);
}
StatusOr<HloInstructionProto*> XlaBuilder::LookUpMutableInstructionByHandle(
int64_t handle) {
return LookUpInstructionByHandleInternal<HloInstructionProto*>(handle);
}
// Enqueues a "retrieve parameter value" instruction for a parameter that was
// passed to the computation.
XlaOp Parameter(XlaBuilder* builder, int64_t parameter_number,
const Shape& shape, const std::string& name) {
std::vector<bool> empty_bools;
return Parameter(builder, parameter_number, shape, name, empty_bools);
}
XlaOp Parameter(XlaBuilder* builder, int64_t parameter_number,
const Shape& shape, const std::string& name,
const std::vector<bool>& replicated_at_leaf_buffers) {
return builder->Parameter(parameter_number, shape, name,
replicated_at_leaf_buffers);
}
// Enqueues a constant with the value of the given literal onto the
// computation.
XlaOp ConstantLiteral(XlaBuilder* builder, const LiteralSlice& literal) {
return builder->ConstantLiteral(literal);
}
XlaOp Broadcast(const XlaOp operand,
absl::Span<const int64_t> broadcast_sizes) {
return operand.builder()->Broadcast(operand, broadcast_sizes);
}
XlaOp BroadcastInDim(const XlaOp operand,
const absl::Span<const int64_t> out_dim_size,
const absl::Span<const int64_t> broadcast_dimensions) {
return operand.builder()->BroadcastInDim(operand, out_dim_size,
broadcast_dimensions);
}
XlaOp Copy(const XlaOp operand) {
return operand.builder()->UnaryOp(HloOpcode::kCopy, operand);
}
XlaOp Pad(const XlaOp operand, const XlaOp padding_value,
const PaddingConfig& padding_config) {
return operand.builder()->Pad(operand, padding_value, padding_config);
}
XlaOp PadInDim(XlaOp operand, XlaOp padding_value, int64_t dimno,
int64_t pad_lo, int64_t pad_hi) {
return operand.builder()->PadInDim(operand, padding_value, dimno, pad_lo,
pad_hi);
}
XlaOp Reshape(const XlaOp operand, absl::Span<const int64_t> dimensions,
absl::Span<const int64_t> new_sizes) {
return operand.builder()->Reshape(operand, dimensions, new_sizes);
}
XlaOp Reshape(const XlaOp operand, absl::Span<const int64_t> new_sizes) {
return operand.builder()->Reshape(operand, new_sizes);
}
XlaOp Reshape(const Shape& shape, XlaOp operand) {
return operand.builder()->Reshape(shape, operand);
}
XlaOp DynamicReshape(XlaOp operand, absl::Span<const XlaOp> dim_sizes,
absl::Span<const int64_t> new_size_bounds,
const std::vector<bool>& dims_are_dynamic) {
return operand.builder()->DynamicReshape(operand, dim_sizes, new_size_bounds,
dims_are_dynamic);
}
XlaOp ReshapeWithInferredDimension(XlaOp operand,
absl::Span<const int64_t> new_sizes,
int64_t inferred_dimension) {
return operand.builder()->Reshape(operand, new_sizes, inferred_dimension);
}
XlaOp Collapse(const XlaOp operand, absl::Span<const int64_t> dimensions) {
return operand.builder()->Collapse(operand, dimensions);
}
XlaOp Slice(const XlaOp operand, absl::Span<const int64_t> start_indices,
absl::Span<const int64_t> limit_indices,
absl::Span<const int64_t> strides) {
return operand.builder()->Slice(operand, start_indices, limit_indices,
strides);
}
XlaOp SliceInDim(const XlaOp operand, int64_t start_index, int64_t limit_index,
int64_t stride, int64_t dimno) {
return operand.builder()->SliceInDim(operand, start_index, limit_index,
stride, dimno);
}
XlaOp DynamicSlice(const XlaOp operand, absl::Span<const XlaOp> start_indices,
absl::Span<const int64_t> slice_sizes) {
return operand.builder()->DynamicSlice(operand, start_indices, slice_sizes);
}
XlaOp DynamicUpdateSlice(const XlaOp operand, const XlaOp update,
absl::Span<const XlaOp> start_indices) {
return operand.builder()->DynamicUpdateSlice(operand, update, start_indices);
}
XlaOp ConcatInDim(XlaBuilder* builder, absl::Span<const XlaOp> operands,
int64_t dimension) {
return builder->ConcatInDim(operands, dimension);
}
XlaOp Select(const XlaOp pred, const XlaOp on_true, const XlaOp on_false) {
return pred.builder()->Select(pred, on_true, on_false);
}
XlaOp Tuple(XlaBuilder* builder, absl::Span<const XlaOp> elements) {
return builder->Tuple(elements);
}
XlaOp GetTupleElement(const XlaOp tuple_data, int64_t index) {
return tuple_data.builder()->GetTupleElement(tuple_data, index);
}
XlaOp Eq(const XlaOp lhs, const XlaOp rhs,
absl::Span<const int64_t> broadcast_dimensions) {
return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kEq);
}
static XlaOp CompareTotalOrder(const XlaOp lhs, const XlaOp rhs,
absl::Span<const int64_t> broadcast_dimensions,
ComparisonDirection comparison_direction) {
auto b = lhs.builder();
return b->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(auto operand_shape, b->GetShape(lhs));
auto operand_element_type = operand_shape.element_type();
auto compare_type =
primitive_util::IsFloatingPointType(operand_element_type)
? Comparison::Type::kFloatTotalOrder
: Comparison::DefaultComparisonType(operand_element_type);
return Compare(lhs, rhs, broadcast_dimensions, comparison_direction,
compare_type);
});
}
XlaOp EqTotalOrder(const XlaOp lhs, const XlaOp rhs,
absl::Span<const int64_t> broadcast_dimensions) {
return CompareTotalOrder(lhs, rhs, broadcast_dimensions,
ComparisonDirection::kEq);
}
XlaOp Ne(const XlaOp lhs, const XlaOp rhs,
absl::Span<const int64_t> broadcast_dimensions) {
return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kNe);
}
XlaOp NeTotalOrder(const XlaOp lhs, const XlaOp rhs,
absl::Span<const int64_t> broadcast_dimensions) {
return CompareTotalOrder(lhs, rhs, broadcast_dimensions,
ComparisonDirection::kNe);
}
XlaOp Ge(const XlaOp lhs, const XlaOp rhs,
absl::Span<const int64_t> broadcast_dimensions) {
return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kGe);
}
XlaOp GeTotalOrder(const XlaOp lhs, const XlaOp rhs,
absl::Span<const int64_t> broadcast_dimensions) {
return CompareTotalOrder(lhs, rhs, broadcast_dimensions,
ComparisonDirection::kGe);
}
XlaOp Gt(const XlaOp lhs, const XlaOp rhs,
absl::Span<const int64_t> broadcast_dimensions) {
return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kGt);
}
XlaOp GtTotalOrder(const XlaOp lhs, const XlaOp rhs,
absl::Span<const int64_t> broadcast_dimensions) {
return CompareTotalOrder(lhs, rhs, broadcast_dimensions,
ComparisonDirection::kGt);
}
XlaOp Le(const XlaOp lhs, const XlaOp rhs,
absl::Span<const int64_t> broadcast_dimensions) {
return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kLe);
}
XlaOp LeTotalOrder(const XlaOp lhs, const XlaOp rhs,
absl::Span<const int64_t> broadcast_dimensions) {
return CompareTotalOrder(lhs, rhs, broadcast_dimensions,
ComparisonDirection::kLe);
}
XlaOp Lt(const XlaOp lhs, const XlaOp rhs,
absl::Span<const int64_t> broadcast_dimensions) {
return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kLt);
}
XlaOp LtTotalOrder(const XlaOp lhs, const XlaOp rhs,
absl::Span<const int64_t> broadcast_dimensions) {
return CompareTotalOrder(lhs, rhs, broadcast_dimensions,
ComparisonDirection::kLt);
}
XlaOp Compare(const XlaOp lhs, const XlaOp rhs,
absl::Span<const int64_t> broadcast_dimensions,
ComparisonDirection direction) {
return lhs.builder()->BinaryOp(HloOpcode::kCompare, lhs, rhs,
broadcast_dimensions, direction);
}
XlaOp Compare(const XlaOp lhs, const XlaOp rhs,
absl::Span<const int64_t> broadcast_dimensions,
ComparisonDirection direction, Comparison::Type compare_type) {
return lhs.builder()->BinaryOp(HloOpcode::kCompare, lhs, rhs,
broadcast_dimensions, direction, compare_type);
}
XlaOp Compare(const XlaOp lhs, const XlaOp rhs, ComparisonDirection direction) {
return Compare(lhs, rhs, {}, direction);
}
XlaOp Dot(const XlaOp lhs, const XlaOp rhs,
const PrecisionConfig* precision_config,
absl::optional<PrimitiveType> preferred_element_type) {
return lhs.builder()->Dot(lhs, rhs, precision_config, preferred_element_type);
}
XlaOp DotGeneral(const XlaOp lhs, const XlaOp rhs,
const DotDimensionNumbers& dimension_numbers,
const PrecisionConfig* precision_config,
absl::optional<PrimitiveType> preferred_element_type) {
return lhs.builder()->DotGeneral(lhs, rhs, dimension_numbers,
precision_config, preferred_element_type);
}
XlaOp Conv(const XlaOp lhs, const XlaOp rhs,
absl::Span<const int64_t> window_strides, Padding padding,
int64_t feature_group_count, int64_t batch_group_count,
const PrecisionConfig* precision_config,
absl::optional<PrimitiveType> preferred_element_type) {
return lhs.builder()->Conv(lhs, rhs, window_strides, padding,
feature_group_count, batch_group_count,
precision_config, preferred_element_type);
}
XlaOp ConvWithGeneralPadding(
const XlaOp lhs, const XlaOp rhs, absl::Span<const int64_t> window_strides,
absl::Span<const std::pair<int64_t, int64_t>> padding,
int64_t feature_group_count, int64_t batch_group_count,
const PrecisionConfig* precision_config,
absl::optional<PrimitiveType> preferred_element_type) {
return lhs.builder()->ConvWithGeneralPadding(
lhs, rhs, window_strides, padding, feature_group_count, batch_group_count,
precision_config, preferred_element_type);
}
XlaOp ConvWithGeneralDimensions(
const XlaOp lhs, const XlaOp rhs, absl::Span<const int64_t> window_strides,
Padding padding, const ConvolutionDimensionNumbers& dimension_numbers,
int64_t feature_group_count, int64_t batch_group_count,
const PrecisionConfig* precision_config,
absl::optional<PrimitiveType> preferred_element_type) {
return lhs.builder()->ConvWithGeneralDimensions(
lhs, rhs, window_strides, padding, dimension_numbers, feature_group_count,
batch_group_count, precision_config, preferred_element_type);
}
XlaOp ConvGeneral(const XlaOp lhs, const XlaOp rhs,
absl::Span<const int64_t> window_strides,
absl::Span<const std::pair<int64_t, int64_t>> padding,
const ConvolutionDimensionNumbers& dimension_numbers,
int64_t feature_group_count, int64_t batch_group_count,
const PrecisionConfig* precision_config,
absl::optional<PrimitiveType> preferred_element_type) {
return lhs.builder()->ConvGeneral(
lhs, rhs, window_strides, padding, dimension_numbers, feature_group_count,
batch_group_count, precision_config, preferred_element_type);
}
XlaOp ConvGeneralDilated(const XlaOp lhs, const XlaOp rhs,
absl::Span<const int64_t> window_strides,
absl::Span<const std::pair<int64_t, int64_t>> padding,
absl::Span<const int64_t> lhs_dilation,
absl::Span<const int64_t> rhs_dilation,
const ConvolutionDimensionNumbers& dimension_numbers,
int64_t feature_group_count, int64_t batch_group_count,
const PrecisionConfig* precision_config,
absl::optional<PrimitiveType> preferred_element_type) {
return lhs.builder()->ConvGeneralDilated(
lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation,
dimension_numbers, feature_group_count, batch_group_count,
precision_config, preferred_element_type);
}
XlaOp DynamicConvInputGrad(
XlaOp input_sizes, const XlaOp lhs, const XlaOp rhs,
absl::Span<const int64_t> window_strides,
absl::Span<const std::pair<int64_t, int64_t>> padding,
absl::Span<const int64_t> lhs_dilation,
absl::Span<const int64_t> rhs_dilation,
const ConvolutionDimensionNumbers& dimension_numbers,
int64_t feature_group_count, int64_t batch_group_count,
const PrecisionConfig* precision_config, PaddingType padding_type,
absl::optional<PrimitiveType> preferred_element_type) {
return lhs.builder()->DynamicConvInputGrad(
input_sizes, lhs, rhs, window_strides, padding, lhs_dilation,
rhs_dilation, dimension_numbers, feature_group_count, batch_group_count,
precision_config, padding_type, preferred_element_type);
}
XlaOp DynamicConvKernelGrad(
XlaOp activations, XlaOp gradients,
absl::Span<const int64_t> window_strides,
absl::Span<const std::pair<int64_t, int64_t>> padding,
absl::Span<const int64_t> lhs_dilation,
absl::Span<const int64_t> rhs_dilation,
const ConvolutionDimensionNumbers& dimension_numbers,
int64_t feature_group_count, int64_t batch_group_count,
const PrecisionConfig* precision_config, PaddingType padding_type,
absl::optional<PrimitiveType> preferred_element_type) {
return activations.builder()->DynamicConvKernelGrad(
activations, gradients, window_strides, padding, lhs_dilation,
rhs_dilation, dimension_numbers, feature_group_count, batch_group_count,
precision_config, padding_type, preferred_element_type);
}
XlaOp DynamicConvForward(const XlaOp lhs, const XlaOp rhs,
absl::Span<const int64_t> window_strides,
absl::Span<const std::pair<int64_t, int64_t>> padding,
absl::Span<const int64_t> lhs_dilation,
absl::Span<const int64_t> rhs_dilation,
const ConvolutionDimensionNumbers& dimension_numbers,
int64_t feature_group_count, int64_t batch_group_count,
const PrecisionConfig* precision_config,
PaddingType padding_type,
absl::optional<PrimitiveType> preferred_element_type) {
return lhs.builder()->DynamicConvForward(
lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation,
dimension_numbers, feature_group_count, batch_group_count,
precision_config, padding_type, preferred_element_type);
}
XlaOp Fft(const XlaOp operand, FftType fft_type,
absl::Span<const int64_t> fft_length) {
return operand.builder()->Fft(operand, fft_type, fft_length);
}
XlaOp TriangularSolve(XlaOp a, XlaOp b, bool left_side, bool lower,
bool unit_diagonal,
TriangularSolveOptions::Transpose transpose_a) {
XlaBuilder* builder = a.builder();
return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(const Shape* a_shape, builder->GetShapePtr(a));
TF_ASSIGN_OR_RETURN(const Shape* b_shape, builder->GetShapePtr(b));
xla::TriangularSolveOptions options;
options.set_left_side(left_side);
options.set_lower(lower);
options.set_unit_diagonal(unit_diagonal);
options.set_transpose_a(transpose_a);
TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferTriangularSolveShape(
*a_shape, *b_shape, options));
return builder->TriangularSolveInternal(shape, a, b, std::move(options));
});
}
XlaOp Cholesky(XlaOp a, bool lower) {
XlaBuilder* builder = a.builder();
return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(const Shape* a_shape, builder->GetShapePtr(a));
TF_ASSIGN_OR_RETURN(Shape shape,
ShapeInference::InferCholeskyShape(*a_shape));
return builder->CholeskyInternal(shape, a, lower);
});
}
XlaOp Infeed(XlaBuilder* builder, const Shape& shape,
const std::string& config) {
return builder->Infeed(shape, config);
}
void Outfeed(const XlaOp operand, const Shape& shape_with_layout,
const std::string& outfeed_config) {
return operand.builder()->Outfeed(operand, shape_with_layout, outfeed_config);
}
XlaOp Call(XlaBuilder* builder, const XlaComputation& computation,
absl::Span<const XlaOp> operands) {
return builder->Call(computation, operands);
}
XlaOp CustomCall(
XlaBuilder* builder, const std::string& call_target_name,
absl::Span<const XlaOp> operands, const Shape& shape,
const std::string& opaque, bool has_side_effect,
absl::Span<const std::pair<ShapeIndex, std::pair<int64_t, ShapeIndex>>>
output_operand_aliasing,
const Literal* literal, CustomCallSchedule schedule,
CustomCallApiVersion api_version) {
return builder->CustomCall(call_target_name, operands, shape, opaque,
/*operand_shapes_with_layout=*/absl::nullopt,
has_side_effect, output_operand_aliasing, literal,
/*window=*/absl::nullopt, /*dnums=*/absl::nullopt,
schedule, api_version);
}
XlaOp CustomCallWithComputation(
XlaBuilder* builder, const std::string& call_target_name,
absl::Span<const XlaOp> operands, const XlaComputation& computation,
const Shape& shape, const std::string& opaque, bool has_side_effect,
absl::Span<const std::pair<ShapeIndex, std::pair<int64_t, ShapeIndex>>>
output_operand_aliasing,
const Literal* literal, CustomCallSchedule schedule,
CustomCallApiVersion api_version) {
return builder->CustomCall(
call_target_name, operands, computation, shape, opaque,
/*operand_shapes_with_layout=*/absl::nullopt, has_side_effect,
output_operand_aliasing, literal, schedule, api_version);
}
XlaOp CustomCallWithLayout(
XlaBuilder* builder, const std::string& call_target_name,
absl::Span<const XlaOp> operands, const Shape& shape,
absl::Span<const Shape> operand_shapes_with_layout,
const std::string& opaque, bool has_side_effect,
absl::Span<const std::pair<ShapeIndex, std::pair<int64_t, ShapeIndex>>>
output_operand_aliasing,
const Literal* literal, CustomCallSchedule schedule,
CustomCallApiVersion api_version) {
return builder->CustomCall(
call_target_name, operands, shape, opaque, operand_shapes_with_layout,
has_side_effect, output_operand_aliasing, literal,
/*window=*/absl::nullopt, /*dnums=*/absl::nullopt, schedule, api_version);
}
XlaOp CustomCallWithConvDnums(
XlaBuilder* builder, const std::string& call_target_name,
absl::Span<const XlaOp> operands, const Shape& shape,
absl::Span<const Shape> operand_shapes_with_layout,
const std::string& opaque, bool has_side_effect,
absl::Span<const std::pair<ShapeIndex, std::pair<int64_t, ShapeIndex>>>
output_operand_aliasing,
const Literal* literal, Window window, ConvolutionDimensionNumbers dnums,
CustomCallSchedule schedule, CustomCallApiVersion api_version) {
absl::optional<absl::Span<const Shape>> maybe_operand_shapes;
if (!operand_shapes_with_layout.empty()) {
maybe_operand_shapes = operand_shapes_with_layout;
}
return builder->CustomCall(call_target_name, operands, shape, opaque,
maybe_operand_shapes, has_side_effect,
output_operand_aliasing, literal, window, dnums,
schedule, api_version);
}
XlaOp OptimizationBarrier(XlaOp operand) {
return operand.builder()->OptimizationBarrier(operand);
}
XlaOp Complex(const XlaOp lhs, const XlaOp rhs,
absl::Span<const int64_t> broadcast_dimensions) {
return lhs.builder()->BinaryOp(HloOpcode::kComplex, lhs, rhs,
broadcast_dimensions);
}
XlaOp Conj(const XlaOp operand) {
return Complex(Real(operand), Neg(Imag(operand)));
}
XlaOp Add(const XlaOp lhs, const XlaOp rhs,
absl::Span<const int64_t> broadcast_dimensions) {
return lhs.builder()->BinaryOp(HloOpcode::kAdd, lhs, rhs,
broadcast_dimensions);
}
XlaOp Sub(const XlaOp lhs, const XlaOp rhs,
absl::Span<const int64_t> broadcast_dimensions) {
return lhs.builder()->BinaryOp(HloOpcode::kSubtract, lhs, rhs,
broadcast_dimensions);
}
XlaOp Mul(const XlaOp lhs, const XlaOp rhs,
absl::Span<const int64_t> broadcast_dimensions) {
return lhs.builder()->BinaryOp(HloOpcode::kMultiply, lhs, rhs,
broadcast_dimensions);
}
XlaOp Div(const XlaOp lhs, const XlaOp rhs,
absl::Span<const int64_t> broadcast_dimensions) {
return lhs.builder()->BinaryOp(HloOpcode::kDivide, lhs, rhs,
broadcast_dimensions);
}
XlaOp Rem(const XlaOp lhs, const XlaOp rhs,
absl::Span<const int64_t> broadcast_dimensions) {
return lhs.builder()->BinaryOp(HloOpcode::kRemainder, lhs, rhs,
broadcast_dimensions);
}
XlaOp Max(const XlaOp lhs, const XlaOp rhs,
absl::Span<const int64_t> broadcast_dimensions) {
return lhs.builder()->BinaryOp(HloOpcode::kMaximum, lhs, rhs,
broadcast_dimensions);
}
XlaOp Min(const XlaOp lhs, const XlaOp rhs,
absl::Span<const int64_t> broadcast_dimensions) {
return lhs.builder()->BinaryOp(HloOpcode::kMinimum, lhs, rhs,
broadcast_dimensions);
}
XlaOp And(const XlaOp lhs, const XlaOp rhs,
absl::Span<const int64_t> broadcast_dimensions) {
return lhs.builder()->BinaryOp(HloOpcode::kAnd, lhs, rhs,
broadcast_dimensions);
}
XlaOp Or(const XlaOp lhs, const XlaOp rhs,
absl::Span<const int64_t> broadcast_dimensions) {
return lhs.builder()->BinaryOp(HloOpcode::kOr, lhs, rhs,
broadcast_dimensions);
}
XlaOp Xor(const XlaOp lhs, const XlaOp rhs,
absl::Span<const int64_t> broadcast_dimensions) {
return lhs.builder()->BinaryOp(HloOpcode::kXor, lhs, rhs,
broadcast_dimensions);
}
XlaOp Not(const XlaOp operand) {
return operand.builder()->UnaryOp(HloOpcode::kNot, operand);
}
XlaOp PopulationCount(const XlaOp operand) {
return operand.builder()->UnaryOp(HloOpcode::kPopulationCount, operand);
}
XlaOp ShiftLeft(const XlaOp lhs, const XlaOp rhs,
absl::Span<const int64_t> broadcast_dimensions) {
return lhs.builder()->BinaryOp(HloOpcode::kShiftLeft, lhs, rhs,
broadcast_dimensions);
}
XlaOp ShiftRightArithmetic(const XlaOp lhs, const XlaOp rhs,
absl::Span<const int64_t> broadcast_dimensions) {
return lhs.builder()->BinaryOp(HloOpcode::kShiftRightArithmetic, lhs, rhs,
broadcast_dimensions);
}
XlaOp ShiftRightLogical(const XlaOp lhs, const XlaOp rhs,
absl::Span<const int64_t> broadcast_dimensions) {
return lhs.builder()->BinaryOp(HloOpcode::kShiftRightLogical, lhs, rhs,
broadcast_dimensions);
}
XlaOp Reduce(const XlaOp operand, const XlaOp init_value,
const XlaComputation& computation,
absl::Span<const int64_t> dimensions_to_reduce) {
return operand.builder()->Reduce(operand, init_value, computation,
dimensions_to_reduce);
}
// Reduces several arrays simultaneously among the provided dimensions, given
// "computation" as a reduction operator.
XlaOp Reduce(XlaBuilder* builder, absl::Span<const XlaOp> operands,
absl::Span<const XlaOp> init_values,
const XlaComputation& computation,
absl::Span<const int64_t> dimensions_to_reduce) {
return builder->Reduce(operands, init_values, computation,
dimensions_to_reduce);
}
XlaOp ReduceAll(const XlaOp operand, const XlaOp init_value,
const XlaComputation& computation) {
return operand.builder()->ReduceAll(operand, init_value, computation);
}
XlaOp ReduceWindow(const XlaOp operand, const XlaOp init_value,
const XlaComputation& computation,
absl::Span<const int64_t> window_dimensions,
absl::Span<const int64_t> window_strides, Padding padding) {
return operand.builder()->ReduceWindow(operand, init_value, computation,
window_dimensions, window_strides,
padding);
}
XlaOp ReduceWindow(absl::Span<const XlaOp> operands,
absl::Span<const XlaOp> init_values,
const XlaComputation& computation,
absl::Span<const int64_t> window_dimensions,
absl::Span<const int64_t> window_strides, Padding padding) {
CHECK(!operands.empty());
return operands[0].builder()->ReduceWindow(operands, init_values, computation,
window_dimensions, window_strides,
padding);
}
XlaOp ReduceWindowWithGeneralPadding(
const XlaOp operand, const XlaOp init_value,
const XlaComputation& computation,
absl::Span<const int64_t> window_dimensions,
absl::Span<const int64_t> window_strides,
absl::Span<const int64_t> base_dilations,
absl::Span<const int64_t> window_dilations,
absl::Span<const std::pair<int64_t, int64_t>> padding) {
return operand.builder()->ReduceWindowWithGeneralPadding(
absl::MakeSpan(&operand, 1), absl::MakeSpan(&init_value, 1), computation,
window_dimensions, window_strides, base_dilations, window_dilations,
padding);
}
XlaOp ReduceWindowWithGeneralPadding(
absl::Span<const XlaOp> operands, absl::Span<const XlaOp> init_values,
const XlaComputation& computation,
absl::Span<const int64_t> window_dimensions,
absl::Span<const int64_t> window_strides,
absl::Span<const int64_t> base_dilations,
absl::Span<const int64_t> window_dilations,
absl::Span<const std::pair<int64_t, int64_t>> padding) {
CHECK(!operands.empty());
return operands[0].builder()->ReduceWindowWithGeneralPadding(
operands, init_values, computation, window_dimensions, window_strides,
base_dilations, window_dilations, padding);
}
XlaOp AllGather(const XlaOp operand, int64_t all_gather_dimension,
int64_t shard_count,
absl::Span<const ReplicaGroup> replica_groups,
const absl::optional<ChannelHandle>& channel_id,
const absl::optional<Layout>& layout,
const absl::optional<bool> use_global_device_ids) {
return operand.builder()->AllGather(operand, all_gather_dimension,
shard_count, replica_groups, channel_id,
layout, use_global_device_ids);
}
XlaOp CrossReplicaSum(const XlaOp operand,
absl::Span<const ReplicaGroup> replica_groups) {
return operand.builder()->CrossReplicaSum(operand, replica_groups);
}
XlaOp AllReduce(const XlaOp operand, const XlaComputation& computation,
absl::Span<const ReplicaGroup> replica_groups,
const absl::optional<ChannelHandle>& channel_id,
const absl::optional<Shape>& shape_with_layout) {
return operand.builder()->AllReduce(operand, computation, replica_groups,
channel_id, shape_with_layout);
}
XlaOp ReduceScatter(const XlaOp operand, const XlaComputation& computation,
int64_t scatter_dimension, int64_t shard_count,
absl::Span<const ReplicaGroup> replica_groups,
const absl::optional<ChannelHandle>& channel_id,
const absl::optional<Layout>& layout,
const absl::optional<bool> use_global_device_ids) {
return operand.builder()->ReduceScatter(
operand, computation, scatter_dimension, shard_count, replica_groups,
channel_id, layout, use_global_device_ids);
}
XlaOp AllToAll(const XlaOp operand, int64_t split_dimension,
int64_t concat_dimension, int64_t split_count,
absl::Span<const ReplicaGroup> replica_groups,
const absl::optional<Layout>& layout) {
return operand.builder()->AllToAll(operand, split_dimension, concat_dimension,
split_count, replica_groups, layout);
}
XlaOp AllToAllTuple(absl::Span<const XlaOp> operands,
absl::Span<const ReplicaGroup> replica_groups,
const absl::optional<Layout>& layout) {
CHECK(!operands.empty());
return operands[0].builder()->AllToAllTuple(operands, replica_groups, layout);
}
XlaOp AllToAllTuple(const XlaOp operand, int64_t split_dimension,
int64_t concat_dimension, int64_t split_count,
absl::Span<const ReplicaGroup> replica_groups,
const absl::optional<Layout>& layout) {
return operand.builder()->AllToAllTuple(operand, split_dimension,
concat_dimension, split_count,
replica_groups, layout);
}
XlaOp CollectivePermute(
const XlaOp operand,
const std::vector<std::pair<int64_t, int64_t>>& source_target_pairs) {
return operand.builder()->CollectivePermute(operand, source_target_pairs);
}
XlaOp ReplicaId(XlaBuilder* builder) { return builder->ReplicaId(); }
XlaOp SelectAndScatter(const XlaOp operand, const XlaComputation& select,
absl::Span<const int64_t> window_dimensions,
absl::Span<const int64_t> window_strides,
Padding padding, const XlaOp source,
const XlaOp init_value, const XlaComputation& scatter) {
return operand.builder()->SelectAndScatter(operand, select, window_dimensions,
window_strides, padding, source,
init_value, scatter);
}
XlaOp SelectAndScatterWithGeneralPadding(
const XlaOp operand, const XlaComputation& select,
absl::Span<const int64_t> window_dimensions,
absl::Span<const int64_t> window_strides,
absl::Span<const std::pair<int64_t, int64_t>> padding, const XlaOp source,
const XlaOp init_value, const XlaComputation& scatter) {
return operand.builder()->SelectAndScatterWithGeneralPadding(
operand, select, window_dimensions, window_strides, padding, source,
init_value, scatter);
}
XlaOp Abs(const XlaOp operand) {
return operand.builder()->UnaryOp(HloOpcode::kAbs, operand);
}
XlaOp Atan2(const XlaOp lhs, const XlaOp rhs,
absl::Span<const int64_t> broadcast_dimensions) {
return lhs.builder()->BinaryOp(HloOpcode::kAtan2, lhs, rhs,
broadcast_dimensions);
}
XlaOp Exp(const XlaOp operand) {
return operand.builder()->UnaryOp(HloOpcode::kExp, operand);
}
XlaOp Expm1(const XlaOp operand) {
return operand.builder()->UnaryOp(HloOpcode::kExpm1, operand);
}
XlaOp Floor(const XlaOp operand) {
return operand.builder()->UnaryOp(HloOpcode::kFloor, operand);
}
XlaOp Ceil(const XlaOp operand) {
return operand.builder()->UnaryOp(HloOpcode::kCeil, operand);
}
XlaOp Round(const XlaOp operand) {
return operand.builder()->UnaryOp(HloOpcode::kRoundNearestAfz, operand);
}
XlaOp Log(const XlaOp operand) {
return operand.builder()->UnaryOp(HloOpcode::kLog, operand);
}
XlaOp Log1p(const XlaOp operand) {
return operand.builder()->UnaryOp(HloOpcode::kLog1p, operand);
}
XlaOp Logistic(const XlaOp operand) {
return operand.builder()->UnaryOp(HloOpcode::kLogistic, operand);
}
XlaOp Sign(const XlaOp operand) {
return operand.builder()->UnaryOp(HloOpcode::kSign, operand);
}
XlaOp Clz(const XlaOp operand) {
return operand.builder()->UnaryOp(HloOpcode::kClz, operand);
}
XlaOp Cos(const XlaOp operand) {
return operand.builder()->UnaryOp(HloOpcode::kCos, operand);
}
XlaOp Sin(const XlaOp operand) {
return operand.builder()->UnaryOp(HloOpcode::kSin, operand);
}
XlaOp Tanh(const XlaOp operand) {
return operand.builder()->UnaryOp(HloOpcode::kTanh, operand);
}
XlaOp Real(const XlaOp operand) {
return operand.builder()->UnaryOp(HloOpcode::kReal, operand);
}
XlaOp Imag(const XlaOp operand) {
return operand.builder()->UnaryOp(HloOpcode::kImag, operand);
}
XlaOp Sqrt(const XlaOp operand) {
return operand.builder()->UnaryOp(HloOpcode::kSqrt, operand);
}
XlaOp Cbrt(const XlaOp operand) {
return operand.builder()->UnaryOp(HloOpcode::kCbrt, operand);
}
XlaOp Rsqrt(const XlaOp operand) {
return operand.builder()->UnaryOp(HloOpcode::kRsqrt, operand);
}
XlaOp Pow(const XlaOp lhs, const XlaOp rhs,
absl::Span<const int64_t> broadcast_dimensions) {
return lhs.builder()->BinaryOp(HloOpcode::kPower, lhs, rhs,
broadcast_dimensions);
}
XlaOp IsFinite(const XlaOp operand) {
return operand.builder()->UnaryOp(HloOpcode::kIsFinite, operand);
}
XlaOp ConvertElementType(const XlaOp operand, PrimitiveType new_element_type) {
return operand.builder()->ConvertElementType(operand, new_element_type);
}
XlaOp BitcastConvertType(const XlaOp operand, PrimitiveType new_element_type) {
return operand.builder()->BitcastConvertType(operand, new_element_type);
}
XlaOp Neg(const XlaOp operand) {
return operand.builder()->UnaryOp(HloOpcode::kNegate, operand);
}
XlaOp Transpose(const XlaOp operand, absl::Span<const int64_t> permutation) {
return operand.builder()->Transpose(operand, permutation);
}
XlaOp Rev(const XlaOp operand, absl::Span<const int64_t> dimensions) {
return operand.builder()->Rev(operand, dimensions);
}
XlaOp Sort(absl::Span<const XlaOp> operands, const XlaComputation& comparator,
int64_t dimension, bool is_stable) {
return operands[0].builder()->Sort(operands, comparator, dimension,
is_stable);
}
XlaOp Clamp(const XlaOp min, const XlaOp operand, const XlaOp max) {
return min.builder()->Clamp(min, operand, max);
}
XlaOp Map(XlaBuilder* builder, absl::Span<const XlaOp> operands,
const XlaComputation& computation,
absl::Span<const int64_t> dimensions,
absl::Span<const XlaOp> static_operands) {
return builder->Map(operands, computation, dimensions, static_operands);
}
XlaOp RngNormal(const XlaOp mu, const XlaOp sigma, const Shape& shape) {
return mu.builder()->RngNormal(mu, sigma, shape);
}
XlaOp RngUniform(const XlaOp a, const XlaOp b, const Shape& shape) {
return a.builder()->RngUniform(a, b, shape);
}
XlaOp RngBitGenerator(RandomAlgorithm algorithm, const XlaOp initial_state,
const Shape& shape) {
return initial_state.builder()->RngBitGenerator(algorithm, initial_state,
shape);
}
XlaOp While(const XlaComputation& condition, const XlaComputation& body,
const XlaOp init) {
return init.builder()->While(condition, body, init);
}
XlaOp Conditional(const XlaOp predicate, const XlaOp true_operand,
const XlaComputation& true_computation,
const XlaOp false_operand,
const XlaComputation& false_computation) {
return predicate.builder()->Conditional(predicate, true_operand,
true_computation, false_operand,
false_computation);
}
XlaOp Conditional(const XlaOp branch_index,
absl::Span<const XlaComputation* const> branch_computations,
absl::Span<const XlaOp> branch_operands) {
return branch_index.builder()->Conditional(branch_index, branch_computations,
branch_operands);
}
XlaOp ReducePrecision(const XlaOp operand, const int exponent_bits,
const int mantissa_bits) {
return operand.builder()->ReducePrecision(operand, exponent_bits,
mantissa_bits);
}
XlaOp Gather(const XlaOp input, const XlaOp start_indices,
const GatherDimensionNumbers& dimension_numbers,
absl::Span<const int64_t> slice_sizes, bool indices_are_sorted) {
return input.builder()->Gather(input, start_indices, dimension_numbers,
slice_sizes, indices_are_sorted);
}
XlaOp Scatter(const XlaOp input, const XlaOp scatter_indices,
const XlaOp updates, const XlaComputation& update_computation,
const ScatterDimensionNumbers& dimension_numbers,
bool indices_are_sorted, bool unique_indices) {
return input.builder()->Scatter(input, scatter_indices, updates,
update_computation, dimension_numbers,
indices_are_sorted, unique_indices);
}
XlaOp Scatter(absl::Span<const XlaOp> inputs, XlaOp scatter_indices,
absl::Span<const XlaOp> updates,
const XlaComputation& update_computation,
const ScatterDimensionNumbers& dimension_numbers,
bool indices_are_sorted, bool unique_indices) {
return scatter_indices.builder()->Scatter(
inputs, scatter_indices, updates, update_computation, dimension_numbers,
indices_are_sorted, unique_indices);
}
void Send(const XlaOp operand, const ChannelHandle& handle) {
return operand.builder()->Send(operand, handle);
}
XlaOp Recv(XlaBuilder* builder, const Shape& shape,
const ChannelHandle& handle) {
return builder->Recv(shape, handle);
}
XlaOp SendWithToken(const XlaOp operand, const XlaOp token,
const ChannelHandle& handle) {
return operand.builder()->SendWithToken(operand, token, handle);
}
XlaOp RecvWithToken(const XlaOp token, const Shape& shape,
const ChannelHandle& handle) {
return token.builder()->RecvWithToken(token, shape, handle);
}
XlaOp SendToHost(const XlaOp operand, const XlaOp token,
const Shape& shape_with_layout, const ChannelHandle& handle) {
return operand.builder()->SendToHost(operand, token, shape_with_layout,
handle);
}
XlaOp RecvFromHost(const XlaOp token, const Shape& shape,
const ChannelHandle& handle) {
return token.builder()->RecvFromHost(token, shape, handle);
}
XlaOp InfeedWithToken(const XlaOp token, const Shape& shape,
const std::string& config) {
return token.builder()->InfeedWithToken(token, shape, config);
}
XlaOp OutfeedWithToken(const XlaOp operand, const XlaOp token,
const Shape& shape_with_layout,
const std::string& outfeed_config) {
return operand.builder()->OutfeedWithToken(operand, token, shape_with_layout,
outfeed_config);
}
XlaOp CreateToken(XlaBuilder* builder) { return builder->CreateToken(); }
XlaOp AfterAll(XlaBuilder* builder, absl::Span<const XlaOp> tokens) {
return builder->AfterAll(tokens);
}
XlaOp BatchNormTraining(const XlaOp operand, const XlaOp scale,
const XlaOp offset, float epsilon,
int64_t feature_index) {
return operand.builder()->BatchNormTraining(operand, scale, offset, epsilon,
feature_index);
}
XlaOp BatchNormInference(const XlaOp operand, const XlaOp scale,
const XlaOp offset, const XlaOp mean,
const XlaOp variance, float epsilon,
int64_t feature_index) {
return operand.builder()->BatchNormInference(
operand, scale, offset, mean, variance, epsilon, feature_index);
}
XlaOp BatchNormGrad(const XlaOp operand, const XlaOp scale,
const XlaOp batch_mean, const XlaOp batch_var,
const XlaOp grad_output, float epsilon,
int64_t feature_index) {
return operand.builder()->BatchNormGrad(operand, scale, batch_mean, batch_var,
grad_output, epsilon, feature_index);
}
XlaOp Iota(XlaBuilder* builder, PrimitiveType type, int64_t size) {
return builder->Iota(type, size);
}
XlaOp Iota(XlaBuilder* builder, const Shape& shape, int64_t iota_dimension) {
return builder->Iota(shape, iota_dimension);
}
XlaOp GetDimensionSize(const XlaOp operand, int64_t dimension) {
return operand.builder()->GetDimensionSize(operand, dimension);
}
XlaOp SetDimensionSize(const XlaOp operand, const XlaOp val,
int64_t dimension) {
return operand.builder()->SetDimensionSize(operand, val, dimension);
}
XlaOp RemoveDynamicDimension(const XlaOp operand, int64_t dimension) {
return operand.builder()->RemoveDynamicDimension(operand, dimension);
}
} // namespace xla