blob: 14f03b156c8778bb7a791b02b48696477d138e4e [file] [log] [blame]
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/tf2xla/kernels/while_op.h"
#include "absl/strings/str_split.h"
#include "tensorflow/compiler/tf2xla/const_analysis.h"
#include "tensorflow/compiler/tf2xla/kernels/if_while_utils.h"
#include "tensorflow/compiler/tf2xla/kernels/tensor_list_utils.h"
#include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/compiler/tf2xla/side_effect_util.h"
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
#include "tensorflow/compiler/tf2xla/type_util.h"
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/client/xla_computation.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/op_kernel.h"
namespace tensorflow {
namespace {
// Verify that input resources are grouped in the end.
Status VerifyResourceArgsGroupedAtEnd(XlaOpKernelContext* ctx,
const NameAttrList& body_name_attr) {
const FunctionBody* body;
TF_RETURN_IF_ERROR(ctx->compiler()->FindFunctionBody(body_name_attr, &body));
bool has_seen_resource = false;
for (int i = 0; i < body->arg_types.size(); i++) {
DataType arg_type = body->arg_types[i];
if (has_seen_resource) {
if (arg_type != DT_RESOURCE) {
return errors::InvalidArgument(
"Expect input resources are grouped in the end of while body ",
body_name_attr.name(), ", but the ", i, "-th argument ",
body->arg_nodes[i]->name(), " is not a resource.");
}
} else {
if (arg_type == DT_RESOURCE) {
has_seen_resource = true;
}
}
}
return Status::OK();
}
// Builds XlaCompiler argument descriptions `args` from `ctx`.
Status MakeXlaCompilerArgumentsFromInputs(
XlaOpKernelContext* ctx, std::vector<XlaCompiler::Argument>* args,
bool* has_uninitialized_vars, bool* has_tensor_arrays,
bool* has_uninitialized_tensor_lists) {
VLOG(2) << "Num inputs " << ctx->num_inputs();
args->resize(ctx->num_inputs());
*has_uninitialized_vars = false;
*has_tensor_arrays = false;
*has_uninitialized_tensor_lists = false;
for (int i = 0; i < ctx->num_inputs(); ++i) {
VLOG(2) << " Input " << i << " type: " << DataTypeString(ctx->input_type(i))
<< " shape: " << ctx->InputShape(i).DebugString();
XlaCompiler::Argument& arg = (*args)[i];
DataType type = ctx->input_type(i);
// When reading a resource input, use the type and shape of the resource's
// current value.
if (type == DT_RESOURCE) {
XlaResource* resource;
TF_RETURN_IF_ERROR(ctx->GetResourceInput(i, &resource));
XlaCompiler::PopulateArgumentFromResource(*resource, &arg);
if (arg.resource_kind == XlaResource::kTensorArray) {
*has_tensor_arrays = true;
}
if (!arg.initialized) {
*has_uninitialized_vars = true;
}
VLOG(2) << " resource " << resource->name()
<< " type: " << DataTypeString(arg.type)
<< " shape: " << arg.ShapeHumanString()
<< " initialized: " << arg.initialized;
} else {
arg.kind = XlaCompiler::Argument::kParameter;
arg.type = type;
TF_ASSIGN_OR_RETURN(arg.shape, ctx->builder()->GetShape(ctx->Input(i)));
if (IsTensorListInput(ctx, i)) {
// arg.initialized == false means that the element_shape of the list
// was not available at the time of building the list so an empty list
// was created instead. If so, the body function of While is run once
// to infer the shape of the list before actually building the While op.
TF_RETURN_IF_ERROR(
IsTensorListInitialized(ctx->Input(i), &arg.initialized));
if (!arg.initialized) {
*has_uninitialized_tensor_lists = true;
}
}
}
}
return Status::OK();
}
// Populates loop invariant indices to true in `loop_invariants`.
void GetLoopInvariants(XlaOpKernelContext* ctx,
const NameAttrList& body_name_attr,
std::vector<bool>* const loop_invariants) {
const FunctionBody* body;
OP_REQUIRES_OK(ctx, ctx->compiler()->FindFunctionBody(body_name_attr, &body));
const tensorflow::FunctionLibraryDefinition* fld =
ctx->compiler()->flib_runtime()->GetFunctionLibraryDefinition();
for (int i = 0; i < body->ret_nodes.size(); i++) {
StatusOr<bool> is_loop_invariant = IsLoopInvariant(body, i, fld);
OP_REQUIRES_OK(ctx, is_loop_invariant.status());
(*loop_invariants)[i] = *is_loop_invariant;
VLOG(2) << "Arg " << i << " of " << body_name_attr.name() << " is "
<< ((*loop_invariants)[i] ? "" : "not ") << "loop invariant";
}
}
// Converts entries in `args` which are loop invariants and have compile time
// constant inputs and need to be constants in order to be compilable to
// constants so that they can be propagated in the loop body.
Status ConvertLoopInvariantsToConst(
XlaOpKernelContext* ctx, const NameAttrList& body_name_attr,
const NameAttrList& cond_name_attr,
std::vector<XlaCompiler::Argument>* args,
std::vector<bool>* compile_time_const_arg_indices,
int* num_compile_time_const_args, xla::Client* client) {
std::vector<bool> loop_invariants(ctx->num_inputs());
GetLoopInvariants(ctx, body_name_attr, &loop_invariants);
std::vector<bool> body_must_be_const_nodes;
const FunctionBody* body;
std::vector<bool> cond_must_be_const_nodes;
const FunctionBody* cond;
TF_RETURN_IF_ERROR(FindMustBeConstNodes(ctx, body_name_attr,
&body_must_be_const_nodes, &body));
TF_RETURN_IF_ERROR(FindMustBeConstNodes(ctx, cond_name_attr,
&cond_must_be_const_nodes, &cond));
auto should_convert_to_const = [&](int arg_idx) {
XlaCompiler::Argument& arg = (*args)[arg_idx];
return arg.kind != XlaCompiler::Argument::kResource &&
loop_invariants[arg_idx] &&
(body_must_be_const_nodes[body->arg_nodes[arg_idx]->id()] ||
cond_must_be_const_nodes[cond->arg_nodes[arg_idx]->id()]);
};
absl::InlinedVector<int, 5> converted_constants =
ConvertCompileTimeConstArgumentsToConst(ctx, args,
/*xla_expression_offset=*/0,
should_convert_to_const);
VLOG(2) << "Converted args to constants: {"
<< absl::StrJoin(converted_constants, ",") << "}";
for (int arg_idx : converted_constants) {
compile_time_const_arg_indices->at(arg_idx) = true;
(*num_compile_time_const_args)++;
}
return Status::OK();
}
Status VerifyBodyInputAndOutputShapeMatch(
XlaOpKernelContext* ctx,
const std::vector<bool>& compile_time_const_arg_indices,
const XlaCompiler::CompilationResult& body, bool has_token_input_output) {
xla::Shape body_input_shape = body.xla_input_shapes[0];
xla::Shape body_output_shape;
body_output_shape.set_element_type(xla::TUPLE);
for (int i = 0; i < ctx->num_outputs(); i++) {
if (!compile_time_const_arg_indices[i]) {
*(body_output_shape.add_tuple_shapes()) =
body.xla_output_shape.tuple_shapes(i);
}
}
// If `body` has a token output, append its shape to `body_output_shape`.
if (has_token_input_output) {
*(body_output_shape.add_tuple_shapes()) =
body.xla_output_shape.tuple_shapes(ctx->num_inputs());
}
if (!xla::ShapeUtil::Compatible(body_input_shape, body_output_shape)) {
return errors::InvalidArgument(
"Input and output shapes of loop body do not match: ",
xla::ShapeUtil::HumanString(body_input_shape), " vs. ",
xla::ShapeUtil::HumanString(body_output_shape));
}
return Status::OK();
}
xla::StatusOr<xla::XlaComputation> BuildWrappedCond(
XlaOpKernelContext* ctx, const XlaCompiler::CompilationResult& cond) {
xla::Shape cond_input_shape = cond.xla_input_shapes[0];
std::unique_ptr<xla::XlaBuilder> cb =
ctx->builder()->CreateSubBuilder("cond_wrapper");
auto inputs = xla::Parameter(cb.get(), 0, cond_input_shape, "inputs");
auto outputs = xla::Call(cb.get(), *cond.computation, {inputs});
xla::GetTupleElement(outputs, 0);
return cb->Build();
}
xla::StatusOr<xla::XlaComputation> BuildWrappedBody(
XlaOpKernelContext* ctx, const XlaCompiler::CompilationResult& body,
const std::vector<bool>& compile_time_const_arg_indices,
int num_compile_time_const_args, bool has_token_input_output) {
if (num_compile_time_const_args <= 0) {
return xla::XlaComputation(body.computation->proto());
}
xla::XlaComputation body_wrapper;
std::unique_ptr<xla::XlaBuilder> cb =
ctx->builder()->CreateSubBuilder("body_wrapper");
xla::Shape body_input_shape = body.xla_input_shapes[0];
auto inputs = xla::Parameter(cb.get(), 0, body_input_shape, "inputs");
// Call the original body function which has mismatched inputs and outputs
// and strip the compile time consts from the list of outputs. While requires
// the inputs and outputs of its body function to match.
auto outputs = xla::Call(cb.get(), *body.computation, {inputs});
std::vector<xla::XlaOp> non_compile_time_const_outputs;
for (int i = 0; i < compile_time_const_arg_indices.size(); i++) {
if (!compile_time_const_arg_indices[i]) {
non_compile_time_const_outputs.push_back(
xla::GetTupleElement(outputs, i));
}
}
// If `body` has a token output, append it to
// `non_compile_time_const_outputs`.
if (has_token_input_output) {
non_compile_time_const_outputs.push_back(
xla::GetTupleElement(outputs, ctx->num_outputs()));
}
xla::Tuple(cb.get(), non_compile_time_const_outputs);
return cb->Build();
}
xla::XlaOp BuildWhile(XlaOpKernelContext* ctx,
const xla::XlaComputation& wrapped_cond,
const xla::XlaComputation& wrapped_body,
const xla::XlaOp& initial_values,
const std::vector<int>& input_mapping,
const std::vector<bool>& compile_time_const_arg_indices,
int num_compile_time_const_args,
bool has_token_input_output) {
xla::XlaOp while_result =
xla::While(wrapped_cond, wrapped_body, initial_values);
std::vector<xla::XlaOp> padded_while_outputs(ctx->num_outputs());
int while_result_index = 0;
for (int i = 0; i < ctx->num_inputs(); i++) {
if (!compile_time_const_arg_indices[i]) {
padded_while_outputs[input_mapping[while_result_index]] =
xla::GetTupleElement(while_result, while_result_index);
while_result_index++;
} else {
padded_while_outputs[i] = ctx->Input(i);
}
}
// If `body` has a token output, append it to `padded_while_outputs`.
if (has_token_input_output) {
padded_while_outputs.push_back(xla::GetTupleElement(
while_result, ctx->num_inputs() - num_compile_time_const_args));
}
return xla::Tuple(ctx->builder(), padded_while_outputs);
}
} // anonymous namespace
XlaWhileOp::XlaWhileOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
const NameAttrList* name_attr;
OP_REQUIRES_OK(ctx, ctx->GetAttr("cond", &name_attr));
cond_name_attr_ = *name_attr;
OP_REQUIRES_OK(ctx, ctx->GetAttr("body", &name_attr));
body_name_attr_ = *name_attr;
if (!ctx->GetAttr(kXlaTokenInputNodesAttrName, &token_input_nodes_).ok()) {
has_token_input_output_ = false;
} else {
has_token_input_output_ = !token_input_nodes_.empty();
}
if (ctx->HasAttr(kPropagateCompileTimeConsts)) {
OP_REQUIRES_OK(ctx, ctx->GetAttr(kPropagateCompileTimeConsts,
&propagate_compile_time_consts_));
}
if (!ctx->GetAttr(kXlaOriginalOutsideCompilationNodeName,
&original_node_name_)
.ok())
original_node_name_ = name();
}
void XlaWhileOp::Compile(XlaOpKernelContext* ctx) {
VLOG(1) << "WhileOp::Compile";
// Input resources need to be grouped in the end of the body function
// according to the convention of the XLA bridge.
OP_REQUIRES_OK(ctx, VerifyResourceArgsGroupedAtEnd(ctx, body_name_attr_));
std::vector<XlaCompiler::Argument> arguments;
bool has_uninitialized_vars;
bool has_tensor_arrays;
bool has_uninitialized_tensor_lists;
OP_REQUIRES_OK(ctx, MakeXlaCompilerArgumentsFromInputs(
ctx, &arguments, &has_uninitialized_vars,
&has_tensor_arrays, &has_uninitialized_tensor_lists));
xla::XlaBuilder* builder = ctx->builder();
XlaCompiler* compiler = ctx->compiler();
// Indices of loop vars which satisfy the following conditions:
// 1. They are loop invariants.
// 2. The op inputs at these indices are compile time constants.
//
// These compile time consts do not appear as _Args in the cond/body functions
// and are replaced by kConstant nodes instead. As a result, the compiled
// body function does not have matching input and output shape. We fix this
// by rewriting the body computation (see body_wrapper below) to output
// just the non compile-time-const values and later pad up the while output
// with the const args.
std::vector<bool> compile_time_const_arg_indices(ctx->num_inputs());
int num_compile_time_const_args = 0;
if (propagate_compile_time_consts_) {
OP_REQUIRES_OK(ctx, ConvertLoopInvariantsToConst(
ctx, body_name_attr_, cond_name_attr_, &arguments,
&compile_time_const_arg_indices,
&num_compile_time_const_args, compiler->client()));
}
VLOG(1) << "Compiling body";
// All resource that are inputs to the loop's body must also be
// present as loop body outputs; the signature of the loop's input and
// output must match. We ensure this by asking the compiler to include the
// current values of all resources, even if they haven't been updated by the
// computation. We must also ask the compiler to keep compile-time constant
// outputs as part of the generated computation, for the same reason.
// TODO(phawkins): consider adding loop-invariant inputs to XLA's While()
// operator.
XlaCompiler::CompileOptions body_options;
body_options.use_tuple_arg = true;
body_options.return_updated_values_for_all_resources = true;
body_options.is_entry_computation = false;
body_options.add_token_input_output = has_token_input_output_;
XlaCompiler::CompilationResult body;
OP_REQUIRES_OK(ctx, compiler->CompileFunction(body_options, body_name_attr_,
arguments, &body));
// We must use a static shape for parameters to an XLA compilation. However,
// we may not know the shape of a resource if it is first
// written inside the loop. Furthermore, we do not know ahead of time which
// gradient TensorArrays will be created by the TensorArrayGradV3 operator.
//
// Ideally we would change TensorFlow to provide static shape always, but
// but this is not easy to do. So if uninitialized resources or TensorArrays
// are used by the loop body, we compile the body function twice:
// 1) once with uninitialized resource inputs and no TensorArray gradient
// inputs. We then discard the computation but we assume resource shapes
// and the set of gradients read or written will reach a fixpoint after one
// iteration.
// Hence we can use the output shapes and TensorArray gradients of each
// resource as the "true" shapes.
// 2) again with the "correct" resource information determined by (1).
if (has_uninitialized_vars || has_tensor_arrays ||
has_uninitialized_tensor_lists) {
VLOG(2) << "Recompiling loop body: has_uninitialized_vars: "
<< has_uninitialized_vars
<< " has_tensor_arrays: " << has_tensor_arrays
<< " has_uninitialized_tensor_lists: "
<< has_uninitialized_tensor_lists;
// Initializes any uninitialized resource with zero values of the
// shape determined by the first compilation.
for (int i = 0; i < body.resource_updates.size(); ++i) {
const XlaCompiler::ResourceUpdate& update = body.resource_updates[i];
XlaResource* resource;
OP_REQUIRES_OK(ctx, ctx->GetResourceInput(update.input_index, &resource));
XlaCompiler::Argument& arg = arguments[update.input_index];
if (!arg.initialized) {
VLOG(2) << "Update shape for argument " << update.input_index << " "
<< update.shape.DebugString();
arg.initialized = true;
arg.shape = update.shape;
OP_REQUIRES_OK(ctx,
resource->SetTypeAndShape(update.type, update.shape));
OP_REQUIRES_OK(ctx, resource->SetZeroValue(builder));
}
// Add any TensorArray gradients touched by the body to the enclosing
// graph.
for (const string& grad_source : update.tensor_array_gradients_accessed) {
VLOG(4) << "TensorArray " << resource->name() << " accessed gradient "
<< grad_source;
XlaResource* gradient;
OP_REQUIRES_OK(ctx, resource->GetOrCreateTensorArrayGradient(
grad_source, builder, &gradient));
}
// Add all of the TensorArray gradients to the argument. For simplicity,
// we always pass all known gradients.
for (const auto& gradient : resource->tensor_array_gradients()) {
arg.tensor_array_gradients.insert(gradient.first);
}
}
// Set the shape of any uninitialized TensorLists to the shape determined by
// the first compilation. Note that, unlike resources, we do not initialize
// the input list with zeros here, that is done later.
xla::Shape body_output_shape = body.xla_output_shape;
OP_REQUIRES(ctx, body_output_shape.IsTuple(),
errors::FailedPrecondition(
"xla_output_shape of while body must be a tuple."));
for (int i = 0; i < arguments.size(); i++) {
XlaCompiler::Argument& arg = arguments[i];
if (arg.initialized || !IsTensorListInput(ctx, i)) {
continue;
}
arg.shape = body_output_shape.tuple_shapes(i);
arg.initialized = true;
}
// Recompile the body with the "correct" resource shapes.
VLOG(1) << "Recompiling body with corrected resource shapes";
body = {};
OP_REQUIRES_OK(ctx, compiler->CompileFunction(body_options, body_name_attr_,
arguments, &body));
}
VLOG(1) << "Compiling condition";
XlaCompiler::CompileOptions cond_options;
cond_options.use_tuple_arg = true;
cond_options.is_entry_computation = false;
cond_options.add_token_input_output = has_token_input_output_;
XlaCompiler::CompilationResult cond;
OP_REQUIRES_OK(ctx, compiler->CompileFunction(cond_options, cond_name_attr_,
arguments, &cond));
OP_REQUIRES(ctx, body.xla_input_shapes.size() == 1,
errors::FailedPrecondition("Expected one input shape"));
xla::Shape body_input_shape = body.xla_input_shapes[0];
OP_REQUIRES(ctx, body_input_shape.IsTuple(),
errors::FailedPrecondition("Expected tuple shape"));
OP_REQUIRES(ctx, cond.xla_input_shapes.size() == 1,
errors::FailedPrecondition("Expected one input shape"));
xla::Shape cond_input_shape = cond.xla_input_shapes[0];
OP_REQUIRES(ctx, cond_input_shape.IsTuple(),
errors::FailedPrecondition("Expected tuple shape"));
VLOG(2) << "Body shape: " << xla::ShapeUtil::HumanString(body_input_shape)
<< " -> " << xla::ShapeUtil::HumanString(body.xla_output_shape);
VLOG(2) << "Cond shape: " << xla::ShapeUtil::HumanString(cond_input_shape)
<< " -> " << xla::ShapeUtil::HumanString(cond.xla_output_shape);
OP_REQUIRES(ctx,
xla::ShapeUtil::Compatible(body_input_shape, cond_input_shape),
errors::InvalidArgument(
"Input shapes of loop body and condition do not match: ",
xla::ShapeUtil::HumanString(body_input_shape), " vs. ",
xla::ShapeUtil::HumanString(cond_input_shape)));
// Check that the shape of the body outputs excluding the compile time const
// args (which are pruned from the body outputs in body_wapper) matches the
// shape of the inputs.
OP_REQUIRES_OK(ctx, VerifyBodyInputAndOutputShapeMatch(
ctx, compile_time_const_arg_indices, body,
has_token_input_output_));
xla::Shape expected_cond_output_shape_without_side_effect =
xla::ShapeUtil::MakeTupleShape(
{xla::ShapeUtil::MakeShape(xla::PRED, {})});
xla::Shape expected_cond_output_shape_with_side_effect =
xla::ShapeUtil::MakeTupleShape({xla::ShapeUtil::MakeShape(xla::PRED, {}),
xla::ShapeUtil::MakeTokenShape()});
OP_REQUIRES(ctx,
xla::ShapeUtil::Compatible(
cond.xla_output_shape,
expected_cond_output_shape_without_side_effect) ||
xla::ShapeUtil::Compatible(
cond.xla_output_shape,
expected_cond_output_shape_with_side_effect),
errors::InvalidArgument(
"Output shape of loop condition should be (pred[]) or "
"(pred[], token[]), got: ",
xla::ShapeUtil::HumanString(cond.xla_output_shape)));
int num_inputs = body.input_mapping.size();
std::vector<xla::XlaOp> inputs(num_inputs);
for (int i = 0; i < num_inputs; ++i) {
int input_num = body.input_mapping[i];
if (has_token_input_output_ && i == num_inputs - 1) {
// Set token input for this "while" op.
std::vector<xla::XlaOp> token_inputs;
for (const string& node_name : token_input_nodes_) {
auto token_or = compiler->GetNodeToken(node_name);
OP_REQUIRES_OK(ctx, token_or.status());
token_inputs.push_back(token_or.ValueOrDie());
}
inputs[i] = xla::AfterAll(builder, token_inputs);
} else if (ctx->input_type(input_num) == DT_RESOURCE) {
XlaResource* resource;
OP_REQUIRES_OK(ctx, ctx->GetResourceInput(input_num, &resource));
OP_REQUIRES_OK(ctx, resource->Pack(&inputs[i], builder));
} else if (IsTensorListInput(ctx, input_num)) {
xla::XlaOp input = ctx->Input(input_num);
auto input_shape_or = ctx->builder()->GetShape(input);
OP_REQUIRES_OK(ctx, input_shape_or.status());
xla::Shape input_shape = input_shape_or.ValueOrDie();
const xla::Shape& list_shape = body_input_shape.tuple_shapes(i);
// Shape/datatype of the input list may differ from shape/datatype of the
// body/cond input if the list's shape/datatype was inferred after the
// first compilation and the body/cond was recompiled with the updated
// shape/datatype of the list.
if (input_shape != list_shape) {
// Prepare dynamic dimensions for element shapes.
std::vector<std::vector<xla::XlaOp>> list_dynamic_dims;
for (int64 i = 0; i < list_shape.tuple_shapes_size() - 1; ++i) {
std::vector<xla::XlaOp> dynamic_dims;
const xla::Shape& shape = list_shape.tuple_shapes(i);
// We already have the dynamic size of leading dimension outside of
// the while loop without initializing the TensorList inside the while
// loop.
if (shape.is_dynamic_dimension(0)) {
xla::XlaOp leading_dim_size = xla::GetDimensionSize(input, 0);
dynamic_dims.push_back(leading_dim_size);
} else {
int32 dim_size = shape.dimensions(0);
dynamic_dims.push_back(
xla::ConstantR0<int32>(ctx->builder(), dim_size));
}
// Set dynamic dimension size to 0 for element value. Inside the while
// loop, TensorlistSetItem will properly set the element shape's
// dynamic dimension.
for (int64 dim = 1; dim < shape.dimensions_size(); ++dim) {
int32 dim_size = shape.dimensions(dim);
if (shape.is_dynamic_dimension(dim)) {
dim_size = 0;
}
dynamic_dims.push_back(
xla::ConstantR0<int32>(ctx->builder(), dim_size));
}
list_dynamic_dims.push_back(dynamic_dims);
}
OP_REQUIRES_OK(
ctx, CreateZerosTensorListWithShape(ctx->builder(), list_shape,
list_dynamic_dims, &inputs[i]));
} else {
inputs[i] = ctx->Input(input_num);
}
} else {
inputs[i] = ctx->Input(input_num);
}
}
xla::XlaOp init = xla::Tuple(builder, inputs);
VLOG(1) << "Building while loop";
// Wraps the condition in a computation that unpacks the output tuple.
xla::StatusOr<xla::XlaComputation> cond_result = BuildWrappedCond(ctx, cond);
OP_REQUIRES_OK(ctx, cond_result.status());
xla::XlaComputation wrapped_cond = std::move(cond_result.ValueOrDie());
// Remove compile time const args from the list of body outputs.
xla::StatusOr<xla::XlaComputation> body_result =
BuildWrappedBody(ctx, body, compile_time_const_arg_indices,
num_compile_time_const_args, has_token_input_output_);
OP_REQUIRES_OK(ctx, body_result.status());
xla::XlaComputation wrapped_body = std::move(body_result.ValueOrDie());
// Builds the While op and pads its output with the compile time const args.
xla::XlaOp while_result =
BuildWhile(ctx, wrapped_cond, wrapped_body, init, body.input_mapping,
compile_time_const_arg_indices, num_compile_time_const_args,
has_token_input_output_);
// Sets non-variable outputs and determine when resource variables start.
int resource_index = 0;
for (int i = 0; i < ctx->num_outputs(); ++i) {
if (ctx->input_type(i) != DT_RESOURCE) {
if (IsTensorListInput(ctx, i)) {
ctx->SetTensorListOutput(i, xla::GetTupleElement(while_result, i));
} else {
ctx->SetOutput(i, xla::GetTupleElement(while_result, i));
}
++resource_index;
} else {
break;
}
}
if (has_token_input_output_) {
// Set token output for this "while" op.
xla::XlaOp token_output =
xla::GetTupleElement(while_result, ctx->num_outputs());
auto shape_or = builder->GetShape(token_output);
OP_REQUIRES_OK(ctx, shape_or.status());
OP_REQUIRES(ctx, shape_or.ValueOrDie().IsToken(),
errors::FailedPrecondition(
"Token output is not token type: ",
xla::ShapeUtil::HumanString(shape_or.ValueOrDie())));
OP_REQUIRES_OK(ctx,
compiler->SetNodeToken(original_node_name_, token_output));
}
// Updates the values of any resource variables modified by the loop.
for (int i = 0; i < body.resource_updates.size(); ++i) {
const XlaCompiler::ResourceUpdate& update = body.resource_updates[i];
XlaResource* resource;
OP_REQUIRES_OK(ctx, ctx->GetResourceInput(update.input_index, &resource));
if (update.modified) {
int pos = resource_index + i;
OP_REQUIRES_OK(ctx,
resource->SetFromPack(
arguments[update.input_index].tensor_array_gradients,
xla::GetTupleElement(while_result, pos), builder));
}
VLOG(2) << "Loop-carried variable: pos: " << update.input_index
<< " name: " << resource->name() << " modified: " << update.modified
<< " type: " << DataTypeString(update.type)
<< " shape: " << update.shape.DebugString();
// Copies the identity of the resource variable from input to output
// unchanged, even if the variable was not modified.
ctx->op_kernel_context()->set_output(
update.input_index,
ctx->op_kernel_context()->input(update.input_index));
}
VLOG(1) << "Done building while loop";
}
REGISTER_XLA_OP(Name("While").AllowResourceTypes().AllowVariantTypes(),
XlaWhileOp);
REGISTER_XLA_OP(Name("StatelessWhile").AllowResourceTypes().AllowVariantTypes(),
XlaWhileOp);
REGISTER_XLA_OP(Name("XlaWhile").AllowResourceTypes().AllowVariantTypes(),
XlaWhileOp);
} // namespace tensorflow