| /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. |
| |
| Licensed under the Apache License, Version 2.0 (the "License"); |
| you may not use this file except in compliance with the License. |
| You may obtain a copy of the License at |
| |
| http://www.apache.org/licenses/LICENSE-2.0 |
| |
| Unless required by applicable law or agreed to in writing, software |
| distributed under the License is distributed on an "AS IS" BASIS, |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| See the License for the specific language governing permissions and |
| limitations under the License. |
| ==============================================================================*/ |
| |
| #include <memory> |
| #include <string> |
| |
| #include "llvm/ADT/ArrayRef.h" |
| #include "llvm/ADT/Optional.h" |
| #include "llvm/ADT/STLExtras.h" |
| #include "llvm/ADT/SmallVector.h" |
| #include "llvm/ADT/StringRef.h" |
| #include "llvm/Support/Casting.h" |
| #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project |
| #include "mlir/IR/Attributes.h" // from @llvm-project |
| #include "mlir/IR/Block.h" // from @llvm-project |
| #include "mlir/IR/Builders.h" // from @llvm-project |
| #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project |
| #include "mlir/IR/BuiltinOps.h" // from @llvm-project |
| #include "mlir/IR/Operation.h" // from @llvm-project |
| #include "mlir/IR/Value.h" // from @llvm-project |
| #include "mlir/Interfaces/CallInterfaces.h" // from @llvm-project |
| #include "mlir/Pass/Pass.h" // from @llvm-project |
| #include "mlir/Pass/PassRegistry.h" // from @llvm-project |
| #include "mlir/Support/LogicalResult.h" // from @llvm-project |
| #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" |
| #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" |
| #include "tensorflow/compiler/mlir/tensorflow/transforms/passes_detail.h" |
| #include "tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.h" |
| #include "tensorflow/compiler/xla/client/sharding_builder.h" |
| #include "tensorflow/compiler/xla/xla_data.pb.h" |
| |
| namespace mlir { |
| namespace TFTPU { |
| namespace { |
| |
| constexpr char kReplicateSharding[] = ""; |
| constexpr char kShardingAttr[] = "mhlo.sharding"; |
| constexpr char kUseSpmdAttr[] = "use_spmd_for_xla_partitioning"; |
| constexpr char kAliasingAttr[] = "tf.aliasing_output"; |
| |
| struct TPUShardingIdentificationPass |
| : public TF::TPUShardingIdentificationPassBase< |
| TPUShardingIdentificationPass> { |
| void runOnOperation() final; |
| }; |
| |
| // Returns XLA sharding from TPUPartitionedInput op connected to a |
| // `tf_device.cluster_func` operand value. If value is a resource type then |
| // TPUPartitionedInput op will be connected to a ReadVariable op that feeds into |
| // a `tf_device.cluster_func`. |
| llvm::Optional<llvm::StringRef> GetXlaShardingFromOperand(Value value) { |
| Value value_to_visit = value; |
| if (auto read_var = value_to_visit.getDefiningOp<TF::ReadVariableOp>()) |
| value_to_visit = read_var.resource(); |
| |
| if (auto partitioned_input = |
| value_to_visit.getDefiningOp<TF::TPUPartitionedInputOp>()) |
| return partitioned_input._XlaSharding(); |
| |
| return llvm::None; |
| } |
| |
| // Given a `tf_device.cluster_func` operand value return true iff it a device |
| // variable that should default to MAXIMAL sharding. Device variables that are |
| // per-replica or distributed default to MAXIMAL sharding, which corresponds to |
| // arguments of the `tf_device.replicate`. Otherwise the variable is broadcast, |
| // which corresponds to edges that are implicitly captured by the `replicate`. |
| bool IsMaximalVariable(Value value) { |
| auto read_var = value.getDefiningOp<TF::ReadVariableOp>(); |
| return read_var && read_var->getParentOfType<tf_device::ReplicateOp>(); |
| } |
| |
| // Verify whether the given sharding can be applied to the given (tensor) type. |
| // (A bad sharding might mean failing tf.Split ops if the graph later executes |
| // on CPU) |
| // If the sharding is incorrect, return failure. If it's good, or if we can't |
| // verify it, return success. |
| LogicalResult VerifySharding(Type type, StringRef sharding_string) { |
| xla::OpSharding sharding; |
| if (!sharding.ParseFromString(sharding_string.str())) { |
| // Some test cases use \01\02\03 as sharding, to test propagation. Treat |
| // a non-proto sharding as valid, and don't verify further. |
| return success(); |
| } |
| if (sharding.type() != xla::OpSharding::OTHER) { |
| // We currently only verify shardings that actually break a tensor apart. |
| return success(); |
| } |
| if (RankedTensorType ranked_type = type.dyn_cast<RankedTensorType>()) { |
| if (ranked_type.getRank() < sharding.tile_assignment_dimensions_size()) { |
| return failure(); |
| } |
| } |
| return success(); |
| } |
| |
| // Verify sharding for all arguments and return values. |
| LogicalResult VerifyShardings( |
| mlir::func::FuncOp func, |
| const llvm::SmallVectorImpl<llvm::StringRef>& sharding_for_args, |
| const llvm::SmallVectorImpl<llvm::StringRef>& sharding_for_rets) { |
| Block& function_block = func.front(); |
| for (auto sharding_and_arg : |
| llvm::zip(sharding_for_args, function_block.getArguments())) { |
| StringRef sharding = std::get<0>(sharding_and_arg); |
| BlockArgument arg = std::get<1>(sharding_and_arg); |
| if (failed(VerifySharding(arg.getType(), sharding))) return failure(); |
| } |
| Operation* terminator = function_block.getTerminator(); |
| for (auto sharding_and_retval : |
| llvm::zip(sharding_for_rets, terminator->getOpOperands())) { |
| StringRef sharding = std::get<0>(sharding_and_retval); |
| OpOperand& retval = std::get<1>(sharding_and_retval); |
| if (failed(VerifySharding(retval.get().getType(), sharding))) |
| return failure(); |
| } |
| return success(); |
| } |
| |
| // Returns XLA sharding from a XlaSharding op connected to an argument value. If |
| // value is a resource type then XlaSharding op will be connected to a |
| // ReadVariable op. XlaSharding op may be direct user of inputs but it may also |
| // be followed by an Identity op and, in the case where bfloat16 type is used, |
| // Cast op may be added right after the input. |
| // |
| // TODO(hongjunchoi): Add logic to parse XlaSharding op inside control flow (If, |
| // Case, While) ops and Caller return values. |
| // TODO(hongjunchoi): Consider explicitly checking op patterns to detect sharded |
| // inputs. |
| llvm::Optional<llvm::StringRef> GetXlaShardingFromArg(Value value) { |
| llvm::SmallPtrSet<Value, 4> visited_values; |
| llvm::SmallVector<Value, 4> values_to_visit{value}; |
| while (!values_to_visit.empty()) { |
| llvm::SmallVector<Value, 4> next_values_to_visit; |
| for (Value value_to_visit : values_to_visit) { |
| if (!visited_values.insert(value_to_visit).second) continue; |
| |
| for (auto& use : value_to_visit.getUses()) { |
| Operation* owner = use.getOwner(); |
| if (auto sharding = llvm::dyn_cast<TF::XlaShardingOp>(owner)) |
| return sharding._XlaSharding(); |
| |
| if (llvm::isa<TF::IdentityOp, TF::CastOp, TF::ReadVariableOp>(owner)) { |
| next_values_to_visit.push_back(use.getOwner()->getResult(0)); |
| continue; |
| } |
| |
| if (auto call_op = llvm::dyn_cast<CallOpInterface>(owner)) { |
| func::FuncOp func = |
| llvm::dyn_cast<func::FuncOp>(call_op.resolveCallable()); |
| if (!func) continue; |
| next_values_to_visit.push_back( |
| func.getArgument(use.getOperandNumber())); |
| } |
| } |
| } |
| |
| values_to_visit.swap(next_values_to_visit); |
| } |
| |
| return llvm::None; |
| } |
| |
| // Extracts sharding configurations for all inputs by parsing XlaSharding/ |
| // TPUPartitionedInput op connected to the operands/arguments. If argument to |
| // the `cluster_func` directly feeds into another function call op, then |
| // recursively walk the function definition to find the connected XlaSharding |
| // op. |
| void IdentifyXlaShardingForComputationInputs( |
| StringRef logical_core_0_sharding, bool use_spmd, |
| bool infer_from_computation, tf_device::ClusterFuncOp cluster_func, |
| func::FuncOp func, Builder* builder, |
| llvm::SmallVectorImpl<llvm::StringRef>& sharding_for_args) { |
| // Look up function definition from module. |
| Block& function_block = func.front(); |
| |
| sharding_for_args.reserve(function_block.getNumArguments()); |
| |
| // Iterate through operands of `cluster_func`. |
| // The computation operand can either be: |
| // 1) a TPUPartitionedInput Op if the input has a non-resource type; |
| // 2) a ReadVariableOp else. |
| // |
| // Replicate sharding is used if `use_spmd` is set. |
| // |
| // Iterate through input arguments to the entry block of |
| // tf_device.ClusterFunc. For input ops, look for XlaSharding ops. |
| // XlaSharding ops can: |
| // 1) Directly follow the input argument if input argument has non-resource |
| // types. |
| // 2) Follow ReadVariableOp if the input type is of resource type. |
| // 3) Follow IdentityOp or CastOp after above cases (1), (2). |
| // |
| // Sharding configurations are added to the tf_device.ClusterFunc as an |
| // attribute and the function as an argument attribute. |
| for (auto operand_and_arg : |
| llvm::zip(cluster_func.operands(), function_block.getArguments())) { |
| Value operand = std::get<0>(operand_and_arg); |
| BlockArgument arg = std::get<1>(operand_and_arg); |
| |
| if (auto operand_sharding = GetXlaShardingFromOperand(operand)) { |
| sharding_for_args.push_back(operand_sharding.getValue()); |
| continue; |
| } |
| |
| if (infer_from_computation) { |
| auto arg_sharding = GetXlaShardingFromArg(arg); |
| if (arg_sharding) { |
| sharding_for_args.push_back(arg_sharding.getValue()); |
| continue; |
| } |
| } |
| |
| if (use_spmd && !IsMaximalVariable(operand)) { |
| // If XLA SPMD is enabled, host variables or non-variable per-replica |
| // inputs should take on replicate sharding, so that every device gets the |
| // whole tensor(s) (and can slice them up later). Exclude device |
| // variables, which always should take maximal sharding. |
| sharding_for_args.push_back(kReplicateSharding); |
| continue; |
| } |
| |
| // Otherwise, default to maximal sharding core 0. |
| sharding_for_args.push_back(logical_core_0_sharding); |
| } |
| } |
| |
| // Returns XLA sharding from TPUPartitionedOutput or TPUPartitionedInput (via |
| // AssignVariableOp/resource write) op connected to a `tf_device.cluster_func` |
| // result value. |
| llvm::Optional<llvm::StringRef> GetXlaShardingFromResult(Value value) { |
| if (!value.hasOneUse()) return llvm::None; |
| |
| Operation* user = *value.getUsers().begin(); |
| if (auto partitioned_output = |
| llvm::dyn_cast<TF::TPUPartitionedOutputOp>(user)) |
| return partitioned_output._XlaSharding(); |
| |
| if (auto assign_var = llvm::dyn_cast<TF::AssignVariableOp>(user)) |
| if (auto partitioned_input = |
| assign_var.resource().getDefiningOp<TF::TPUPartitionedInputOp>()) |
| return partitioned_input._XlaSharding(); |
| |
| return llvm::None; |
| } |
| |
| // Looks up arg->retval aliases for every argument, and builds a reverse map. |
| void ExtractAliases(func::FuncOp func, llvm::SmallVectorImpl<int>& aliases) { |
| aliases.resize(func.getNumResults(), -1); |
| for (int i = 0; i < func.getNumArguments(); i++) { |
| if (auto v = func.getArgAttrOfType<mlir::IntegerAttr>(i, kAliasingAttr)) { |
| int retval_index = v.getInt(); |
| if (retval_index >= 0 && retval_index < aliases.size()) { |
| aliases[retval_index] = i; |
| } |
| } |
| } |
| } |
| |
| // Returns XLA sharding from argument connected via tf.aliasing_output. |
| llvm::Optional<StringRef> GetXlaShardingFromAlias( |
| Value value, llvm::SmallVectorImpl<int>& aliases, |
| const llvm::SmallVectorImpl<llvm::StringRef>& sharding_for_args) { |
| int retval_index = value.cast<OpResult>().getResultNumber(); |
| if (retval_index >= 0 && retval_index < aliases.size()) { |
| int arg_index = aliases[retval_index]; |
| if (arg_index >= 0 && arg_index < sharding_for_args.size()) { |
| return sharding_for_args[arg_index]; |
| } |
| } |
| return llvm::None; |
| } |
| |
| // Returns XLA sharding from XlaSharding op connected to a result value. |
| // XlaSharding op may be directly connected to output but it may also be |
| // followed by Identity or simple arithmetic ops. In case where bfloat16 type is |
| // used, we might see a Cast op. |
| // |
| // TODO(hongjunchoi): Add logic to parse XlaSharding op inside control flow (If, |
| // Case, While) ops and Caller argument values. |
| // TODO(hongjunchoi): Consider explicitly checking op patterns to detect sharded |
| // inputs. |
| llvm::Optional<StringRef> GetXlaShardingFromRetval(Value value) { |
| llvm::SmallPtrSet<Value, 4> visited_values; |
| llvm::SmallVector<Value, 4> values_to_visit; |
| values_to_visit.push_back(value); |
| |
| while (!values_to_visit.empty()) { |
| Value value_to_visit = values_to_visit.pop_back_val(); |
| |
| if (!visited_values.insert(value_to_visit).second) { |
| continue; |
| } |
| |
| Operation* def = value_to_visit.getDefiningOp(); |
| if (!def) { |
| continue; |
| } |
| |
| if (auto sharding = llvm::dyn_cast_or_null<TF::XlaShardingOp>(def)) |
| return sharding._XlaSharding(); |
| |
| if ( // Cast, real/imag, etc. |
| def->hasTrait<mlir::OpTrait::SameOperandsAndResultShape>() || |
| // Exp, ceil, etc. |
| def->hasTrait<mlir::OpTrait::SameOperandsAndResultType>() || |
| // Identity |
| def->hasTrait<mlir::OpTrait::TF::OperandsSameAsResultsTypeOrRef>() || |
| // AddV2, Sub, etc. |
| (def->hasTrait< |
| mlir::OpTrait::TF::SameOperandsAndResultElementTypeResolveRef>() && |
| def->hasTrait<mlir::OpTrait::TF::CwiseBinary>())) { |
| for (auto operand : def->getOperands()) { |
| values_to_visit.push_back(operand); |
| } |
| continue; |
| } |
| |
| if (auto call_op = llvm::dyn_cast_or_null<CallOpInterface>(def)) { |
| func::FuncOp func = |
| llvm::dyn_cast<func::FuncOp>(call_op.resolveCallable()); |
| if (!func) continue; |
| value_to_visit = func.front().getTerminator()->getOperand( |
| value_to_visit.cast<OpResult>().getResultNumber()); |
| values_to_visit.push_back(value_to_visit); |
| continue; |
| } |
| } |
| |
| return llvm::None; |
| } |
| |
| // Extracts sharding configurations for all outputs by parsing XlaSharding/ |
| // TPUPartitionedOutput op connected to the retvals/results. |
| void IdentifyXlaShardingForComputationOutputs( |
| StringRef logical_core_0_sharding, bool use_spmd, |
| bool infer_from_computation, tf_device::ClusterFuncOp cluster_func, |
| func::FuncOp func, Builder* builder, |
| const llvm::SmallVectorImpl<llvm::StringRef>& sharding_for_args, |
| llvm::SmallVectorImpl<llvm::StringRef>& sharding_for_rets) { |
| Block& function_block = func.front(); |
| Operation* terminator = function_block.getTerminator(); |
| sharding_for_rets.reserve(terminator->getNumOperands()); |
| |
| llvm::SmallVector<int, 8> aliases; // maps return value index to arg index |
| ExtractAliases(func, aliases); |
| |
| // Iterate through results of `cluster_func`. For output ops, look for |
| // TPUPartitionedOutput ops. |
| // |
| // Replicate sharding is used if `use_spmd` is set. |
| // |
| // Iterate through operands of the terminator. If the preceding op is |
| // XlaShardingOp, then the provided sharding configuration is added to the |
| // tf_device.ClusterFunc as an attribute and the function as a result |
| // attribute. |
| for (auto result_and_retval : |
| llvm::zip(cluster_func.results(), terminator->getOpOperands())) { |
| Value result = std::get<0>(result_and_retval); |
| OpOperand& retval = std::get<1>(result_and_retval); |
| |
| if (auto result_sharding = GetXlaShardingFromResult(result)) { |
| sharding_for_rets.push_back(result_sharding.getValue()); |
| continue; |
| } |
| |
| if (auto from_alias = |
| GetXlaShardingFromAlias(result, aliases, sharding_for_args)) { |
| sharding_for_rets.push_back(from_alias.getValue()); |
| continue; |
| } |
| |
| if (infer_from_computation) { |
| if (auto retval_sharding = GetXlaShardingFromRetval(retval.get())) { |
| sharding_for_rets.push_back(retval_sharding.getValue()); |
| continue; |
| } |
| } |
| |
| if (use_spmd) { |
| // If XLA SPMD is enabled, we default to replicate sharding. This way, |
| // all devices get the whole tensor(s), but if there's an XlaSharding op |
| // deeper in the function, they can use dynamic-slice to slice off their |
| // part of the computation. |
| sharding_for_rets.push_back(kReplicateSharding); |
| continue; |
| } |
| |
| // Otherwise, default to maximal sharding core 0. |
| sharding_for_rets.push_back(logical_core_0_sharding); |
| } |
| } |
| |
| // Extracts input/output sharding configuration of `cluster_func` by parsing |
| // XlaSharding ops inside the `cluster_func`. |
| void IdentifyXlaShardingForTPUComputation( |
| Builder* builder, tf_device::ClusterFuncOp cluster_func) { |
| // Look up function definition from module. |
| func::FuncOp func = |
| cluster_func->getParentOfType<ModuleOp>().lookupSymbol<func::FuncOp>( |
| cluster_func.func()); |
| |
| // By default inputs/outputs have maximal sharding and are assigned to logical |
| // core 0 if no sharding is defined. |
| const std::string logical_core_0_sharding = |
| xla::sharding_builder::AssignDevice(0).SerializeAsString(); |
| |
| bool use_spmd = false; |
| if (auto use_spmd_attr = cluster_func->getAttrOfType<BoolAttr>(kUseSpmdAttr)) |
| use_spmd = use_spmd_attr.getValue(); |
| |
| llvm::SmallVector<llvm::StringRef, 8> sharding_for_args; |
| IdentifyXlaShardingForComputationInputs(logical_core_0_sharding, use_spmd, |
| /*infer_from_computation=*/true, |
| cluster_func, func, builder, |
| sharding_for_args); |
| |
| llvm::SmallVector<llvm::StringRef, 8> sharding_for_rets; |
| IdentifyXlaShardingForComputationOutputs( |
| logical_core_0_sharding, use_spmd, /*infer_from_computation=*/true, |
| cluster_func, func, builder, sharding_for_args, sharding_for_rets); |
| |
| auto has_maximal_sharding = [](llvm::StringRef sharding_string) -> bool { |
| xla::OpSharding sharding; |
| sharding.ParseFromString(sharding_string.str()); |
| return sharding.type() == xla::OpSharding::MAXIMAL; |
| }; |
| |
| // XLA SPMD only supports cases where all inputs/outputs exist on every |
| // partition (sharded or replicated). If any of the inputs/outputs have |
| // maximal sharding, then fallback to MPMD. Also fall back if any of the |
| // shardings aren't compatible with the rank of their tensor. |
| if ((use_spmd && (absl::c_any_of(sharding_for_args, has_maximal_sharding) || |
| absl::c_any_of(sharding_for_rets, has_maximal_sharding))) || |
| failed(VerifyShardings(func, sharding_for_args, sharding_for_rets))) { |
| LOG(WARNING) << "XLA SPMD only supports cases where all inputs/outputs " |
| "exist on every partition (sharded or replicated). If any " |
| "of the inputs/outputs have maximal sharding, then " |
| "fallback to MPMD."; |
| sharding_for_args.clear(); |
| sharding_for_rets.clear(); |
| cluster_func->setAttr(kUseSpmdAttr, builder->getBoolAttr(false)); |
| |
| IdentifyXlaShardingForComputationInputs( |
| logical_core_0_sharding, |
| /*use_spmd=*/false, /*infer_from_computation=*/false, cluster_func, |
| func, builder, sharding_for_args); |
| IdentifyXlaShardingForComputationOutputs( |
| logical_core_0_sharding, |
| /*use_spmd=*/false, /*infer_from_computation=*/false, cluster_func, |
| func, builder, sharding_for_args, sharding_for_rets); |
| } |
| |
| // Update sharding on function arguments and returns. |
| Block& function_block = func.front(); |
| for (auto sharding_and_arg : |
| llvm::zip(sharding_for_args, function_block.getArguments())) { |
| StringRef sharding = std::get<0>(sharding_and_arg); |
| BlockArgument arg = std::get<1>(sharding_and_arg); |
| func.setArgAttr(arg.getArgNumber(), kShardingAttr, |
| builder->getStringAttr(sharding)); |
| } |
| |
| Operation* terminator = function_block.getTerminator(); |
| for (auto sharding_and_retval : |
| llvm::zip(sharding_for_rets, terminator->getOpOperands())) { |
| StringRef sharding = std::get<0>(sharding_and_retval); |
| OpOperand& retval = std::get<1>(sharding_and_retval); |
| func.setResultAttr(retval.getOperandNumber(), kShardingAttr, |
| builder->getStringAttr(sharding)); |
| } |
| |
| // Update input/output sharding attributes on tf_device.cluster_func op. |
| cluster_func->setAttr(tensorflow::kInputShardingAttr, |
| builder->getStrArrayAttr(sharding_for_args)); |
| cluster_func->setAttr(tensorflow::kOutputShardingAttr, |
| builder->getStrArrayAttr(sharding_for_rets)); |
| } |
| |
| void TPUShardingIdentificationPass::runOnOperation() { |
| Builder builder(getOperation().getContext()); |
| |
| getOperation().walk([&](tf_device::ClusterFuncOp cluster_func) { |
| IdentifyXlaShardingForTPUComputation(&builder, cluster_func); |
| }); |
| } |
| |
| } // anonymous namespace |
| |
| std::unique_ptr<OperationPass<ModuleOp>> CreateTPUShardingIdentificationPass() { |
| return std::make_unique<TPUShardingIdentificationPass>(); |
| } |
| |
| } // namespace TFTPU |
| } // namespace mlir |