blob: 084158c9efc0e3ff31856926d490808010ef9111 [file] [log] [blame]
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <algorithm>
#include <iterator>
#include <memory>
#include <tuple>
#include <utility>
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/DenseSet.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/ADT/iterator_range.h"
#include "llvm/Support/Casting.h"
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/Builders.h" // from @llvm-project
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/IR/Operation.h" // from @llvm-project
#include "mlir/IR/Types.h" // from @llvm-project
#include "mlir/IR/Value.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/Pass/PassRegistry.h" // from @llvm-project
#include "mlir/Support/DebugStringHelper.h" // from @llvm-project
#include "mlir/Support/LogicalResult.h" // from @llvm-project
#include "mlir/Transforms/RegionUtils.h" // from @llvm-project
#include "tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes_detail.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/attribute_utils.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.h"
namespace mlir {
namespace TFTPU {
namespace {
constexpr char kDeviceAttr[] = "device";
constexpr char kNameAttr[] = "name";
constexpr char kNumCoresPerReplicaAttr[] = "num_cores_per_replica";
constexpr char kNumReplicasAttr[] = "num_replicas";
constexpr char kReplicatedInputIndicesAttr[] = "_replicated_input_indices";
constexpr char kMirroredVariableIndicesAttr[] = "_mirrored_variable_indices";
constexpr char kBadReplicateInfoAttrMsg[] =
"requires '_replication_info' string attribute";
// Mapping for `_replication_info` attribute to TPUReplicateMetadata attributes.
using MetadataMap = llvm::SmallDenseMap<llvm::StringRef, NamedAttrList, 8>;
// A set of operations. We use a `SmallSetVector` in order to have deterministic
// traversal order (= insertion order), independent of the pointer keys.
using OpSetVector = llvm::SmallSetVector<Operation*, 8>;
// Mapping for `_replication_info` attribute to ops of a cluster.
using ClusterMap = llvm::SmallDenseMap<llvm::StringRef, OpSetVector, 8>;
struct TPUClusterFormationPass
: public TF::TPUClusterFormationPassBase<TPUClusterFormationPass> {
void getDependentDialects(DialectRegistry& registry) const override {
registry.insert<tf_device::TensorFlowDeviceDialect>();
}
void runOnOperation() override;
};
// Creates a mapping from the TPUReplicateMetadata ops `_replication_info`
// attribute to its attributes and removes the ops. If multiple
// TPUReplicateMetadata ops have the same `_replication_info` attribute, an
// error will be returned.
LogicalResult CollectMetadata(Block* block, MetadataMap* metadata_map) {
// Just look at top-level operations in the block (not nested ones)
for (Operation& op : llvm::make_early_inc_range(*block)) {
auto metadata_op = dyn_cast<TF::TPUReplicateMetadataOp>(op);
if (!metadata_op) continue;
NamedAttrList attrs(metadata_op->getAttrDictionary());
// Missing or bad `_replication_info` attribute.
auto tpu_replicate_attr = attrs.get(TF::kReplicationInfoAttr);
if (!tpu_replicate_attr)
return metadata_op.emitError() << kBadReplicateInfoAttrMsg;
auto tpu_replicate_attr_str = tpu_replicate_attr.dyn_cast<StringAttr>();
if (!tpu_replicate_attr_str || tpu_replicate_attr_str.getValue().empty())
return metadata_op.emitError() << kBadReplicateInfoAttrMsg;
// Remove `name` attribute.
attrs.erase(StringAttr::get(metadata_op.getContext(), kNameAttr));
auto it = metadata_map->try_emplace(tpu_replicate_attr_str.getValue(),
std::move(attrs));
// There are multiple TPUReplicateMetadata ops with the same
// `_replication_info` attribute.
if (!it.second) {
return metadata_op.emitError()
<< "multiple TPUReplicateMetadata ops with the same '"
<< TF::kReplicationInfoAttr << "' attribute '"
<< tpu_replicate_attr_str.getValue() << "' found";
}
metadata_op.erase();
}
return success();
}
// Collects and clusters ops with the same `_replication_info` attribute. This
// will return an error if a `_replication_info` attribute of an op is empty.
LogicalResult CollectAndGroupClusterOps(Block* block, ClusterMap* clusters) {
for (Operation& op : *block) {
if (op.hasAttr(TF::kReplicationInfoAttr) ||
op.hasAttr(TF::kCompileDeviceTypeAttr)) {
auto result = TF::HasValidCompilationAndReplicationAttributes(op);
if (failed(result)) return result;
auto attr = op.getAttrOfType<StringAttr>(TF::kReplicationInfoAttr);
auto it = clusters->try_emplace(attr.getValue());
it.first->getSecond().insert(&op);
}
}
return success();
}
// Returns true iff `op` has a direct control dependency from (`incoming` ==
// true) or to (`incoming` == false) any op in `cluster_ops` or
// `cluster_dependent_ops`.
bool hasOpClusterControlDependency(
Operation* op, bool incoming, const OpSetVector& cluster_ops,
const OpSetVector& cluster_dependent_ops,
const TF::SideEffectAnalysis::Info& side_effect_analysis) {
auto filter = [&](Operation* other_op) {
return cluster_ops.contains(other_op) ||
cluster_dependent_ops.contains(other_op);
};
return incoming ? !side_effect_analysis.DirectControlPredecessors(op, filter)
.empty()
: !side_effect_analysis.DirectControlSuccessors(op, filter)
.empty();
}
// Returns true iff `op` has a direct data dependency from (`incoming` == true
// or to (`incoming` == false) any op in `cluster_ops` or
// `cluster_dependent_ops`.
bool hasOpClusterDataDependency(Operation* op, bool incoming,
const OpSetVector& cluster_ops,
const OpSetVector& cluster_dependent_ops) {
auto result = op->walk([&](Operation* inner_op) {
ValueRange values = incoming ? ValueRange(inner_op->getOperands())
: ValueRange(inner_op->getResults());
llvm::SmallVector<Operation*, 4> candidates;
for (Value value : values) {
if (incoming) {
candidates = {value.getDefiningOp()};
} else {
candidates.assign(value.getUsers().begin(), value.getUsers().end());
}
for (Operation* candidate_op : candidates) {
if (cluster_ops.contains(candidate_op) ||
cluster_dependent_ops.contains(candidate_op)) {
return WalkResult::interrupt();
}
}
}
return WalkResult::advance();
});
return result.wasInterrupted();
}
// Collects ops that need to be moved behind the cluster due to data or control
// dependencies.
llvm::SmallSetVector<Operation*, 8> CollectClusterSuccessorOps(
Block* block, const OpSetVector& cluster_ops,
const TF::SideEffectAnalysis::Info& side_effect_analysis) {
OpSetVector cluster_predecessor_ops;
OpSetVector cluster_successor_ops;
// Collect non-cluster ops that have a dependency to the cluster. For this
// traverse all ops from last to first cluster op and keep track of in-between
// non-cluster ops that have some outgoing (transitive) dependency to some
// cluster op (`cluster_predecessor_ops`).
auto rfront = Block::reverse_iterator(cluster_ops.front());
auto rback = Block::reverse_iterator(cluster_ops.back());
for (Operation& op : llvm::make_range(rback, rfront)) {
if (cluster_ops.contains(&op)) continue;
bool has_dependency_to_cluster =
hasOpClusterDataDependency(&op, /*incoming=*/false, cluster_ops,
cluster_predecessor_ops) ||
hasOpClusterControlDependency(&op, /*incoming=*/false, cluster_ops,
cluster_predecessor_ops,
side_effect_analysis);
if (has_dependency_to_cluster) cluster_predecessor_ops.insert(&op);
}
// Collect non-cluster ops that have a dependency from the cluster. For this
// traverse all ops from first to last cluster op and keep track of in-between
// non-cluster ops that have some incoming (transitive) dependency from some
// cluster op (`cluster_successor_ops`).
auto front = Block::iterator(cluster_ops.front());
auto back = Block::iterator(cluster_ops.back());
for (Operation& op : llvm::make_range(front, back)) {
if (cluster_ops.contains(&op)) continue;
bool has_dependency_from_cluster =
hasOpClusterDataDependency(&op, /*incoming=*/true, cluster_ops,
cluster_successor_ops) ||
hasOpClusterControlDependency(&op, /*incoming=*/true, cluster_ops,
cluster_successor_ops,
side_effect_analysis);
if (has_dependency_from_cluster) {
if (cluster_predecessor_ops.contains(&op)) {
// Op has a dependency from and to the cluster which is invalid. Instead
// of erroring out we don't add the op to `cluster_successor_ops` which
// is in line with previous behavior when certain control dependencies
// were not considered.
// TODO(b/216706460) Establish some contract here: Should we expect only
// valid clusters, or should we split clusters accordingly? The latter
// might have runtime impact for existing models.
// We should make this message an error once there is such a contract
// and once existing cases have been fixed.
VLOG(1) << "Invalid TPU cluster structure. Following op is both a "
"predecessor and a successor of a cluster: "
<< mlir::debugString(op);
} else {
cluster_successor_ops.insert(&op);
}
}
}
return cluster_successor_ops;
}
// Collects results and associated types of the cluster that are used outside of
// the cluster. These results and types are used to create the clusters
// `tf_device.cluster` and associated terminator. Results that have no uses
// outside of the cluster (i.e. results of ops in the cluster are only consumed
// by other ops in the cluster) are pruned.
llvm::SmallVector<Value, 8> CollectClusterResults(
Block* block, const OpSetVector& cluster_ops) {
llvm::SmallVector<Value, 8> results;
for (Operation* op : cluster_ops) {
for (Value result : op->getResults()) {
for (Operation* user : result.getUsers()) {
// Check if user is not an op in the cluster.
if (cluster_ops.count(block->findAncestorOpInBlock(*user)) == 0) {
results.push_back(result);
break;
}
}
}
}
return results;
}
// Creates a `tf_device.cluster` to wrap cluster ops.
tf_device::ClusterOp CreateClusterOp(
Block* block, const OpSetVector& cluster_ops, llvm::ArrayRef<Value> results,
llvm::ArrayRef<Operation*> cluster_successor_ops) {
// `tf_device.cluster` will be placed at where the last op of the cluster is.
Operation* last_cluster_op = cluster_ops.back();
OpBuilder builder(last_cluster_op);
llvm::SmallVector<Type, 8> result_types;
for (Value result : results) result_types.push_back(result.getType());
auto cluster = builder.create<tf_device::ClusterOp>(last_cluster_op->getLoc(),
result_types);
Block* body = new Block;
cluster.body().push_back(body);
// Move cluster ops to the cluster body. Also remove `_replication_info` and
// `device` attribute from ops in the cluster when that information is
// redundant will the `tf_device.cluster`. Do this for all ops including
// nested ops.
for (Operation* cluster_op : cluster_ops) {
cluster_op->moveBefore(body, body->end());
cluster_op->walk([&](Operation* inner_op) {
inner_op->removeAttr(TF::kReplicationInfoAttr);
inner_op->removeAttr(TF::kCompileDeviceTypeAttr);
if (auto attr = inner_op->getAttrOfType<StringAttr>(kDeviceAttr)) {
// Preserve device attribute if the op is placed on a replicated core
// device. Device attribute is used to infer the appropriate sharding
// within TPUs for this op.
// TODO(b/183598857): Use explicit sharding ops from the front-end.
// For example, dequeue ops generated by
// tensorflow/python/tpu/tpu_feed.py
if (!tensorflow::IsTPUReplicatedCore(attr.getValue())) {
inner_op->removeAttr(kDeviceAttr);
}
}
});
}
// Add terminator.
builder.setInsertionPointToEnd(body);
builder.create<tf_device::ReturnOp>(last_cluster_op->getLoc(), results);
// Replaces uses of cluster ops results outside of cluster with the associated
// `tf_device.cluster` results.
for (auto ret_vals : llvm::zip(results, cluster.getResults())) {
Value old_ret = std::get<0>(ret_vals);
Value new_ret = std::get<1>(ret_vals);
for (auto& use : llvm::make_early_inc_range(old_ret.getUses())) {
Operation* user = use.getOwner();
if (!body->findAncestorOpInBlock(*user)) use.set(new_ret);
}
}
// Move ops that depend on something in the cluster behind the cluster.
Operation* op_after_cluster = cluster.getOperation()->getNextNode();
for (Operation* op : cluster_successor_ops) op->moveBefore(op_after_cluster);
return cluster;
}
// Sorts `tf.TPUReplicatedInput` ops by `index` attribute. Ops with an `index`
// of -1 are always after ops with a non negative `index`, and an arbitrary
// ordering is used as there are no dependencies on their relative ordering. If
// there are multiple `tf.TPUReplicatedInput` ops with the same non negative
// index or if indices are less than -1, an error will be returned.
LogicalResult SortTPUReplicatedInputsByIndex(
llvm::ArrayRef<Operation*> inputs,
llvm::SmallVectorImpl<Operation*>* sorted_inputs) {
llvm::SmallDenseSet<int64_t, 8> unique_indices;
for (Operation* input : inputs) {
int64_t index = llvm::cast<TF::TPUReplicatedInputOp>(input).index();
if (index < -1)
return input->emitOpError()
<< "requires index to be at least -1, but got " << index;
if (index == -1) continue;
if (!unique_indices.insert(index).second)
return input->emitOpError()
<< "requires indices to be unique, but found multiple '"
<< input->getName() << "' ops with index " << index;
}
// Sort all TPUReplicatedInputs by `index` attribute to have
// TPUReplicatedInputs with indices be added to the `tf_device.replicate` op
// deterministically. If `index` attribute is -1, instead move them to the
// end.
sorted_inputs->assign(inputs.begin(), inputs.end());
std::stable_sort(
sorted_inputs->begin(), sorted_inputs->end(),
[](Operation* l, Operation* r) {
int64_t l_index = llvm::cast<TF::TPUReplicatedInputOp>(l).index();
int64_t r_index = llvm::cast<TF::TPUReplicatedInputOp>(r).index();
if (l_index == -1 && r_index != -1) return false;
if (r_index == -1 && l_index != -1) return true;
return l_index < r_index;
});
return success();
}
// Creates a `tf_device.replicate` to represent replication for the cluster, if
// necessary.
LogicalResult ReplicateCluster(tf_device::ClusterOp cluster, int num_replicas,
int num_cores_per_replica) {
// No need to replicate.
if (num_replicas == 1) return success();
if (num_replicas < 1)
return cluster.emitError() << "requires '" << kNumReplicasAttr
<< "' int attribute to be at least 1";
LogicalResult status = success();
// Collect all used TPUReplicatedInput ops and sort by `index`.
OpSetVector unique_replicated_input_ops;
mlir::visitUsedValuesDefinedAbove(
cluster.body(), cluster.body(), [&](mlir::OpOperand* operand) {
Operation* def = operand->get().getDefiningOp();
if (llvm::isa_and_nonnull<TF::TPUReplicatedInputOp>(def))
unique_replicated_input_ops.insert(def);
// When model parallelism is used in conjunction with data parallelism
// for resource inputs, we need to collect the per replica resource
// inputs from input to `tf.TPUPartitionedInput` ops.
if (auto pi = llvm::dyn_cast_or_null<TF::TPUPartitionedInputOp>(def)) {
if (pi->getNumOperands() != num_cores_per_replica)
status = pi.emitOpError()
<< "requires " << num_cores_per_replica
<< " operands but found " << pi->getNumOperands();
for (auto operand : pi.inputs()) {
if (llvm::isa_and_nonnull<TF::TPUReplicatedInputOp>(
operand.getDefiningOp()))
unique_replicated_input_ops.insert(operand.getDefiningOp());
}
}
});
if (failed(status)) return failure();
llvm::SmallVector<Operation*, 8> replicated_input_ops;
if (failed(SortTPUReplicatedInputsByIndex(
unique_replicated_input_ops.getArrayRef(), &replicated_input_ops)))
return failure();
// Index attribute value stored on TPUReplicatedInput op. These will be used
// later for dynamic padder.
llvm::SmallVector<int64_t, 8> replicated_input_indices;
llvm::SmallVector<int64_t, 8> packed_input_indices;
bool has_replicated_input_index = false;
// Indices of the replicate op's arguments that are mirrored variables.
llvm::SmallVector<int64_t, 8> mirrored_variable_indices;
// Check if number of operands of each used TPUReplicatedInput op matches
// `num_replicas` or 1. Collect all their operands and associated type for
// creating the replicate op.
llvm::SmallVector<std::pair<ValueRange, Type>, 8> replicated_inputs;
llvm::SmallVector<Value, 8> packed_inputs;
llvm::SmallVector<Operation*, 8> replicated_ops;
llvm::SmallVector<Operation*, 8> packed_ops;
for (auto& pos_and_input : llvm::enumerate(replicated_input_ops)) {
auto input = pos_and_input.value();
bool is_packed = llvm::cast<TF::TPUReplicatedInputOp>(input).is_packed();
const int num_operands = input->getNumOperands();
int num_inputs = is_packed ? 1 : num_replicas;
if (num_operands != num_inputs)
return input->emitOpError() << "requires " << num_inputs << " operands";
auto tpu_replicated_input = llvm::cast<TF::TPUReplicatedInputOp>(input);
int64_t tpu_replicated_input_index = tpu_replicated_input.index();
if (is_packed) {
packed_inputs.push_back(input->getOperand(0));
packed_input_indices.push_back(tpu_replicated_input_index);
packed_ops.push_back(input);
} else {
replicated_inputs.push_back(
{input->getOperands(), input->getOperand(0).getType()});
replicated_input_indices.push_back(tpu_replicated_input_index);
replicated_ops.push_back(input);
}
if (tpu_replicated_input_index != -1) has_replicated_input_index = true;
if (tpu_replicated_input.is_mirrored_variable())
mirrored_variable_indices.push_back(pos_and_input.index());
}
replicated_input_indices.append(packed_input_indices.begin(),
packed_input_indices.end());
// Create replicate op.
OpBuilder builder(cluster);
auto replicate_op = builder.create<tf_device::ReplicateOp>(
cluster.getLoc(), num_replicas,
llvm::SmallDenseMap<llvm::StringRef, llvm::SmallVector<StringRef, 4>>(),
replicated_inputs, packed_inputs, cluster.getResultTypes());
if (has_replicated_input_index)
replicate_op->setAttr(kReplicatedInputIndicesAttr,
builder.getI64ArrayAttr(replicated_input_indices));
if (!mirrored_variable_indices.empty())
replicate_op->setAttr(kMirroredVariableIndicesAttr,
builder.getI64ArrayAttr(mirrored_variable_indices));
// Replace replicated cluster results with replicate op results.
for (auto result_and_idx : llvm::enumerate(cluster.getResults())) {
Value result = result_and_idx.value();
int idx = result_and_idx.index();
auto replicate_outputs = llvm::make_range(
std::next(replicate_op.result_begin(), idx * num_replicas),
std::next(replicate_op.result_begin(), (idx + 1) * num_replicas));
for (auto& use : llvm::make_early_inc_range(result.getUses())) {
Operation* def = use.getOwner();
if (!llvm::isa<TF::TPUReplicatedOutputOp>(def)) {
// If user is not a `tf.TPUReplicatedOutput`, simply forward the first
// replica output. Certain Graphs under V1 create `tf.Identity` users of
// replicated ops to pin the TPU computation for execution.
use.set(*replicate_outputs.begin());
continue;
}
const int def_num_results = def->getNumResults();
if (def_num_results != num_replicas)
return def->emitOpError() << "requires " << num_replicas << " results";
def->replaceAllUsesWith(replicate_outputs);
}
}
// Collect all `tf.TPUPartitionedInput` ops to be moved inside the
// `tf_device.replicate` later.
llvm::SmallSet<Operation*, 4> partitioned_inputs;
// Update replicated inputs with replicate op block arguments.
auto ordered_tpu_replicate_inputs =
llvm::concat<Operation*>(replicated_ops, packed_ops);
for (auto input_and_block_arg :
llvm::zip(ordered_tpu_replicate_inputs,
replicate_op.GetBody().getArguments())) {
Operation* input = std::get<0>(input_and_block_arg);
Value block_arg = std::get<1>(input_and_block_arg);
mlir::replaceAllUsesInRegionWith(input->getResult(0), block_arg,
cluster.body());
// Update replicated input use in tf.TPUPartitionedInput op.
for (auto& use : input->getUses()) {
auto pi = llvm::dyn_cast<TF::TPUPartitionedInputOp>(use.getOwner());
if (pi) {
pi.setOperand(use.getOperandNumber(), block_arg);
partitioned_inputs.insert(pi.getOperation());
}
}
}
// Create terminator for replicate op and move `tf_device.cluster` and
// `tf.TPUPartitionedInput`(s) into replicate body.
builder.setInsertionPointToEnd(&replicate_op.GetBody());
auto return_op = builder.create<tf_device::ReturnOp>(replicate_op.getLoc(),
cluster.getResults());
for (auto pi : partitioned_inputs) pi->moveBefore(return_op);
cluster.getOperation()->moveBefore(return_op);
return success();
}
// Forms clusters with ops of the same `_replication_info` attribute under a
// block.
//
// For a given block, clusters are formed via grouping ops by
// `_replication_info` attributes. For every cluster formed:
// 1. Find associated TPUReplicateMetadata attributes with the same
// `_replication_info` attribute.
// 2. Find users not in cluster that are interleaved between cluster ops.
// 3. Find external uses of cluster ops.
// 4. Create `tf_device.cluster` with results consisting of the external uses
// of cluster ops determined at 3.
// 5. Move cluster ops to `tf_device.cluster` body.
// 6. Replace external uses of cluster ops uses with `tf_device.cluster`
// results.
// 7. Move users from 2 to after the `tf_device.cluster`.
// 8. Wrap cluster (`tf_device.cluster`) in a `tf_device.replicate` if
// attribute `num_replicas` is greater than 1.
// 9. Copy over TPUReplicateMetadata attributes to `tf_device.cluster`.
LogicalResult FormClustersInBlock(
Block* block, const TF::SideEffectAnalysis::Info& side_effect_analysis) {
MetadataMap metadata_map;
LogicalResult result = CollectMetadata(block, &metadata_map);
if (failed(result)) return result;
// If there is no TPUReplicateMetadata op in this block, process blocks in
// regions attached to the op's in the block.
if (metadata_map.empty()) {
for (Operation& op : *block) {
for (Region& region : op.getRegions()) {
if (!llvm::hasSingleElement(region))
return op.emitOpError("Expected single block region");
if (failed(FormClustersInBlock(&region.front(), side_effect_analysis)))
return failure();
}
}
return success();
}
ClusterMap clusters;
result = CollectAndGroupClusterOps(block, &clusters);
if (failed(result)) return result;
for (const auto& cluster_metadata_and_ops : clusters) {
const auto& cluster_ops = cluster_metadata_and_ops.getSecond();
auto cluster_metadata =
metadata_map.find(cluster_metadata_and_ops.getFirst());
// No TPUReplicateMetadata for a `_replication_info` attribute.
if (cluster_metadata == metadata_map.end()) {
cluster_ops.front()->emitWarning()
<< "TPUReplicateMetadata for associated '" << TF::kReplicationInfoAttr
<< "' attribute '" << cluster_metadata_and_ops.getFirst()
<< "' is missing";
continue;
}
OpSetVector cluster_successor_ops =
CollectClusterSuccessorOps(block, cluster_ops, side_effect_analysis);
llvm::SmallVector<Value, 8> results =
CollectClusterResults(block, cluster_ops);
tf_device::ClusterOp cluster = CreateClusterOp(
block, cluster_ops, results, cluster_successor_ops.getArrayRef());
auto num_replicas = cluster_metadata->getSecond().get(kNumReplicasAttr);
if (!num_replicas || !num_replicas.isa<mlir::IntegerAttr>())
return cluster.emitError()
<< "requires '" << kNumReplicasAttr << "' int attribute";
int num_cores_per_replica = 1;
auto num_cores_per_replica_attr =
cluster_metadata->getSecond()
.get(kNumCoresPerReplicaAttr)
.dyn_cast_or_null<mlir::IntegerAttr>();
if (num_cores_per_replica_attr)
num_cores_per_replica = num_cores_per_replica_attr.getInt();
if (failed(ReplicateCluster(cluster,
num_replicas.cast<mlir::IntegerAttr>().getInt(),
num_cores_per_replica)))
return failure();
// Copy TPUReplicateMetadata attributes to `tf_device.cluster`.
cluster->setAttrs(
cluster_metadata->second.getDictionary(cluster.getContext()));
// Exclude `num_replicas` as cluster should be replicated if necessary.
cluster->removeAttr(kNumReplicasAttr);
}
return success();
}
LogicalResult FormClustersInFunction(
func::FuncOp func,
const TF::SideEffectAnalysis::Info& side_effect_analysis) {
if (!llvm::hasSingleElement(func))
return func.emitOpError("Expecting a single block function");
if (failed(FormClustersInBlock(&func.front(), side_effect_analysis)))
return failure();
// Remove TPUReplicatedInput and TPUReplicatedOutput nodes.
auto remove_result = func.walk([&](Operation* op) {
if (!llvm::isa<TF::TPUReplicatedInputOp, TF::TPUReplicatedOutputOp>(op))
return WalkResult::advance();
// Forward operand to result. When `num_replicas` attribute is 1, no
// `tf_device.replicate` is created and replicated (1) operands/results are
// untouched.
if (op->getNumOperands() == 1 && op->getNumResults() == 1)
op->getResult(0).replaceAllUsesWith(op->getOperand(0));
// Leftover TPUReplicatedInput/TPUReplicatedOutput that are not of
// `num_replicas` to 1.
if (!op->use_empty()) {
op->emitOpError() << "is expected to have no uses, but it is operand#"
<< op->use_begin()->getOperandNumber() << " of "
<< *op->use_begin()->getOwner();
return WalkResult::interrupt();
}
op->erase();
return WalkResult::advance();
});
return failure(remove_result.wasInterrupted());
}
void TPUClusterFormationPass::runOnOperation() {
auto& side_effect_analysis = getAnalysis<TF::SideEffectAnalysis>();
for (auto func : getOperation().getOps<func::FuncOp>())
if (!func.isExternal() &&
failed(FormClustersInFunction(
func, side_effect_analysis.GetAnalysisForFunc(func))))
return signalPassFailure();
}
} // anonymous namespace
std::unique_ptr<OperationPass<ModuleOp>> CreateTPUClusterFormationPass() {
return std::make_unique<TPUClusterFormationPass>();
}
} // namespace TFTPU
} // namespace mlir