blob: 99af3797e2fa4fa733569b374e201e02ed08a2db [file] [log] [blame]
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/pjrt/utils.h"
#include "absl/container/flat_hash_set.h"
#include "tensorflow/compiler/xla/client/executable_build_options.h"
#include "tensorflow/compiler/xla/client/xla_computation.h"
#include "tensorflow/compiler/xla/service/hlo.pb.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/hlo_sharding.h"
#include "tensorflow/compiler/xla/shape.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
namespace xla {
namespace {
StatusOr<Shape> GetShardedShape(const Shape& shape,
const OpSharding& sharding) {
if (sharding.type() == OpSharding::TUPLE) {
if (!shape.IsTuple()) {
return InvalidArgument(
"Got tuple OpSharding (%s) for non-tuple shape (%s)",
sharding.DebugString(), shape.ToString());
}
if (sharding.tuple_shardings_size() != shape.tuple_shapes_size()) {
return InvalidArgument(
"Got mismatched OpSharding tuple size (%d) and shape tuple size (%d)."
" (OpSharding: %s, shape: %s)",
sharding.tuple_shardings_size(), shape.tuple_shapes_size(),
sharding.DebugString(), shape.ToString());
}
std::vector<Shape> sharded_subshapes;
const int tuple_shapes_size = shape.tuple_shapes_size();
sharded_subshapes.reserve(tuple_shapes_size);
for (int i = 0; i < tuple_shapes_size; ++i) {
TF_ASSIGN_OR_RETURN(
Shape sharded_subshape,
GetShardedShape(shape.tuple_shapes(i), sharding.tuple_shardings(i)));
sharded_subshapes.emplace_back(std::move(sharded_subshape));
}
return ShapeUtil::MakeTupleShape(sharded_subshapes);
}
TF_ASSIGN_OR_RETURN(HloSharding hlo_sharding,
HloSharding::FromProto(sharding));
return hlo_sharding.TileShape(shape);
}
StatusOr<Shape> GetShardedShape(const HloInstructionProto& instr) {
const Shape unsharded_shape(instr.shape());
Shape sharded_shape;
if (instr.has_sharding()) {
TF_ASSIGN_OR_RETURN(sharded_shape,
GetShardedShape(unsharded_shape, instr.sharding()));
} else {
sharded_shape = unsharded_shape;
}
LayoutUtil::ClearLayout(&sharded_shape);
return sharded_shape;
}
// Returns sharded (argument shapes, result shape) without layouts.
StatusOr<std::pair<std::vector<Shape>, Shape>> GetShardedProgramShapes(
const XlaComputation& computation, const ProgramShape& program_shape) {
std::vector<Shape> arg_shapes;
arg_shapes.resize(program_shape.parameters_size());
Shape result_shape;
for (const HloComputationProto& comp : computation.proto().computations()) {
if (comp.id() != computation.proto().entry_computation_id()) {
continue;
}
for (const HloInstructionProto& instr : comp.instructions()) {
if (instr.opcode() == HloOpcodeString(HloOpcode::kParameter)) {
if (instr.parameter_number() >= program_shape.parameters_size()) {
return InvalidArgument(
"Got invalid parameter number %d, expected %d parameters",
instr.parameter_number(), program_shape.parameters_size());
}
TF_ASSIGN_OR_RETURN(arg_shapes[instr.parameter_number()],
GetShardedShape(instr));
}
if (instr.id() == comp.root_id()) {
if (result_shape.element_type() != PRIMITIVE_TYPE_INVALID) {
return InvalidArgument("Found multiple root instructions");
}
TF_ASSIGN_OR_RETURN(result_shape, GetShardedShape(instr));
}
}
}
for (int i = 0; i < arg_shapes.size(); ++i) {
if (arg_shapes[i].element_type() == PRIMITIVE_TYPE_INVALID) {
return InvalidArgument("Couldn't find parameter %d", i);
}
}
if (result_shape.element_type() == PRIMITIVE_TYPE_INVALID) {
return InvalidArgument("Couldn't find root instruction");
}
return std::make_pair(arg_shapes, result_shape);
}
} // namespace
Status ParseDeviceAssignmentCompileOptions(
bool compile_portable_executable, ExecutableBuildOptions* build_options,
std::function<StatusOr<DeviceAssignment>(int, int)>
GetDefaultDeviceAssignmentFunction,
int* num_replicas, int* num_partitions,
std::shared_ptr<DeviceAssignment>* device_assignment) {
if (compile_portable_executable) {
if (build_options->has_device_assignment()) {
return InvalidArgument(
"CompileOptions requests portable executable but "
"ExecutableBuildOptions includes a device assignment");
}
*num_replicas = 1;
*num_partitions = 1;
} else {
if (!build_options->has_device_assignment()) {
VLOG(2) << "Compile using default device_assignment.";
TF_ASSIGN_OR_RETURN(
DeviceAssignment device_assignment,
GetDefaultDeviceAssignmentFunction(build_options->num_replicas(),
build_options->num_partitions()));
build_options->set_device_assignment(device_assignment);
}
VLOG(2) << "Compile device_assignment:\n"
<< build_options->device_assignment().ToString();
*num_replicas = build_options->device_assignment().replica_count();
*num_partitions = build_options->device_assignment().computation_count();
*device_assignment =
std::make_shared<DeviceAssignment>(build_options->device_assignment());
}
return Status::OK();
}
Status DetermineArgumentLayoutsFromCompileOptions(
const XlaComputation& computation,
std::function<StatusOr<Shape>(Shape)>
choose_compact_layout_for_shape_function,
std::optional<std::vector<Shape>>& argument_layouts,
ExecutableBuildOptions* build_options,
std::vector<const Shape*>* argument_layout_pointers) {
TF_ASSIGN_OR_RETURN(ProgramShape program_shape,
computation.GetProgramShape());
if (!argument_layouts) {
argument_layouts.emplace(program_shape.parameters());
for (Shape& shape : *argument_layouts) {
LayoutUtil::ClearLayout(&shape);
}
} else if (argument_layouts->size() != program_shape.parameters_size()) {
return InvalidArgument(
"CompileOptions specify %d argument layouts, but computation has %d "
"arguments",
argument_layouts->size(), program_shape.parameters_size());
}
argument_layout_pointers->reserve(argument_layouts->size());
// Assign a default layout based on `sharded_shape` to any array subshapes in
// `dst_shape` that are missing layouts.
auto assign_layouts = [&choose_compact_layout_for_shape_function](
const Shape& sharded_shape, Shape* dst_shape) {
return ShapeUtil::ForEachMutableSubshapeWithStatus(
dst_shape, [&](Shape* subshape, const ShapeIndex& idx) {
if (subshape->IsArray() && !subshape->has_layout()) {
CHECK(ShapeUtil::IndexIsValid(sharded_shape, idx));
const Shape& sharded_subshape =
ShapeUtil::GetSubshape(sharded_shape, idx);
LayoutUtil::SetToDefaultLayout(subshape);
TF_ASSIGN_OR_RETURN(
Shape layout,
choose_compact_layout_for_shape_function(sharded_subshape));
*subshape->mutable_layout() = layout.layout();
}
return Status::OK();
});
};
TF_ASSIGN_OR_RETURN(auto sharded_shapes,
GetShardedProgramShapes(computation, program_shape));
CHECK_EQ(sharded_shapes.first.size(), argument_layouts->size());
for (int i = 0; i < argument_layouts->size(); ++i) {
Shape* layout = &(*argument_layouts)[i];
argument_layout_pointers->push_back(layout);
TF_RETURN_IF_ERROR(assign_layouts(sharded_shapes.first[i], layout));
}
Shape result_layout;
if (build_options->result_layout()) {
result_layout = *build_options->result_layout();
} else {
result_layout = program_shape.result();
LayoutUtil::ClearLayout(&result_layout);
}
TF_RETURN_IF_ERROR(assign_layouts(sharded_shapes.second, &result_layout));
build_options->set_result_layout(result_layout);
return Status::OK();
}
StatusOr<std::vector<int>> ComputeParametersThatMustBeDonated(
const HloModule& module, bool tuple_inputs) {
HloComputation* computation = module.entry_computation();
int number_of_parameters = [&]() -> int {
if (tuple_inputs) {
CHECK_EQ(computation->num_parameters(), 1);
const Shape& input_tuple_shape =
computation->parameter_instruction(0)->shape();
CHECK(input_tuple_shape.IsTuple());
return input_tuple_shape.tuple_shapes_size();
} else {
return computation->num_parameters();
}
}();
// If any buffer in a parameter is aliased we will donate the entire input
// parameter.
std::vector<int> parameters_to_donate;
parameters_to_donate.reserve(computation->num_parameters());
const HloInputOutputAliasConfig& config = module.input_output_alias_config();
TF_RETURN_IF_ERROR(config.ForEachAliasWithStatus(
[&](const ShapeIndex& output_index,
const HloInputOutputAliasConfig::Alias& alias) {
if (tuple_inputs) {
if (alias.parameter_number != 0) {
return InvalidArgument(
"Unexpected parameter number %d in alias config with tupled "
"inputs",
alias.parameter_number);
}
const ShapeIndex& index = alias.parameter_index;
if (!index.empty()) {
int this_parameter = index.data()[0];
if (this_parameter >= number_of_parameters) {
return InvalidArgument(
"Unexpected parameter index %s in alias config with tupled "
"inputs and %d parameters",
index.ToString(), number_of_parameters);
}
parameters_to_donate.push_back(this_parameter);
}
} else {
int this_parameter = alias.parameter_number;
if (this_parameter >= number_of_parameters) {
return InvalidArgument(
"Unexpected parameter number %d in alias config without tupled "
"inputs and %d parameters",
this_parameter, number_of_parameters);
}
parameters_to_donate.push_back(this_parameter);
}
return Status::OK();
}));
absl::c_sort(parameters_to_donate);
return parameters_to_donate;
}
int DefaultThreadPoolSize() {
// Google's CI system exposes an environment variable NPROC that describes
// a CPU reservation for tests.
// TODO(phawkins): expose a better thought-out set of knobs to control
// parallelism.
const char* nproc_str = std::getenv("NPROC");
int nproc = 0;
if (nproc_str && absl::SimpleAtoi(nproc_str, &nproc)) {
return std::max(0, nproc);
}
return tensorflow::port::MaxParallelism();
}
bool HasMajorToMinorLayout(PrimitiveType type, absl::Span<int64_t const> dims,
absl::Span<int64_t const> byte_strides) {
CHECK_EQ(dims.size(), byte_strides.size());
// If the array is size 0, the strides are irrelevant.
if (absl::c_find(dims, 0) != dims.end()) {
return true;
}
int64_t stride = primitive_util::ByteWidth(type);
for (int i = static_cast<int>(dims.size()) - 1; i >= 0; --i) {
// If a dimension is of size 1, its stride is irrelevant.
if (dims[i] != 1) {
if (byte_strides[i] != stride) {
return false;
}
stride *= dims[i];
}
}
return true;
}
} // namespace xla