| /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. |
| |
| Licensed under the Apache License, Version 2.0 (the "License"); |
| you may not use this file except in compliance with the License. |
| You may obtain a copy of the License at |
| |
| http://www.apache.org/licenses/LICENSE-2.0 |
| |
| Unless required by applicable law or agreed to in writing, software |
| distributed under the License is distributed on an "AS IS" BASIS, |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| See the License for the specific language governing permissions and |
| limitations under the License. |
| ==============================================================================*/ |
| |
| #include "tensorflow/compiler/xla/pjrt/utils.h" |
| |
| #include "absl/container/flat_hash_set.h" |
| #include "tensorflow/compiler/xla/client/executable_build_options.h" |
| #include "tensorflow/compiler/xla/client/xla_computation.h" |
| #include "tensorflow/compiler/xla/service/hlo.pb.h" |
| #include "tensorflow/compiler/xla/service/hlo_opcode.h" |
| #include "tensorflow/compiler/xla/service/hlo_sharding.h" |
| #include "tensorflow/compiler/xla/shape.h" |
| #include "tensorflow/compiler/xla/statusor.h" |
| #include "tensorflow/compiler/xla/xla_data.pb.h" |
| |
| namespace xla { |
| |
| namespace { |
| StatusOr<Shape> GetShardedShape(const Shape& shape, |
| const OpSharding& sharding) { |
| if (sharding.type() == OpSharding::TUPLE) { |
| if (!shape.IsTuple()) { |
| return InvalidArgument( |
| "Got tuple OpSharding (%s) for non-tuple shape (%s)", |
| sharding.DebugString(), shape.ToString()); |
| } |
| if (sharding.tuple_shardings_size() != shape.tuple_shapes_size()) { |
| return InvalidArgument( |
| "Got mismatched OpSharding tuple size (%d) and shape tuple size (%d)." |
| " (OpSharding: %s, shape: %s)", |
| sharding.tuple_shardings_size(), shape.tuple_shapes_size(), |
| sharding.DebugString(), shape.ToString()); |
| } |
| std::vector<Shape> sharded_subshapes; |
| const int tuple_shapes_size = shape.tuple_shapes_size(); |
| sharded_subshapes.reserve(tuple_shapes_size); |
| for (int i = 0; i < tuple_shapes_size; ++i) { |
| TF_ASSIGN_OR_RETURN( |
| Shape sharded_subshape, |
| GetShardedShape(shape.tuple_shapes(i), sharding.tuple_shardings(i))); |
| sharded_subshapes.emplace_back(std::move(sharded_subshape)); |
| } |
| return ShapeUtil::MakeTupleShape(sharded_subshapes); |
| } |
| TF_ASSIGN_OR_RETURN(HloSharding hlo_sharding, |
| HloSharding::FromProto(sharding)); |
| return hlo_sharding.TileShape(shape); |
| } |
| |
| StatusOr<Shape> GetShardedShape(const HloInstructionProto& instr) { |
| const Shape unsharded_shape(instr.shape()); |
| Shape sharded_shape; |
| if (instr.has_sharding()) { |
| TF_ASSIGN_OR_RETURN(sharded_shape, |
| GetShardedShape(unsharded_shape, instr.sharding())); |
| } else { |
| sharded_shape = unsharded_shape; |
| } |
| LayoutUtil::ClearLayout(&sharded_shape); |
| return sharded_shape; |
| } |
| |
| // Returns sharded (argument shapes, result shape) without layouts. |
| StatusOr<std::pair<std::vector<Shape>, Shape>> GetShardedProgramShapes( |
| const XlaComputation& computation, const ProgramShape& program_shape) { |
| std::vector<Shape> arg_shapes; |
| arg_shapes.resize(program_shape.parameters_size()); |
| Shape result_shape; |
| for (const HloComputationProto& comp : computation.proto().computations()) { |
| if (comp.id() != computation.proto().entry_computation_id()) { |
| continue; |
| } |
| for (const HloInstructionProto& instr : comp.instructions()) { |
| if (instr.opcode() == HloOpcodeString(HloOpcode::kParameter)) { |
| if (instr.parameter_number() >= program_shape.parameters_size()) { |
| return InvalidArgument( |
| "Got invalid parameter number %d, expected %d parameters", |
| instr.parameter_number(), program_shape.parameters_size()); |
| } |
| TF_ASSIGN_OR_RETURN(arg_shapes[instr.parameter_number()], |
| GetShardedShape(instr)); |
| } |
| if (instr.id() == comp.root_id()) { |
| if (result_shape.element_type() != PRIMITIVE_TYPE_INVALID) { |
| return InvalidArgument("Found multiple root instructions"); |
| } |
| TF_ASSIGN_OR_RETURN(result_shape, GetShardedShape(instr)); |
| } |
| } |
| } |
| for (int i = 0; i < arg_shapes.size(); ++i) { |
| if (arg_shapes[i].element_type() == PRIMITIVE_TYPE_INVALID) { |
| return InvalidArgument("Couldn't find parameter %d", i); |
| } |
| } |
| if (result_shape.element_type() == PRIMITIVE_TYPE_INVALID) { |
| return InvalidArgument("Couldn't find root instruction"); |
| } |
| return std::make_pair(arg_shapes, result_shape); |
| } |
| } // namespace |
| |
| Status ParseDeviceAssignmentCompileOptions( |
| bool compile_portable_executable, ExecutableBuildOptions* build_options, |
| std::function<StatusOr<DeviceAssignment>(int, int)> |
| GetDefaultDeviceAssignmentFunction, |
| int* num_replicas, int* num_partitions, |
| std::shared_ptr<DeviceAssignment>* device_assignment) { |
| if (compile_portable_executable) { |
| if (build_options->has_device_assignment()) { |
| return InvalidArgument( |
| "CompileOptions requests portable executable but " |
| "ExecutableBuildOptions includes a device assignment"); |
| } |
| *num_replicas = 1; |
| *num_partitions = 1; |
| } else { |
| if (!build_options->has_device_assignment()) { |
| VLOG(2) << "Compile using default device_assignment."; |
| TF_ASSIGN_OR_RETURN( |
| DeviceAssignment device_assignment, |
| GetDefaultDeviceAssignmentFunction(build_options->num_replicas(), |
| build_options->num_partitions())); |
| build_options->set_device_assignment(device_assignment); |
| } |
| VLOG(2) << "Compile device_assignment:\n" |
| << build_options->device_assignment().ToString(); |
| *num_replicas = build_options->device_assignment().replica_count(); |
| *num_partitions = build_options->device_assignment().computation_count(); |
| *device_assignment = |
| std::make_shared<DeviceAssignment>(build_options->device_assignment()); |
| } |
| return Status::OK(); |
| } |
| |
| Status DetermineArgumentLayoutsFromCompileOptions( |
| const XlaComputation& computation, |
| std::function<StatusOr<Shape>(Shape)> |
| choose_compact_layout_for_shape_function, |
| std::optional<std::vector<Shape>>& argument_layouts, |
| ExecutableBuildOptions* build_options, |
| std::vector<const Shape*>* argument_layout_pointers) { |
| TF_ASSIGN_OR_RETURN(ProgramShape program_shape, |
| computation.GetProgramShape()); |
| if (!argument_layouts) { |
| argument_layouts.emplace(program_shape.parameters()); |
| for (Shape& shape : *argument_layouts) { |
| LayoutUtil::ClearLayout(&shape); |
| } |
| } else if (argument_layouts->size() != program_shape.parameters_size()) { |
| return InvalidArgument( |
| "CompileOptions specify %d argument layouts, but computation has %d " |
| "arguments", |
| argument_layouts->size(), program_shape.parameters_size()); |
| } |
| argument_layout_pointers->reserve(argument_layouts->size()); |
| |
| // Assign a default layout based on `sharded_shape` to any array subshapes in |
| // `dst_shape` that are missing layouts. |
| auto assign_layouts = [&choose_compact_layout_for_shape_function]( |
| const Shape& sharded_shape, Shape* dst_shape) { |
| return ShapeUtil::ForEachMutableSubshapeWithStatus( |
| dst_shape, [&](Shape* subshape, const ShapeIndex& idx) { |
| if (subshape->IsArray() && !subshape->has_layout()) { |
| CHECK(ShapeUtil::IndexIsValid(sharded_shape, idx)); |
| const Shape& sharded_subshape = |
| ShapeUtil::GetSubshape(sharded_shape, idx); |
| LayoutUtil::SetToDefaultLayout(subshape); |
| TF_ASSIGN_OR_RETURN( |
| Shape layout, |
| choose_compact_layout_for_shape_function(sharded_subshape)); |
| *subshape->mutable_layout() = layout.layout(); |
| } |
| return Status::OK(); |
| }); |
| }; |
| TF_ASSIGN_OR_RETURN(auto sharded_shapes, |
| GetShardedProgramShapes(computation, program_shape)); |
| |
| CHECK_EQ(sharded_shapes.first.size(), argument_layouts->size()); |
| for (int i = 0; i < argument_layouts->size(); ++i) { |
| Shape* layout = &(*argument_layouts)[i]; |
| argument_layout_pointers->push_back(layout); |
| TF_RETURN_IF_ERROR(assign_layouts(sharded_shapes.first[i], layout)); |
| } |
| |
| Shape result_layout; |
| if (build_options->result_layout()) { |
| result_layout = *build_options->result_layout(); |
| } else { |
| result_layout = program_shape.result(); |
| LayoutUtil::ClearLayout(&result_layout); |
| } |
| TF_RETURN_IF_ERROR(assign_layouts(sharded_shapes.second, &result_layout)); |
| build_options->set_result_layout(result_layout); |
| return Status::OK(); |
| } |
| |
| StatusOr<std::vector<int>> ComputeParametersThatMustBeDonated( |
| const HloModule& module, bool tuple_inputs) { |
| HloComputation* computation = module.entry_computation(); |
| int number_of_parameters = [&]() -> int { |
| if (tuple_inputs) { |
| CHECK_EQ(computation->num_parameters(), 1); |
| const Shape& input_tuple_shape = |
| computation->parameter_instruction(0)->shape(); |
| CHECK(input_tuple_shape.IsTuple()); |
| return input_tuple_shape.tuple_shapes_size(); |
| } else { |
| return computation->num_parameters(); |
| } |
| }(); |
| // If any buffer in a parameter is aliased we will donate the entire input |
| // parameter. |
| std::vector<int> parameters_to_donate; |
| parameters_to_donate.reserve(computation->num_parameters()); |
| const HloInputOutputAliasConfig& config = module.input_output_alias_config(); |
| TF_RETURN_IF_ERROR(config.ForEachAliasWithStatus( |
| [&](const ShapeIndex& output_index, |
| const HloInputOutputAliasConfig::Alias& alias) { |
| if (tuple_inputs) { |
| if (alias.parameter_number != 0) { |
| return InvalidArgument( |
| "Unexpected parameter number %d in alias config with tupled " |
| "inputs", |
| alias.parameter_number); |
| } |
| const ShapeIndex& index = alias.parameter_index; |
| if (!index.empty()) { |
| int this_parameter = index.data()[0]; |
| if (this_parameter >= number_of_parameters) { |
| return InvalidArgument( |
| "Unexpected parameter index %s in alias config with tupled " |
| "inputs and %d parameters", |
| index.ToString(), number_of_parameters); |
| } |
| parameters_to_donate.push_back(this_parameter); |
| } |
| } else { |
| int this_parameter = alias.parameter_number; |
| if (this_parameter >= number_of_parameters) { |
| return InvalidArgument( |
| "Unexpected parameter number %d in alias config without tupled " |
| "inputs and %d parameters", |
| this_parameter, number_of_parameters); |
| } |
| parameters_to_donate.push_back(this_parameter); |
| } |
| return Status::OK(); |
| })); |
| absl::c_sort(parameters_to_donate); |
| return parameters_to_donate; |
| } |
| |
| int DefaultThreadPoolSize() { |
| // Google's CI system exposes an environment variable NPROC that describes |
| // a CPU reservation for tests. |
| // TODO(phawkins): expose a better thought-out set of knobs to control |
| // parallelism. |
| const char* nproc_str = std::getenv("NPROC"); |
| int nproc = 0; |
| if (nproc_str && absl::SimpleAtoi(nproc_str, &nproc)) { |
| return std::max(0, nproc); |
| } |
| return tensorflow::port::MaxParallelism(); |
| } |
| |
| bool HasMajorToMinorLayout(PrimitiveType type, absl::Span<int64_t const> dims, |
| absl::Span<int64_t const> byte_strides) { |
| CHECK_EQ(dims.size(), byte_strides.size()); |
| // If the array is size 0, the strides are irrelevant. |
| if (absl::c_find(dims, 0) != dims.end()) { |
| return true; |
| } |
| int64_t stride = primitive_util::ByteWidth(type); |
| for (int i = static_cast<int>(dims.size()) - 1; i >= 0; --i) { |
| // If a dimension is of size 1, its stride is irrelevant. |
| if (dims[i] != 1) { |
| if (byte_strides[i] != stride) { |
| return false; |
| } |
| stride *= dims[i]; |
| } |
| } |
| return true; |
| } |
| |
| } // namespace xla |