blob: ce18e844e66050cdbf38b6a1571fb931610cdca1 [file] [log] [blame]
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/tpu/kernels/tpu_compile_op_common.h"
#include <string>
#include "absl/strings/string_view.h"
#include "tensorflow/compiler/jit/flags.h"
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
#include "tensorflow/compiler/xla/client/client_library.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/common_runtime/graph_constructor.h"
#include "tensorflow/core/common_runtime/graph_optimizer.h"
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/metrics.h"
#include "tensorflow/core/framework/resource_mgr.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/protobuf/tpu/compilation_result.pb.h"
#include "tensorflow/core/protobuf/tpu/compile_metadata.pb.h"
#include "tensorflow/core/protobuf/tpu/dynamic_padding.pb.h"
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_entry_unloader.h"
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_interface.h"
#include "tensorflow/core/tpu/kernels/tpu_compilation_metrics.h"
#include "tensorflow/core/tpu/kernels/tpu_compile_op_options.h"
#include "tensorflow/core/tpu/kernels/tpu_op_consts.h"
#include "tensorflow/core/tpu/kernels/tpu_op_util.h"
#include "tensorflow/core/tpu/kernels/tpu_program_group_interface.h"
#include "tensorflow/core/tpu/kernels/tpu_util.h"
#include "tensorflow/core/tpu/kernels/tpu_util_c_api.h"
#include "tensorflow/core/tpu/tpu_api.h"
#include "tensorflow/core/tpu/tpu_configuration.h"
#include "tensorflow/core/tpu/tpu_defs.h"
namespace tensorflow {
namespace tpu {
namespace {
static constexpr char kArgOp[] = "_Arg";
static constexpr char kRetvalOp[] = "_Retval";
std::string CoreDevice(int core) {
return strings::StrCat("/device:", DEVICE_TPU_REPLICATED_CORE, ":", core);
}
void ConvertGraphShapeInfoToShapeMap(
const Graph& graph, const GraphShapeInfo& graph_shape_info,
std::unordered_map<string, std::vector<PartialTensorShape>>* shape_map) {
// Builds a map from node name to Node* for `graph`.
std::unordered_map<string, Node*> index;
for (Node* node : graph.nodes()) {
index[node->name()] = node;
}
// Discards the resource handle shape info while converting to the correct map
// form.
for (const auto& node_shape_info : graph_shape_info) {
const string& node_name = node_shape_info.first;
const std::vector<InferredShape>& output_shapes = node_shape_info.second;
// Gets the vector of partial shapes, first converting node name to Node*
// using index. graph is the subgraph of the original graph assigned to a
// particular core, and we only add entries to shape_map for nodes in
// graph_shape_info that are in the subgraph.
const auto& node_iter = index.find(node_name);
if (node_iter != index.end()) {
auto& partial_shapes = (*shape_map)[node_name];
for (const auto& inferred_shape : output_shapes) {
partial_shapes.push_back(inferred_shape.shape);
}
}
}
}
// Sets arg shape, arg core mapping, and per core arg shapes for a given
// argument, depending on its sharding.
Status SetPerCoreArgShapes(
const tpu::TPUCompileMetadataProto::Arg& proto_arg, const int arg_index,
xla::Shape* xla_arg_shape,
std::vector<tpu::ShardingAndIndex>* arg_core_mapping,
std::vector<std::vector<xla::Shape>>* per_core_arg_shapes) {
if (proto_arg.unrestricted_layout()) {
xla_arg_shape->clear_layout();
}
(*arg_core_mapping)[arg_index].sharding = proto_arg.sharding();
if (proto_arg.sharding().type() == xla::OpSharding::MAXIMAL) {
const int core = proto_arg.sharding().tile_assignment_devices(0);
TF_RET_CHECK(0 <= core && core < per_core_arg_shapes->size());
(*arg_core_mapping)[arg_index].indices.push_back(
(*per_core_arg_shapes)[core].size());
(*per_core_arg_shapes)[core].push_back(*xla_arg_shape);
} else if (proto_arg.sharding().type() == xla::OpSharding::OTHER) {
TF_ASSIGN_OR_RETURN(xla::HloSharding hlo_sharding,
xla::HloSharding::FromProto(proto_arg.sharding()));
for (int core : proto_arg.sharding().tile_assignment_devices()) {
(*arg_core_mapping)[arg_index].indices.push_back(
(*per_core_arg_shapes)[core].size());
xla::Shape per_core_shape =
GetPerDeviceShape(*xla_arg_shape, hlo_sharding, core);
if (proto_arg.unrestricted_layout()) {
per_core_shape.clear_layout();
}
(*per_core_arg_shapes)[core].push_back(per_core_shape);
}
} else {
TF_RET_CHECK(proto_arg.sharding().type() == xla::OpSharding::REPLICATED)
<< "Unsupported argument sharding: "
<< " proto_arg=" << proto_arg.DebugString();
for (int core = 0; core < per_core_arg_shapes->size(); ++core) {
(*arg_core_mapping)[arg_index].indices.push_back(
(*per_core_arg_shapes)[core].size());
(*per_core_arg_shapes)[core].push_back(*xla_arg_shape);
}
}
return Status::OK();
}
} // namespace
CompileOpImplFactory* CompileOpImplFactory::factory_ = nullptr;
/* static */
CompileOpImplFactory* CompileOpImplFactory::Get() { return factory_; }
/* static */
void CompileOpImplFactory::Register(CompileOpImplFactory* factory) {
CHECK_EQ(factory_, nullptr)
<< "CompileOpImplFactory can only be registered "
"once and there can only be one factory active and used.";
factory_ = factory;
}
Status TpuCompileOpKernelCommon::AssignReturnValueToCore(
std::vector<tpu::ShardingAndIndex>* retval_core_mapping) {
std::vector<int> per_core_retval_counts(metadata_.num_cores_per_replica(), 0);
for (int i = 0; i < metadata_.retvals_size(); ++i) {
const tpu::TPUCompileMetadataProto::Retval& proto_retval =
metadata_.retvals(i);
(*retval_core_mapping)[i].sharding = proto_retval.sharding();
if (proto_retval.sharding().type() == xla::OpSharding::MAXIMAL) {
int core = proto_retval.sharding().tile_assignment_devices(0);
TF_RET_CHECK(0 <= core && core < per_core_retval_counts.size());
(*retval_core_mapping)[i].indices.push_back(
per_core_retval_counts[core]++);
} else if (proto_retval.sharding().type() == xla::OpSharding::OTHER) {
for (int64 core : proto_retval.sharding().tile_assignment_devices()) {
(*retval_core_mapping)[i].indices.push_back(
per_core_retval_counts[core]++);
}
} else {
TF_RET_CHECK(proto_retval.sharding().type() ==
xla::OpSharding::REPLICATED)
<< "Unsupported return value sharding: "
<< proto_retval.sharding().DebugString();
for (int core = 0; core < per_core_retval_counts.size(); ++core) {
(*retval_core_mapping)[i].indices.push_back(
per_core_retval_counts[core]++);
}
}
}
return Status::OK();
}
Status TpuCompileOpKernelCommon::BuildComputationArgumentDescriptions(
const std::vector<TensorShape>& arg_shapes,
const GuaranteedConsts& guaranteed_constants, const XlaCompiler& compiler,
std::vector<XlaCompiler::Argument>* args,
std::vector<tpu::ShardingAndIndex>* arg_core_mapping,
std::vector<std::vector<xla::Shape>>* per_core_arg_shapes) {
// Builds a description of the computation's arguments.
int constant_count = 0;
size_t guaranteed_constants_size = 0;
for (int i = 0; i < metadata_.args_size(); ++i) {
const tpu::TPUCompileMetadataProto::Arg& proto_arg = metadata_.args(i);
args->push_back(XlaCompiler::Argument());
XlaCompiler::Argument& arg = args->back();
arg.type = proto_arg.dtype();
arg.shape = arg_shapes[i];
arg.node_name = proto_arg.name();
switch (proto_arg.kind()) {
case tpu::TPUCompileMetadataProto::Arg::PARAMETER:
arg.kind = XlaCompiler::Argument::kParameter;
break;
case tpu::TPUCompileMetadataProto::Arg::VARIABLE:
arg.kind = XlaCompiler::Argument::kResource;
arg.resource_kind = XlaResource::kVariable;
arg.initialized = true;
arg.fast_mem = proto_arg.fast_mem();
break;
case tpu::TPUCompileMetadataProto::Arg::GUARANTEED_CONSTANT:
arg.kind = XlaCompiler::Argument::kConstant;
guaranteed_constants_size =
guaranteed_constants.index() == 0
? absl::get<0>(guaranteed_constants).size()
: absl::get<1>(guaranteed_constants)->size();
TF_RET_CHECK(constant_count < guaranteed_constants_size)
<< "More constant args in TPUCompileMetadataProto than constant "
"tensors.";
if (guaranteed_constants.index() == 0) {
// `guaranteed_constants` is of type `absl::Span<const TensorProto*
// const>`.
Tensor tensor;
CHECK(tensor.FromProto(
*absl::get<0>(guaranteed_constants)[constant_count++]))
<< "Failed to deserialize invalid `TensorProto` into `Tensor`.";
arg.constant_value = tensor;
} else {
// `guaranteed_constants` is of type `const OpInputList* const`.
arg.constant_value =
(*absl::get<1>(guaranteed_constants))[constant_count++];
}
break;
case tpu::TPUCompileMetadataProto::Arg::INVALID:
default:
break;
}
arg.is_same_data_across_replicas = proto_arg.is_same_data_across_replicas();
if (arg.kind == XlaCompiler::Argument::kInvalid) {
return errors::InvalidArgument("Invalid argument kind");
}
if (arg.kind == XlaCompiler::Argument::kConstant) {
continue;
}
// Assign each argument a sharding.
xla::Shape xla_arg_shape;
TF_ASSIGN_OR_RETURN(auto arg_sharding,
xla::HloSharding::FromProto(proto_arg.sharding()));
TF_RETURN_IF_ERROR(compiler.XLAShapeForArgument(
arg, /*is_entry_computation=*/true, arg_sharding, &xla_arg_shape));
TF_RETURN_IF_ERROR(SetPerCoreArgShapes(
proto_arg, i, &xla_arg_shape, arg_core_mapping, per_core_arg_shapes));
}
TF_RET_CHECK(constant_count == guaranteed_constants_size)
<< "Not all of the constant tensors were consumed.";
return Status::OK();
}
Status TpuCompileOpKernelCommon::GetShardingInfo(
absl::Span<const TensorShape> arg_shapes,
const XlaCompiler::ShapeRepresentationFn shape_representation_fn,
std::vector<tpu::ShardingAndIndex>* arg_core_mapping,
std::vector<std::vector<xla::Shape>>* per_core_arg_shapes) {
int num_inputs = metadata_.args_size();
for (int i = 0; i < num_inputs; ++i) {
const auto& proto_arg = metadata_.args(i);
TF_ASSIGN_OR_RETURN(auto arg_sharding,
xla::HloSharding::FromProto(proto_arg.sharding()));
TF_ASSIGN_OR_RETURN(
auto xla_arg_shape,
shape_representation_fn(arg_shapes[i], proto_arg.dtype(),
/*use_fast_memory=*/false));
TF_RETURN_IF_ERROR(
RewriteLayoutWithShardedShape(arg_sharding, /*use_fast_memory=*/false,
shape_representation_fn, &xla_arg_shape));
TF_RETURN_IF_ERROR(SetPerCoreArgShapes(
proto_arg, i, &xla_arg_shape, arg_core_mapping, per_core_arg_shapes));
}
return Status::OK();
}
Status TpuCompileOpKernelCommon::CompileTFFunctionToHlo(
const FunctionLibraryDefinition& flib_def, int graph_def_version,
const XlaCompiler::ShapeRepresentationFn shape_representation_fn,
const std::vector<TensorShape>& arg_shapes,
const GuaranteedConsts& guaranteed_constants, const NameAttrList& function,
std::function<Status(ResourceMgr*)> populate_resource_manager_fn,
xla::CompileOnlyClient* client,
std::vector<tpu::ShardingAndIndex>* arg_core_mapping,
std::vector<std::vector<xla::Shape>>* per_core_arg_shapes,
XlaCompiler::CompilationResult* compilation_result) {
XlaCompiler::Options compiler_options;
compiler_options.device_type = DeviceType(DEVICE_TPU_XLA_JIT);
compiler_options.client = client;
compiler_options.flib_def = &flib_def;
compiler_options.allow_cpu_custom_calls = false;
compiler_options.populate_resource_manager = &populate_resource_manager_fn;
compiler_options.graph_def_version = graph_def_version;
compiler_options.shape_representation_fn = shape_representation_fn;
auto compiler = absl::make_unique<XlaCompiler>(compiler_options);
std::vector<XlaCompiler::Argument> args;
TF_RETURN_IF_ERROR(BuildComputationArgumentDescriptions(
arg_shapes, guaranteed_constants, *compiler, &args, arg_core_mapping,
per_core_arg_shapes));
// Assign each return value to a core.
std::vector<tpu::ShardingAndIndex> retval_core_mapping(
metadata_.retvals_size());
TF_RETURN_IF_ERROR(
TpuCompileOpKernelCommon::AssignReturnValueToCore(&retval_core_mapping));
LOG(INFO) << "Instantiating function:" << function.name();
FunctionLibraryRuntime::Handle handle;
TF_RETURN_IF_ERROR(compiler->flib_runtime()->Instantiate(
function.name(), AttrSlice(&function.attr()), &handle));
const FunctionBody* fbody = compiler->flib_runtime()->GetFunctionBody(handle);
const string function_id =
Canonicalize(function.name(), AttrSlice(&function.attr()));
std::unique_ptr<Graph> graph(new Graph(&flib_def));
CopyGraph(*fbody->graph, graph.get());
VLOG(2) << "metadata: " << metadata_.DebugString();
std::vector<int> parameter_arg_mapping;
for (int i = 0; i < args.size(); i++) {
XlaCompiler::Argument& arg = args[i];
if (arg.kind != XlaCompiler::Argument::kParameter) {
continue;
}
parameter_arg_mapping.push_back(i);
}
TF_RET_CHECK(fbody->arg_nodes.size() == args.size());
for (size_t i = 0; i < fbody->arg_nodes.size(); i++) {
args[i].node_name = fbody->arg_nodes[i]->name();
}
std::vector<gtl::InlinedVector<int64, 4>> arg_shape_dims;
arg_shape_dims.reserve(arg_shapes.size());
std::vector<PartialTensorShape> partial_arg_shapes(arg_shapes.size());
for (const TensorShape& shape : arg_shapes) {
arg_shape_dims.push_back(shape.dim_sizes());
}
for (const auto& padding_mapping : metadata_.padding_maps()) {
if (padding_mapping.padding_arg_index() >= parameter_arg_mapping.size()) {
return errors::Internal(absl::StrCat(
"TPUCompileMetadataProto `padding_maps` has `padding_arg_index` ",
padding_mapping.padding_arg_index(),
" which exceeds`parameter_arg_mapping` array bounds ",
parameter_arg_mapping.size(),
". this usually indicates there are dynamic shape inputs fed into "
"TPUs from outside compilation head extraction, which is not "
"supported"));
}
int padding_arg_index =
parameter_arg_mapping.at(padding_mapping.padding_arg_index());
args[parameter_arg_mapping.at(padding_mapping.arg_index())]
.dynamic_dim_to_arg_num_map[padding_mapping.shape_index()] =
padding_arg_index;
arg_shape_dims[parameter_arg_mapping.at(padding_mapping.arg_index())]
[padding_mapping.shape_index()] = -1;
args[padding_arg_index].is_pad_arg = true;
}
for (int64 i = 0; i < arg_shape_dims.size(); ++i) {
auto& dims = arg_shape_dims[i];
TF_RETURN_IF_ERROR(PartialTensorShape::MakePartialShape(
dims.data(), dims.size(), &partial_arg_shapes[i]));
}
// Adds device assignments to _Arg and _Retval nodes.
TF_RETURN_IF_ERROR(AssignDevicesToArgsAndRetvals(
absl::MakeSpan(*arg_core_mapping), absl::MakeSpan(retval_core_mapping),
graph.get()));
VLOG(1) << "Optimizing TensorFlow graph";
FunctionLibraryDefinition flib_definition(flib_def);
TF_RETURN_IF_ERROR(OptimizeGraph(metadata_, partial_arg_shapes, &graph,
compiler->flib_runtime(), &flib_definition));
VLOG(1) << "Compiling TensorFlow graph to HLO";
XlaCompiler::CompileOptions compile_options;
compile_options.return_updated_values_for_all_resources = false;
compile_options.use_tuple_arg = true;
compile_options.is_entry_computation = true;
compile_options.alias_resource_update = true;
return compiler->CompileGraph(compile_options, function_id, std::move(graph),
args, compilation_result);
}
/* static */ void TpuCompileOpKernelCommon::ExitCountdown(
Env* env, std::shared_ptr<std::atomic<bool>> done) {
const int kSleepSeconds = 300;
LOG(INFO) << "TpuCompileOp was cancelled. Sleeping for " << kSleepSeconds
<< " seconds to give time for TPUCompileOp to finished.";
env->SleepForMicroseconds(kSleepSeconds * 1000000);
if (done->load()) {
// If the TpuCompileOp has finished, then terminate peacefully.
return;
}
LOG(ERROR) << "Aborting process due to cancelled TpuCompileOp. This "
<< "termination is to ensure a consistent state.";
std::exit(42);
}
/* static */ Status TpuCompileOpKernelCommon::GetDynamicShapes(
OpKernelContext* ctx, std::vector<TensorShape>* shapes) {
OpInputList dynamic_shapes;
TF_RETURN_IF_ERROR(ctx->input_list("dynamic_shapes", &dynamic_shapes));
shapes->resize(dynamic_shapes.size());
for (int i = 0; i < dynamic_shapes.size(); ++i) {
TF_RETURN_IF_ERROR(
tpu::ShapeTensorToTensorShape(dynamic_shapes[i], &(*shapes)[i]));
}
return Status::OK();
}
// Function arguments and return values lose their device assignments, so we
// must recreate them.
/* static */ Status TpuCompileOpKernelCommon::AssignDevicesToArgsAndRetvals(
absl::Span<const tpu::ShardingAndIndex> arg_core_mapping,
absl::Span<const tpu::ShardingAndIndex> retval_core_mapping, Graph* graph) {
auto assign = [&](Node* node, const xla::OpSharding& sharding) -> Status {
if (sharding.type() == xla::OpSharding::MAXIMAL) {
const string device = CoreDevice(sharding.tile_assignment_devices(0));
node->set_assigned_device_name(device);
node->set_requested_device(device);
} else {
TF_RET_CHECK(sharding.type() == xla::OpSharding::REPLICATED ||
sharding.type() == xla::OpSharding::OTHER)
<< "Unsupported sharding on parameter/retval: "
<< sharding.DebugString();
}
node->AddAttr("_XlaSharding", sharding.SerializeAsString());
return Status::OK();
};
for (Node* node : graph->op_nodes()) {
if (node->type_string() == kArgOp) {
int index;
TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), "index", &index));
TF_RET_CHECK(index >= 0 && index < arg_core_mapping.size());
TF_RETURN_IF_ERROR(assign(node, arg_core_mapping[index].sharding));
} else if (node->type_string() == kRetvalOp) {
int index;
TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), "index", &index));
TF_RET_CHECK(index >= 0 && index < retval_core_mapping.size());
TF_RETURN_IF_ERROR(assign(node, retval_core_mapping[index].sharding));
}
}
return Status::OK();
}
// Performs shape inference on the body of `graph`. Shapes for arguments
// are taken from `metadata` and `arg_shapes`.
/* static */ Status TpuCompileOpKernelCommon::RunShapeInferenceOnComputation(
const tpu::TPUCompileMetadataProto& metadata,
const std::vector<PartialTensorShape>& arg_shapes, Graph* graph,
FunctionLibraryRuntime* flr, GraphShapeInfo* shape_info) {
int num_args = arg_shapes.size();
CHECK_EQ(num_args, metadata.args_size());
std::map<int, InferredShape> arg_shapes_for_inference;
for (int i = 0; i < num_args; ++i) {
const auto& arg = metadata.args(i);
InferredShape& shape_for_inference = arg_shapes_for_inference[i];
if (arg.kind() == tpu::TPUCompileMetadataProto::Arg::VARIABLE) {
// For resource variables, arg_shapes[] contains the shape of the
// variable's value.
shape_for_inference.handle_type = arg.dtype();
shape_for_inference.handle_shape = arg_shapes[i];
// The shape of the variable itself is always a scalar.
shape_for_inference.shape = TensorShape();
} else {
if (arg.kind() ==
tpu::TPUCompileMetadataProto::Arg::GUARANTEED_CONSTANT) {
VLOG(1) << "PromisedConstant shape: " << arg_shapes[i].DebugString();
}
shape_for_inference.shape = arg_shapes[i];
}
}
return InferShapes(
graph, arg_shapes_for_inference,
flr != nullptr ? flr->GetFunctionLibraryDefinition() : nullptr,
shape_info);
}
Status TpuCompileOpKernelCommon::OptimizeGraph(
const tpu::TPUCompileMetadataProto& metadata,
const std::vector<PartialTensorShape>& arg_shapes,
std::unique_ptr<Graph>* graph, FunctionLibraryRuntime* flr,
FunctionLibraryDefinition* fld) {
// Sets up options for the optimization passes that need to be done. Notice
// that CSE is not needed as XLA has its own CSE passes later in the
// compilation stage.
auto flags = GetBuildXlaOpsPassFlags();
OptimizerOptions opts;
opts.set_opt_level(OptimizerOptions::L0);
opts.set_do_common_subexpression_elimination(false);
opts.set_do_function_inlining(true);
opts.set_do_constant_folding(!flags->tf_xla_disable_constant_folding);
GraphOptimizer optimizer(opts);
// Performs a first function inlining pass before shape inference, since
// otherwise shape inference can't see inside functions and a comprehensive
// shape_map, including function ops, is needed to constant-propagate Shape
// Ops below.
GraphOptimizer::Options optimizer_opts;
optimizer_opts.inline_multi_device_functions = true;
optimizer_opts.inline_impl_selection_group_functions = true;
optimizer_opts.inline_with_single_device_body_placer = true;
optimizer.Optimize(flr, flr->env(), flr->device(), graph, optimizer_opts);
// Infer shapes for each node in the computation.
GraphShapeInfo shape_info;
TF_RETURN_IF_ERROR(RunShapeInferenceOnComputation(
metadata, arg_shapes, graph->get(), flr, &shape_info));
// Converts the GraphShapeInfo into the form needed by the constant-folding
// pass of the optimizer.
std::unordered_map<string, std::vector<PartialTensorShape>> shape_map;
ConvertGraphShapeInfoToShapeMap(**graph, shape_info, &shape_map);
optimizer.Optimize(flr, flr->env(), flr->device(), graph, &shape_map);
TF_RETURN_IF_ERROR(RewriteTensorListWithConstElement(graph->get(), fld));
return Status::OK();
}
void TpuCompileOpKernelCommon::Compute(OpKernelContext* ctx) {
VLOG(1) << "Cloud TPU: TpuCompileOpKernelCommon::Compute";
std::shared_ptr<std::atomic<bool>> done(new std::atomic<bool>(false));
CancellationToken token =
ctx->cancellation_manager()->get_cancellation_token();
const bool already_cancelled =
!ctx->cancellation_manager()->RegisterCallback(token, [ctx, done]() {
if (UtilApiFn()->TpuCompile_ShouldTpuCompileOpIgnoreCancellationFn()) {
return;
}
// Sleep and exit in another thread so the cancellation manager can
// continue running callbacks.
Env* env = ctx->env();
env->SchedClosure([env, done]() { ExitCountdown(env, done); });
});
// If the RPC was cancelled before we registered the cancellation callback,
// don't compile the TPU program.
OP_REQUIRES(ctx, !already_cancelled,
errors::Cancelled("RPC cancelled, not compiling TPU program"));
// We only want to abort the process if a cancellation actually occurs during
// compilation; we must deregister the callback in the success case. It
// doesn't hurt to also deregister the callback in the failure case; the
// CancellationManager ensures that already-registered callbacks will be run
// once cancellation has started.
auto cancellation_cleanup = xla::MakeCleanup([ctx, token, done] {
ctx->cancellation_manager()->DeregisterCallback(token);
done->store(true);
});
OP_REQUIRES_OK(ctx, ComputeInternal(ctx));
}
Status TpuCompileOpKernelCommon::CompileLocallyAndFillHostCache(
FunctionLibraryRuntime* flib_runtime,
const SessionMetadata* session_metadata,
const TpuMeshStateInterface* mesh_state,
const std::vector<TensorShape>& dynamic_shapes,
const OpInputList& guaranteed_constants, const TpuCompilationCacheKey& key,
TpuProgramGroupInterface* tpu_program_group) {
absl::Time start_time = absl::Now();
std::vector<TensorShape> arg_shapes;
TF_RETURN_IF_ERROR(
ComputeArgumentShapes(metadata_, dynamic_shapes, &arg_shapes));
Status compile_status;
if (use_mlir_) {
compile_status = Compile(MlirToHloArgs{mlir_module_}, mesh_state->data(),
arg_shapes, tpu_program_group);
} else {
compile_status =
Compile(FunctionToHloArgs{&function_,
flib_runtime->GetFunctionLibraryDefinition(),
flib_runtime->graph_def_version(),
{&guaranteed_constants}},
mesh_state->data(), arg_shapes, tpu_program_group);
}
absl::Time end_time = absl::Now();
auto duration = end_time - start_time;
const std::string session_name = SessionNameFromMetadata(session_metadata);
LOG(INFO) << "Compilation of " << key.prefix << " with session name "
<< session_name << " took " << duration;
tpu_program_group->LogProgramMemorySummary();
metrics::UpdateXlaCompilationTime(absl::ToInt64Microseconds(duration));
TpuCompilationMetrics::IncrementCompilationCount(session_name);
TF_RETURN_IF_ERROR(tpu_program_group->LogCompilationStats(key, duration));
return compile_status;
}
Status TpuCompileOpKernelCommon::ComputeInternal(OpKernelContext* ctx) {
VLOG(1) << "Retrieving mesh state";
// Retrieve the topology from the resource manager
ResourceMgr* rm = GetTPUConfigResourceMgr();
TpuMeshStateInterface* mesh_state;
TF_RETURN_IF_ERROR(rm->Lookup(rm->default_container(),
kTpuMeshStateInterfaceResourceName,
&mesh_state));
core::ScopedUnref mesh_state_unref(mesh_state);
std::vector<TensorShape> dynamic_shapes;
TF_RETURN_IF_ERROR(GetDynamicShapes(ctx, &dynamic_shapes));
OpInputList guaranteed_constants;
// TODO(ycao): Decide whether/how to support guaranteed constants in
// MLIR-based TF-Compiler Bridge.
if (!use_mlir_) {
TF_RETURN_IF_ERROR(
ctx->input_list("guaranteed_constants", &guaranteed_constants));
}
const TpuCompilationCacheKey key = CreateCompilationCacheKey(
function_.name(), metadata_.function_library_fingerprint(), mlir_module_,
guaranteed_constants, dynamic_shapes, metadata_, *mesh_state);
// Process-wide cache of TPU executables.
TpuCompilationCacheInterface* cache;
TF_RETURN_IF_ERROR(rm->Lookup<TpuCompilationCacheInterface>(
rm->default_container(), kCompilationCacheResourceName, &cache));
core::ScopedUnref cache_unref(cache);
// Per-step object that ensures that compilation cache entries aren't
// evicted until the step completes. This mechanism ensures that the
// downstream TPUExecute Ops in this step will be able to look up the
// compiled executable even if it is marked for eviction before the step
// ends.
//
// We can't use GetTPUConfigResourceMgr here because it may return the
// global ResourceMgr, which is not associated with any device, and
// GraphMgr's ScopedStepContainer only searches ResourceMgrs associated
// with devices when deleting resources at step boundaries.
CompilationRefHolder* ref_holder;
if (ctx->step_container() == nullptr) {
return errors::FailedPrecondition(
"TPUCompileOp requires a step container.");
}
TF_RETURN_IF_ERROR(
ctx->step_container()->LookupOrCreate<CompilationRefHolder>(
ctx->resource_manager(), "ref_holder", &ref_holder,
[cache](CompilationRefHolder** h) {
*h = cache->MakePerStepRefHolder();
return Status::OK();
}));
core::ScopedUnref ref_holder_unref(ref_holder);
int64 uid;
std::vector<std::string> proto_key;
std::vector<bool> may_modify_variables;
absl::Span<const xla::HloProto* const> hlo_metadatas;
Status status = cache->CompileIfKeyAbsent(
key, ctx->session_metadata(), ref_holder, &uid, &proto_key,
&may_modify_variables, &hlo_metadatas,
[&](TpuProgramGroupInterface* tpu_program_group) {
VLOG(1) << "Cloud TPU: Compiling TPU program";
// When this compile function is invoked, we know that host-memory
// cache TpuCompilationCache saw a cache miss. There are two codepaths:
// 1. If persistent cache is disabled, compile locally and populate
// host-memory cache.
// 2. If persistent cache is enabled, we do an additional lookup on
// the persistent cache.
// - If persistent cache also sees a cache miss, trigger
// compilation. Then, populate both persistent cache and
// host-memory cache.
// - If persistent cache sees a cache hit, retrieve cache entry from
// persistent cache to populate host-memory cache without
// recompilation. If retrieval failed, compile locally as a
// fallback and use the local compilation result to populate
// host-memory cache.
if (persistent_cache_ == nullptr) {
VLOG(1) << "Persistent compilation cache not enabled. Compiling "
"TPU executable locally and populating host-memory cache.";
return CompileLocallyAndFillHostCache(
ctx->function_library(), ctx->session_metadata(), mesh_state,
dynamic_shapes, guaranteed_constants, key, tpu_program_group);
}
return LookupPersistentCompilationCacheAndFillCaches(
ctx->function_library(), ctx->session_metadata(), mesh_state,
dynamic_shapes, guaranteed_constants, persistent_cache_.get(), key,
tpu_program_group);
});
// `ref_holder` is provided to CompileIfKeyAbsent to ensure that cache
// entry does not get evicted before TpuExecuteOp runs it and discards
// `ref_holder`. When TpuCompilationCacheEntryUnloader get destroyed in the
// event that user closes the session while there are in-flight program
// executions, it will discard the cache's reference to the cache entry
// and but not removed the entry until `ref_holder` discards the last
// reference to the entry. This ensures that the guarantees of
// `ref_holder` is not violated when this flag is true.
if (unload_cache_entry_on_session_close_) {
// Place `unloader` in TPU_SYSTEM device resource manager. Note that
// - TPUConfigResourceMgr returned by GetTPUConfigResourceMgr() is a special
// process-global ResourceMgr. There is only one TPUConfigResourceMgr, and
// it is never destroyed.
// - TPU_SYSTEM device resource manager is a normal device ResourceMgr for
// TPU_SYSTEM device. If DirectSession or isolate_session_state are used,
// there's one TPU_SYSTEM ResourceMgr for each session, and the
// ResourceMgrs will be destroyed when their corresponding session is
// closed. Otherwise there's one TPU_SYSTEM ResourceMgr that's only
// destroyed when the master-session is destroyed, not when the worker
// sessions are destroyed
TpuCompilationCacheEntryUnloader* unloader;
TF_RETURN_IF_ERROR(
ctx->resource_manager()
->LookupOrCreate<TpuCompilationCacheEntryUnloader>(
ctx->resource_manager()->default_container(),
kCompilationCacheUnloaderResourceName, &unloader,
[cache](TpuCompilationCacheEntryUnloader** new_unloader) {
*new_unloader = new TpuCompilationCacheEntryUnloader(cache);
return Status::OK();
}));
// Note that LookupOrCreate puts two refcounts on unloader.
core::ScopedUnref unloader_unref(unloader);
unloader->AddCacheEntryUid(uid);
}
int64 num_cores_with_compiled_programs = proto_key.size();
if (proto_key.size() == 1) {
// SPMD produces 1 program for all cores.
num_cores_with_compiled_programs = metadata_.num_cores_per_replica();
}
if (status.ok() &&
num_cores_with_compiled_programs +
(may_modify_variables.size() * static_cast<int>(!use_mlir_)) !=
ctx->num_outputs() - 1) {
status = errors::Internal(
"Number of cores with compiled programs (",
num_cores_with_compiled_programs, ") + variable states (",
may_modify_variables.size() * static_cast<int>(!use_mlir_),
") + compilation status output != number of compile op outputs (",
ctx->num_outputs(), ")");
}
// TODO(jpienaar): status is not just due to the compilation. At this
// point we should be failing the execution of the op in some cases and
// returning a compilation error in others. For now, uniformly return an
// error and fail in _TPUExecute if status failed here.
// TODO(misard) the frame id will be wrong if this is ever called from
// within a function. Consider whether to use the same hack as is
// present in the rendezvous manager where the function call frame is
// cast to a uint64, or do something better all around.
std::string rendezvous_key_base = strings::StrCat(
"host_compute_rendezvous:", ctx->op_kernel().name(), ":",
ctx->frame_iter().frame_id, ":", ctx->frame_iter().iter_id, ":");
// Return compilation status.
{
Tensor output(DT_STRING, TensorShape({}));
tpu::CompilationResultProto proto;
proto.set_status_code(status.code());
if (!status.ok()) {
proto.set_status_error_message(
absl::StrCat("Compilation failure: ", status.error_message()));
}
if (return_hlo_protos_) {
// Return the HloProtos as part of compilation status.
for (const xla::HloProto* hlo_metadata : hlo_metadatas) {
xla::HloProto* hlo_proto = proto.add_hlo_protos();
*hlo_proto = *hlo_metadata;
}
}
SerializeToTString(proto, &output.scalar<tstring>()());
ctx->set_output(0, output);
}
if (status.ok()) {
for (int i = 0; i < num_cores_with_compiled_programs; ++i) {
Tensor output(DT_STRING, TensorShape({2}));
if (proto_key.size() == 1) {
output.vec<tstring>()(0) = proto_key[0];
} else {
output.vec<tstring>()(0) = proto_key[i];
}
output.vec<tstring>()(1) = rendezvous_key_base;
ctx->set_output(i + 1, output);
}
if (!use_mlir_) {
// If any of the programs may modify a variable, then return that all
// do as the only current state being tracked here is if a model is
// read-only or not.
bool may_modify = false;
for (bool m : may_modify_variables) {
may_modify = may_modify || m;
}
for (int i = 0; i < may_modify_variables.size(); ++i) {
Tensor output(DT_BOOL, TensorShape({}));
output.scalar<bool>()() = may_modify;
ctx->set_output(i + num_cores_with_compiled_programs + 1, output);
}
}
VLOG(1) << "Cloud TPU: Compilation succeeded";
} else {
// Return error in the invalid case.
for (int i = 0; i < num_computations_; ++i) {
Tensor output(DT_STRING, TensorShape({2}));
output.vec<tstring>()(0) = "<<NO PROGRAM AS COMPILATION FAILED>>";
output.vec<tstring>()(1) = "<<NO RENDEZVOUS KEY AS COMPILATION FAILED>>";
ctx->set_output(i + 1, output);
}
if (!use_mlir_) {
// The TPUCompileMLIR op does not have MayModifyVariable output
for (int i = 0; i < num_computations_; ++i) {
Tensor output(false);
ctx->set_output(i + num_computations_ + 1, output);
}
}
}
return Status::OK();
}
} // namespace tpu
} // namespace tensorflow