blob: ef06e9db59ced987a963a5320cd7d154738c8a96 [file] [log] [blame]
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// XLA TensorList operators.
#include <limits>
#include <vector>
#include "tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h"
#include "tensorflow/compiler/tf2xla/kernels/tensor_list_utils.h"
#include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/compiler/tf2xla/type_util.h"
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/framework/bounds_check.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/partial_tensor_shape.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
namespace {
// GetTensorListDynamicDims collects the dynamic dimensions that a tensorlist
// may carry and returns them in a 2D vector: XlaOp[ElementSize][DimSize]. If a
// dimension is static, a constant dimension is returned. If a dim is dynamic, a
// dynamic XlaOp representing the dynamic size is returned.
StatusOr<std::vector<std::vector<xla::XlaOp>>> GetTensorListDynamicDims(
XlaOpKernelContext* ctx, const xla::Shape& element_shape,
const xla::Shape& list_shape, int64_t num_elements) {
std::vector<int64_t> dynamic_sizes;
// The multiplier can be a dynamic value.
TF_RETURN_IF_ERROR(ctx->ConstantInputAsIntVector(0, &dynamic_sizes));
std::vector<bool> dims_are_dynamic;
TF_RETURN_IF_ERROR(
ctx->ResolveInputDynamismIntoPredVector(0, &dims_are_dynamic));
bool leading_dim_is_dynamic;
TF_RETURN_IF_ERROR(
ctx->ResolveInputDynamismIntoPred(1, &leading_dim_is_dynamic));
std::vector<std::vector<xla::XlaOp>> list_dynamic_dims;
// Set dynamic dimension size to 0 for initialization value.
std::vector<xla::XlaOp> dynamic_dims;
if (leading_dim_is_dynamic) {
dynamic_dims.push_back(ctx->Input(1));
} else {
dynamic_dims.push_back(
xla::ConstantR0<int32>(ctx->builder(), num_elements));
}
for (int64_t dim = 0; dim < element_shape.dimensions_size(); ++dim) {
if (dims_are_dynamic[dim]) {
auto dynamic_dim_size = xla::Slice(ctx->Input(0), {dim}, {dim + 1}, {1});
dynamic_dim_size = xla::Reshape(dynamic_dim_size, {});
dynamic_dim_size = xla::ConvertElementType(dynamic_dim_size, xla::S32);
dynamic_dims.push_back(dynamic_dim_size);
} else {
dynamic_dims.push_back(
xla::ConstantR0<int32>(ctx->builder(), dynamic_sizes[dim]));
}
}
list_dynamic_dims.push_back(dynamic_dims);
return list_dynamic_dims;
}
class TensorListLengthOp : public XlaOpKernel {
public:
explicit TensorListLengthOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
void Compile(XlaOpKernelContext* ctx) override {
int64_t leading_dim;
xla::XlaOp leading_dim_size;
bool leading_dim_is_dynamic;
OP_REQUIRES_OK(ctx, GetLeadingDimForTensorList(ctx->Input(0), &leading_dim,
&leading_dim_is_dynamic,
&leading_dim_size));
ctx->SetOutput(0, leading_dim_size);
}
private:
TF_DISALLOW_COPY_AND_ASSIGN(TensorListLengthOp);
};
REGISTER_XLA_OP(Name("TensorListLength").IsMetadataOp(), TensorListLengthOp);
// "input" is the shape input for EmptyTensorList/TensorListReserve ops.
// If "input" is a compile time constant and not "unknown rank" (-1), return
// its value in "*shape".
Status TryGetElementShapeFromInput(XlaOpKernelContext* ctx, xla::XlaOp input,
xla::PrimitiveType dtype, bool* got_shape,
xla::Shape* shape) {
auto is_compile_time_constant_or = input.builder()->IsConstant(input);
TF_RETURN_IF_ERROR(is_compile_time_constant_or.status());
bool is_compile_time_constant = is_compile_time_constant_or.ValueOrDie();
if (!is_compile_time_constant) {
*got_shape = false;
return Status::OK();
}
PartialTensorShape partial_shape;
TF_RETURN_IF_ERROR(ctx->ConstantInputAsPartialShape(0, &partial_shape));
if (!partial_shape.IsFullyDefined()) {
*got_shape = false;
return Status::OK();
}
*shape = xla::ShapeUtil::MakeShape(dtype, partial_shape.dim_sizes());
*got_shape = true;
return Status::OK();
}
class TensorListReserveOp : public XlaOpKernel {
public:
explicit TensorListReserveOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("element_dtype", &dtype_));
// Only non-nested TensorList is supported for now.
OP_REQUIRES(
ctx, dtype_ != DT_VARIANT,
errors::Unimplemented(
"Only non-nested TensorList is supported for TensorListReserve."));
}
void Compile(XlaOpKernelContext* ctx) override {
int64_t num_elements;
OP_REQUIRES_OK(ctx,
ctx->ConstantInputAsIntScalar(
1, &num_elements, xla::ValueInferenceMode::kUpperBound));
bool num_element_is_dynamic;
OP_REQUIRES_OK(
ctx, ctx->ResolveInputDynamismIntoPred(1, &num_element_is_dynamic));
OP_REQUIRES(
ctx, num_elements >= 0,
errors::InvalidArgument(
"XLA compilation requires a fixed tensor list size. Set the number "
"of elements. This could also happen if you're using a TensorArray "
"in a while loop that does not have its maximum_iteration set, you "
"can fix this by setting maximum_iteration to a suitable value."));
// If element shape is compile time constant and it's not "unknown rank"
// shape (-1), create an initialized TensorList. Otherwise create an
// uninitialized TensorList.
xla::XlaOp element_shape_handle = ctx->Input(0);
xla::PrimitiveType type;
OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(dtype_, &type));
bool got_shape;
xla::Shape element_shape;
OP_REQUIRES_OK(ctx,
TryGetElementShapeFromInput(ctx, element_shape_handle, type,
&got_shape, &element_shape));
if (got_shape) {
xla::Shape list_shape;
OP_REQUIRES_OK(ctx, GetTensorListShapeFromElementShape(
element_shape, num_elements,
num_element_is_dynamic, &list_shape));
// Set up dynamic dimension sizes to create the zero tensor.
auto list_dynamic_dims_or = GetTensorListDynamicDims(
ctx, element_shape, list_shape, num_elements);
OP_REQUIRES_OK(ctx, list_dynamic_dims_or.status());
xla::XlaOp new_list;
OP_REQUIRES_OK(ctx, CreateZerosTensorListWithShape(
ctx->builder(), list_shape,
list_dynamic_dims_or.ValueOrDie(), &new_list));
xla::XlaOp result;
OP_REQUIRES_OK(
ctx,
SetTensorListPushIndex(
new_list, xla::ConstantR0<int32>(ctx->builder(), num_elements),
&result));
ctx->SetTensorListOutput(0, result);
return;
}
xla::XlaOp result = BuildUninitializedTensorList(
ctx->builder(), num_elements, num_element_is_dynamic, ctx->Input(1));
ctx->SetTensorListOutput(0, result);
}
private:
DataType dtype_;
TF_DISALLOW_COPY_AND_ASSIGN(TensorListReserveOp);
};
REGISTER_XLA_OP(Name("TensorListReserve")
.CompileTimeConstantInput("element_shape")
.CompileTimeConstantInput("num_elements"),
TensorListReserveOp);
class EmptyTensorListOp : public XlaOpKernel {
public:
explicit EmptyTensorListOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("element_dtype", &dtype_));
}
void Compile(XlaOpKernelContext* ctx) override {
int64_t max_num_elements;
OP_REQUIRES_OK(
ctx, ctx->ConstantInputAsIntScalar(
1, &max_num_elements, xla::ValueInferenceMode::kUpperBound));
bool num_element_is_dynamic;
OP_REQUIRES_OK(
ctx, ctx->ResolveInputDynamismIntoPred(1, &num_element_is_dynamic));
OP_REQUIRES(ctx, max_num_elements >= 0,
errors::InvalidArgument(
"XLA compilation requires a fixed tensor list size. Set "
"the max number of elements. This could also happen if "
"you're using a TensorArray in a while loop that does not "
"have its maximum_iteration set, you can fix this by "
"setting maximum_iteration to a suitable value."));
if (dtype_ != DT_VARIANT) {
// We are creating a non-nested TensorList.
// If element shape is compile time constant and it's not "unknown
// rank" shape (-1), create an initialized TensorList. Otherwise
// create an uninitialized TensorList.
xla::XlaOp element_shape_handle = ctx->Input(0);
xla::PrimitiveType type;
OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(dtype_, &type));
bool got_shape;
xla::Shape element_shape;
OP_REQUIRES_OK(
ctx, TryGetElementShapeFromInput(ctx, element_shape_handle, type,
&got_shape, &element_shape));
if (got_shape) {
xla::Shape list_shape;
OP_REQUIRES_OK(ctx, GetTensorListShapeFromElementShape(
element_shape, max_num_elements,
num_element_is_dynamic, &list_shape));
// Set up dynamic dimension sizes to create the zero tensor.
auto list_dynamic_dims_or = GetTensorListDynamicDims(
ctx, element_shape, list_shape, max_num_elements);
OP_REQUIRES_OK(ctx, list_dynamic_dims_or.status());
xla::XlaOp result;
OP_REQUIRES_OK(ctx, CreateZerosTensorListWithShape(
ctx->builder(), list_shape,
list_dynamic_dims_or.ValueOrDie(), &result));
ctx->SetTensorListOutput(0, result);
return;
}
}
// We are creating a nested TensorList or a non-nested TensorList with
// unknown shape. Just create an uninitialized TensorList.
xla::XlaOp result =
BuildUninitializedTensorList(ctx->builder(), max_num_elements,
num_element_is_dynamic, ctx->Input(1));
ctx->SetTensorListOutput(0, result);
}
private:
DataType dtype_;
TF_DISALLOW_COPY_AND_ASSIGN(EmptyTensorListOp);
};
REGISTER_XLA_OP(Name("EmptyTensorList")
.CompileTimeConstantInput("element_shape")
.CompileTimeConstantInput("max_num_elements")
.AllowVariantTypes(),
EmptyTensorListOp);
class TensorListElementShapeOp : public XlaOpKernel {
public:
explicit TensorListElementShapeOp(OpKernelConstruction* ctx)
: XlaOpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("shape_type", &shape_type_));
}
void Compile(XlaOpKernelContext* ctx) override {
// Check that the TensorList is initialized.
bool is_initialized;
OP_REQUIRES_OK(ctx,
(IsTensorListInitialized(ctx->Input(0), &is_initialized)));
OP_REQUIRES(ctx, is_initialized,
errors::InvalidArgument("TensorList is not initialized"));
// Only non-nested TensorList is supported for now.
bool is_nested;
OP_REQUIRES_OK(ctx, IsNestedTensorList(ctx->Input(0), &is_nested));
OP_REQUIRES(ctx, !is_nested,
errors::Unimplemented("Only non-nested TensorList is supported "
"for TensorListElementShape."));
// For non-nested TensorList, element shape is the buffer shape without
// the first dimension.
xla::XlaBuilder* b = ctx->builder();
xla::Shape list_shape;
OP_REQUIRES_OK(ctx, GetTensorListBufferShape(ctx->Input(0), &list_shape));
list_shape.DeleteDimension(0);
switch (shape_type_) {
case DT_INT64:
ctx->SetOutput(0, xla::ConstantR1<int64_t>(b, list_shape.dimensions()));
break;
case DT_INT32: {
std::vector<int32> size;
for (int64_t s : list_shape.dimensions()) {
size.push_back(s);
}
ctx->SetOutput(0, xla::ConstantR1<int32>(b, size));
break;
}
default:
ctx->CtxFailure(
errors::InvalidArgument("Unsupported shape type requested"));
return;
}
}
private:
DataType shape_type_;
TF_DISALLOW_COPY_AND_ASSIGN(TensorListElementShapeOp);
};
REGISTER_XLA_OP(Name("TensorListElementShape").IsMetadataOp(),
TensorListElementShapeOp);
class TensorListGetItemOp : public XlaOpKernel {
public:
explicit TensorListGetItemOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("element_dtype", &dtype_));
}
void Compile(XlaOpKernelContext* ctx) override {
// Check that the TensorList is initialized.
bool is_initialized;
OP_REQUIRES_OK(ctx,
(IsTensorListInitialized(ctx->Input(0), &is_initialized)));
OP_REQUIRES(ctx, is_initialized,
errors::InvalidArgument("TensorList is not initialized"));
// Only non-nested TensorList is supported for now.
bool is_nested;
OP_REQUIRES_OK(ctx, IsNestedTensorList(ctx->Input(0), &is_nested));
OP_REQUIRES(ctx, !is_nested,
errors::Unimplemented("Only non-nested TensorList is supported "
"for TensorListGetItem."));
xla::XlaOp list = ctx->Input(0);
xla::XlaOp index = ctx->Input(1);
xla::XlaOp result;
OP_REQUIRES_OK(ctx, ExecuteTensorListGetItem(list, index, &result));
ctx->SetOutput(0, result);
}
private:
DataType dtype_;
TF_DISALLOW_COPY_AND_ASSIGN(TensorListGetItemOp);
};
REGISTER_XLA_OP(Name("TensorListGetItem"), TensorListGetItemOp);
class TensorListGatherOp : public XlaOpKernel {
public:
explicit TensorListGatherOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("element_dtype", &dtype_));
}
void Compile(XlaOpKernelContext* ctx) override {
// Check that the TensorList is initialized.
bool is_initialized;
OP_REQUIRES_OK(ctx,
(IsTensorListInitialized(ctx->Input(0), &is_initialized)));
OP_REQUIRES(ctx, is_initialized,
errors::InvalidArgument("TensorList is not initialized"));
// Only non-nested TensorList is supported for now.
bool is_nested;
OP_REQUIRES_OK(ctx, IsNestedTensorList(ctx->Input(0), &is_nested));
OP_REQUIRES(ctx, !is_nested,
errors::Unimplemented("Only non-nested TensorList is supported "
"for TensorListGather."));
DataType indices_type = ctx->input_type(1);
const TensorShape indices_shape = ctx->InputShape(1);
OP_REQUIRES(ctx, indices_shape.dims() == 1,
errors::InvalidArgument("indices must be rank 1"));
xla::XlaOp list = ctx->Input(0);
xla::XlaOp indices = ctx->Input(1);
xla::XlaOp buffer;
OP_REQUIRES_OK(ctx, GetTensorListBuffer(list, &buffer));
xla::Shape buffer_xla_shape;
OP_REQUIRES_OK(ctx, GetTensorListBufferShape(list, &buffer_xla_shape));
TensorShape buffer_shape;
OP_REQUIRES_OK(ctx, XLAShapeToTensorShape(buffer_xla_shape, &buffer_shape));
xla::XlaOp result;
OP_REQUIRES_OK(
ctx, XlaGather(buffer, buffer_shape, indices, indices_shape, /*axis=*/0,
/*indices_are_nd=*/false, dtype_, indices_type,
ctx->builder(), &result));
ctx->SetOutput(0, result);
}
private:
DataType dtype_;
TF_DISALLOW_COPY_AND_ASSIGN(TensorListGatherOp);
};
REGISTER_XLA_OP(Name("TensorListGather"), TensorListGatherOp);
class TensorListStackOp : public XlaOpKernel {
public:
explicit TensorListStackOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
void Compile(XlaOpKernelContext* ctx) override {
// Check that the TensorList is initialized.
bool is_initialized;
OP_REQUIRES_OK(ctx,
(IsTensorListInitialized(ctx->Input(0), &is_initialized)));
OP_REQUIRES(ctx, is_initialized,
errors::InvalidArgument("TensorList is not initialized"));
// Only non-nested TensorList is supported for now.
bool is_nested;
OP_REQUIRES_OK(ctx, IsNestedTensorList(ctx->Input(0), &is_nested));
OP_REQUIRES(ctx, !is_nested,
errors::Unimplemented("Only non-nested TensorList is supported "
"for TensorListGetItem."));
xla::XlaOp buffer;
OP_REQUIRES_OK(ctx, GetTensorListBuffer(ctx->Input(0), &buffer));
ctx->SetOutput(0, buffer);
}
private:
TF_DISALLOW_COPY_AND_ASSIGN(TensorListStackOp);
};
REGISTER_XLA_OP(Name("TensorListStack"), TensorListStackOp);
class TensorListConcatOp : public XlaOpKernel {
public:
explicit TensorListConcatOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
void Compile(XlaOpKernelContext* ctx) override {
xla::XlaOp input = ctx->Input(0);
// Check that the TensorList is initialized.
bool is_initialized;
OP_REQUIRES_OK(ctx, (IsTensorListInitialized(input, &is_initialized)));
OP_REQUIRES(ctx, is_initialized,
errors::InvalidArgument("TensorList is not initialized"));
// Only non-nested TensorList is supported for now.
bool is_nested;
OP_REQUIRES_OK(ctx, IsNestedTensorList(input, &is_nested));
OP_REQUIRES(ctx, !is_nested,
errors::Unimplemented("Only non-nested TensorList is supported "
"for TensorListConcat."));
xla::XlaOp buffer;
OP_REQUIRES_OK(ctx, GetTensorListBuffer(input, &buffer));
xla::XlaBuilder* b = input.builder();
auto shape_or = b->GetShape(buffer);
OP_REQUIRES_OK(ctx, shape_or.status());
xla::Shape element_shape = shape_or.ConsumeValueOrDie();
std::vector<int64_t> element_dims =
xla::SpanToVector(element_shape.dimensions());
OP_REQUIRES(
ctx, element_dims.size() > 1,
errors::Unimplemented("TensorList of scalars is not supported"));
int64_t num_elements = element_dims[0];
int64_t tensor_lengths = element_dims[1];
std::vector<int64_t> new_dims = {num_elements * tensor_lengths};
for (int i = 2; i < element_dims.size(); i++) {
new_dims.push_back(element_dims[i]);
}
xla::XlaOp out = xla::Reshape(buffer, new_dims);
ctx->SetOutput(0, out);
// Second output is a tensor of lengths of returned tensors.
xla::XlaOp lengths = xla::ConstantR1(b, num_elements, tensor_lengths);
ctx->SetOutput(1, lengths);
}
private:
TF_DISALLOW_COPY_AND_ASSIGN(TensorListConcatOp);
};
REGISTER_XLA_OP(Name("TensorListConcatV2"), TensorListConcatOp);
class TensorListSplitOp : public XlaOpKernel {
public:
explicit TensorListSplitOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("element_dtype", &dtype_));
// Only non-nested TensorList is supported for now.
OP_REQUIRES(
ctx, dtype_ != DT_VARIANT,
errors::Unimplemented(
"Only non-nested TensorList is supported for TensorListReserve."));
}
void Compile(XlaOpKernelContext* ctx) override {
xla::XlaOp input_tensor = ctx->Input(0);
xla::XlaBuilder* b = input_tensor.builder();
auto shape_or = b->GetShape(input_tensor);
OP_REQUIRES_OK(ctx, shape_or.status());
xla::Shape element_shape = shape_or.ConsumeValueOrDie();
std::vector<int64_t> element_dims =
xla::SpanToVector(element_shape.dimensions());
OP_REQUIRES(
ctx, !element_dims.empty(),
errors::Unimplemented("Element dimensions have to be non-empty"));
std::vector<int64_t> lengths;
OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(2, &lengths));
OP_REQUIRES(ctx, !lengths.empty(),
errors::Unimplemented("Length has to be non-empty"));
int64_t length = lengths[0];
for (int64_t len : lengths) {
OP_REQUIRES(ctx, len == length,
errors::Unimplemented("All lengths have to be the same"));
}
OP_REQUIRES(
ctx, element_dims[0] % length == 0,
errors::Unimplemented("Buffer size has to be a multiple of length"));
std::vector<int64_t> new_dims = {element_dims[0] / length, length};
for (int i = 1; i < element_dims.size(); i++) {
new_dims.push_back(element_dims[i]);
}
xla::XlaOp reshaped = xla::Reshape(input_tensor, new_dims);
xla::XlaOp result;
OP_REQUIRES_OK(ctx, ExecuteTensorListFromTensor(length, reshaped, &result));
ctx->SetTensorListOutput(0, result);
}
private:
DataType dtype_;
TF_DISALLOW_COPY_AND_ASSIGN(TensorListSplitOp);
};
REGISTER_XLA_OP(Name("TensorListSplit")
.CompileTimeConstantInput("element_shape")
.CompileTimeConstantInput("lengths"),
TensorListSplitOp);
class TensorListFromTensorOp : public XlaOpKernel {
public:
explicit TensorListFromTensorOp(OpKernelConstruction* ctx)
: XlaOpKernel(ctx) {}
void Compile(XlaOpKernelContext* ctx) override {
const TensorShape& tensor_shape = ctx->InputShape(0);
int num_elements = tensor_shape.dim_size(0);
const xla::XlaOp tensor = ctx->Input(0);
xla::XlaOp result;
OP_REQUIRES_OK(ctx,
ExecuteTensorListFromTensor(num_elements, tensor, &result));
auto list_shape_or = ctx->builder()->GetShape(result);
ctx->SetTensorListOutput(0, result);
}
private:
TF_DISALLOW_COPY_AND_ASSIGN(TensorListFromTensorOp);
};
REGISTER_XLA_OP(
Name("TensorListFromTensor").CompileTimeConstantInput("element_shape"),
TensorListFromTensorOp);
class TensorListSetItemOp : public XlaOpKernel {
public:
explicit TensorListSetItemOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
void Compile(XlaOpKernelContext* ctx) override {
xla::XlaOp list = ctx->Input(0);
xla::XlaOp index = ctx->Input(1);
xla::XlaOp element = ctx->Input(2);
xla::XlaOp initialized_list;
OP_REQUIRES_OK(ctx, GetInitializedTensorListForElement(
list, element, /*element_is_tensor_list=*/false,
&initialized_list));
// Only non-nested TensorList is supported for now.
bool is_nested;
OP_REQUIRES_OK(ctx, IsNestedTensorList(initialized_list, &is_nested));
OP_REQUIRES(ctx, !is_nested,
errors::Unimplemented("Only non-nested TensorList is supported "
"for TensorListSetItem."));
xla::XlaOp result;
OP_REQUIRES_OK(ctx, ExecuteTensorListSetItem(initialized_list, index,
element, &result));
ctx->SetTensorListOutput(0, result);
}
private:
TF_DISALLOW_COPY_AND_ASSIGN(TensorListSetItemOp);
};
REGISTER_XLA_OP(Name("TensorListSetItem"), TensorListSetItemOp);
class TensorListPushBackOp : public XlaOpKernel {
public:
explicit TensorListPushBackOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
void Compile(XlaOpKernelContext* ctx) override {
xla::XlaOp list = ctx->Input(0);
xla::XlaOp element = ctx->Input(1);
bool element_is_tensor_list = IsTensorListInput(ctx, 1);
xla::XlaOp initialized_list;
OP_REQUIRES_OK(
ctx, GetInitializedTensorListForElement(
list, element, element_is_tensor_list, &initialized_list));
xla::XlaOp result;
OP_REQUIRES_OK(ctx,
ExecuteTensorListPushBack(initialized_list, element,
element_is_tensor_list, &result));
ctx->SetTensorListOutput(0, result);
}
private:
TF_DISALLOW_COPY_AND_ASSIGN(TensorListPushBackOp);
};
REGISTER_XLA_OP(Name("TensorListPushBack").AllowVariantTypes(),
TensorListPushBackOp);
class TensorListPopBackOp : public XlaOpKernel {
public:
explicit TensorListPopBackOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
void Compile(XlaOpKernelContext* ctx) override {
// Check that the TensorList is initialized.
bool is_initialized;
OP_REQUIRES_OK(ctx,
(IsTensorListInitialized(ctx->Input(0), &is_initialized)));
OP_REQUIRES(ctx, is_initialized,
errors::InvalidArgument("TensorList is not initialized"));
xla::XlaOp list = ctx->Input(0);
xla::XlaOp list_result, element_result;
bool element_is_tensor_list;
OP_REQUIRES_OK(ctx,
ExecuteTensorListPopBack(list, &list_result, &element_result,
&element_is_tensor_list));
ctx->SetTensorListOutput(0, list_result);
if (element_is_tensor_list) {
ctx->SetTensorListOutput(1, element_result);
} else {
ctx->SetOutput(1, element_result);
}
}
private:
DataType dtype_;
TF_DISALLOW_COPY_AND_ASSIGN(TensorListPopBackOp);
};
REGISTER_XLA_OP(Name("TensorListPopBack").AllowVariantTypes(),
TensorListPopBackOp);
} // anonymous namespace
} // namespace tensorflow