blob: c5d21ef8f5ac30f3e58c584d4c40f3bd0b1cc568 [file] [log] [blame]
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/shape_inference.h"
namespace tensorflow {
namespace {
// Verifies that `shapes_and_types` is a valid list handle and has the right
// dtype.
Status VerifyHandleData(
shape_inference::InferenceContext* c,
const std::vector<shape_inference::ShapeAndType>& shapes_and_types,
DataType element_dtype) {
if (shapes_and_types.size() != 1) {
return errors::InvalidArgument(
"Invalid handle_data for input list. Expected length of "
"shape_and_types: ",
1, " Saw: ", shapes_and_types.size());
}
const shape_inference::ShapeAndType& list_shape_type = shapes_and_types[0];
if (list_shape_type.dtype != element_dtype) {
return errors::InvalidArgument("Expected list with element dtype ",
DataTypeString(element_dtype),
" but got list with element dtype ",
DataTypeString(list_shape_type.dtype));
}
return Status::OK();
}
// Assumes that the handle_data is valid.
shape_inference::ShapeHandle GetElementShapeFromHandleData(
const std::vector<shape_inference::ShapeAndType>& shapes_and_types) {
return shapes_and_types[0].shape;
}
REGISTER_OP("EmptyTensorList")
.Input("element_shape: shape_type")
.Input("max_num_elements: int32")
.Output("handle: variant")
.Attr("element_dtype: type")
.Attr("shape_type: {int32, int64}")
.SetShapeFn([](shape_inference::InferenceContext* c) {
c->set_output(0, c->Scalar());
DataType element_dtype;
TF_RETURN_IF_ERROR(c->GetAttr("element_dtype", &element_dtype));
shape_inference::ShapeHandle element_shape;
TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensorTreatScalarAsUnknownShape(
0, &element_shape));
c->set_output_handle_shapes_and_types(
0, std::vector<shape_inference::ShapeAndType>{
{element_shape, element_dtype}});
return Status::OK();
});
REGISTER_OP("TensorListPushBack")
.Input("input_handle: variant")
.Input("tensor: element_dtype")
.Output("output_handle: variant")
.Attr("element_dtype: type")
.SetShapeFn([](shape_inference::InferenceContext* c) {
c->set_output(0, c->Scalar());
DataType element_dtype;
TF_RETURN_IF_ERROR(c->GetAttr("element_dtype", &element_dtype));
shape_inference::ShapeHandle element_shape = c->UnknownShape();
auto* handle_data = c->input_handle_shapes_and_types(0);
if (handle_data != nullptr && handle_data->size() > 1) {
return errors::InvalidArgument(
"Trying to push to list with wrong variant data.");
}
if (handle_data != nullptr && handle_data->size() == 1) {
const shape_inference::ShapeAndType& list_shape_type =
(*handle_data)[0];
if (list_shape_type.dtype != element_dtype) {
return errors::InvalidArgument(
"Trying to push to list with wrong element dtype. List has type ",
DataTypeString(list_shape_type.dtype),
" but trying to push element with type ",
DataTypeString(element_dtype));
}
shape_inference::ShapeHandle ignored;
TF_RETURN_IF_ERROR(
c->Merge(element_shape, list_shape_type.shape, &ignored));
element_shape = list_shape_type.shape;
}
c->set_output_handle_shapes_and_types(
0, std::vector<shape_inference::ShapeAndType>{
{element_shape, element_dtype}});
return Status::OK();
});
REGISTER_OP("TensorListPushBackBatch")
.Input("input_handles: variant")
.Input("tensor: element_dtype")
.Output("output_handles: variant")
.Attr("element_dtype: type")
.SetShapeFn([](shape_inference::InferenceContext* c) {
shape_inference::ShapeHandle input_handles;
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &input_handles));
shape_inference::ShapeHandle tensor;
TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 1, &tensor));
TF_RETURN_IF_ERROR(
c->MergePrefix(tensor, input_handles, &tensor, &input_handles));
c->set_output(0, input_handles);
DataType element_dtype;
TF_RETURN_IF_ERROR(c->GetAttr("element_dtype", &element_dtype));
shape_inference::ShapeHandle element_shape = c->UnknownShape();
auto* handle_data = c->input_handle_shapes_and_types(0);
if (handle_data != nullptr && handle_data->size() > 1) {
return errors::InvalidArgument(
"Trying to push to list with wrong variant data.");
}
if (handle_data != nullptr && handle_data->size() == 1) {
const shape_inference::ShapeAndType& list_shape_type =
(*handle_data)[0];
if (list_shape_type.dtype != element_dtype) {
return errors::InvalidArgument(
"Trying to push to list with wrong element dtype. List has type ",
DataTypeString(list_shape_type.dtype),
" but trying to push element with type ",
DataTypeString(element_dtype));
}
shape_inference::ShapeHandle ignored;
TF_RETURN_IF_ERROR(
c->Merge(element_shape, list_shape_type.shape, &ignored));
element_shape = list_shape_type.shape;
}
c->set_output_handle_shapes_and_types(
0, std::vector<shape_inference::ShapeAndType>{
{element_shape, element_dtype}});
return Status::OK();
});
REGISTER_OP("TensorListLength")
.Input("input_handle: variant")
.Output("length: int32")
.SetShapeFn(shape_inference::ScalarShape);
REGISTER_OP("TensorListPopBack")
.Input("input_handle: variant")
.Input("element_shape: int32")
.Output("output_handle: variant")
.Output("tensor: element_dtype")
.Attr("element_dtype: type")
.SetShapeFn([](shape_inference::InferenceContext* c) {
DataType element_dtype;
TF_RETURN_IF_ERROR(c->GetAttr("element_dtype", &element_dtype));
shape_inference::ShapeHandle tensor_shape = c->UnknownShape();
auto* handle_data = c->input_handle_shapes_and_types(0);
if (handle_data != nullptr && handle_data->size() > 1) {
return errors::InvalidArgument(
"Trying to read from list with invalid variant data.");
}
if (handle_data != nullptr && handle_data->size() == 1) {
const shape_inference::ShapeAndType& list_shape_type =
(*handle_data)[0];
if (list_shape_type.dtype != element_dtype) {
return errors::InvalidArgument(
"Trying to read from list with wrong element dtype. List has "
"type ",
DataTypeString(list_shape_type.dtype),
" but trying to push element with type ",
DataTypeString(element_dtype));
}
shape_inference::ShapeHandle ignored;
TF_RETURN_IF_ERROR(
c->Merge(tensor_shape, list_shape_type.shape, &ignored));
c->set_output_handle_shapes_and_types(0, *handle_data);
tensor_shape = list_shape_type.shape;
}
c->set_output(1, tensor_shape);
c->set_output(0, c->Scalar());
return Status::OK();
});
REGISTER_OP("TensorListStack")
.Input("input_handle: variant")
.Input("element_shape: int32")
.Output("tensor: element_dtype")
.Attr("element_dtype: type")
.Attr("num_elements: int = -1")
.SetShapeFn([](shape_inference::InferenceContext* c) {
DataType element_dtype;
TF_RETURN_IF_ERROR(c->GetAttr("element_dtype", &element_dtype));
shape_inference::ShapeHandle element_shape = c->UnknownShape();
auto* handle_data = c->input_handle_shapes_and_types(0);
if (handle_data != nullptr && handle_data->size() > 1) {
return errors::InvalidArgument(
"Trying to read from list with wrong variant data.");
}
if (handle_data != nullptr && handle_data->size() == 1) {
const shape_inference::ShapeAndType& list_shape_type =
(*handle_data)[0];
if (list_shape_type.dtype != element_dtype) {
return errors::InvalidArgument(
"Trying to read from list with wrong element dtype. List has "
"type ",
DataTypeString(list_shape_type.dtype), " but expected type ",
DataTypeString(element_dtype));
}
shape_inference::ShapeHandle ignored;
TF_RETURN_IF_ERROR(
c->Merge(element_shape, list_shape_type.shape, &ignored));
element_shape = list_shape_type.shape;
}
shape_inference::ShapeHandle element_shape_input = c->UnknownShape();
TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensorTreatScalarAsUnknownShape(
1, &element_shape_input));
TF_RETURN_IF_ERROR(
c->Merge(element_shape, element_shape_input, &element_shape));
int expected_num_elements = -1;
TF_RETURN_IF_ERROR(c->GetAttr("num_elements", &expected_num_elements));
shape_inference::ShapeHandle num_elements;
if (expected_num_elements == -1) {
num_elements = c->MakeShape({c->UnknownDim()});
} else {
num_elements = c->MakeShape({expected_num_elements});
}
shape_inference::ShapeHandle result;
TF_RETURN_IF_ERROR(c->Concatenate(num_elements, element_shape, &result));
c->set_output(0, result);
return Status::OK();
});
Status TensorListConcatShapeInference(
shape_inference::InferenceContext* c,
shape_inference::ShapeHandle element_shape) {
DataType element_dtype;
TF_RETURN_IF_ERROR(c->GetAttr("element_dtype", &element_dtype));
auto* handle_data = c->input_handle_shapes_and_types(0);
if (handle_data != nullptr && handle_data->size() > 1) {
return errors::InvalidArgument(
"Trying to read from list with wrong variant data.");
}
if (handle_data != nullptr && handle_data->size() == 1) {
const shape_inference::ShapeAndType& list_shape_type = (*handle_data)[0];
if (list_shape_type.dtype != element_dtype) {
return errors::InvalidArgument(
"Trying to read from list with wrong element dtype. List has "
"type ",
DataTypeString(list_shape_type.dtype), " but expected type ",
DataTypeString(element_dtype));
}
shape_inference::ShapeHandle merged;
TF_RETURN_IF_ERROR(c->Merge(element_shape, list_shape_type.shape, &merged));
element_shape = merged;
}
if (c->RankKnown(element_shape)) {
shape_inference::ShapeHandle result;
TF_RETURN_IF_ERROR(c->Subshape(element_shape, 1, &result));
TF_RETURN_IF_ERROR(
c->Concatenate(c->MakeShape({c->UnknownDim()}), result, &result));
c->set_output(0, result);
} else {
c->set_output(0, c->UnknownShape());
}
c->set_output(1, c->MakeShape({c->UnknownDim()}));
return Status::OK();
}
REGISTER_OP("TensorListConcat")
.Input("input_handle: variant")
.Output("tensor: element_dtype")
.Output("lengths: int64")
.Attr("element_dtype: type")
.Attr("element_shape: shape = { unknown_rank: true }")
.SetShapeFn([](shape_inference::InferenceContext* c) {
PartialTensorShape raw_element_shape;
TF_RETURN_IF_ERROR(c->GetAttr("element_shape", &raw_element_shape));
shape_inference::ShapeHandle element_shape;
TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(raw_element_shape,
&element_shape));
return TensorListConcatShapeInference(c, element_shape);
});
REGISTER_OP("TensorListConcatV2")
.Input("input_handle: variant")
.Input("element_shape: shape_type")
.Input("leading_dims: int64")
.Output("tensor: element_dtype")
.Output("lengths: int64")
.Attr("element_dtype: type")
.Attr("shape_type: {int32, int64}")
.SetShapeFn([](shape_inference::InferenceContext* c) {
shape_inference::ShapeHandle element_shape;
TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensorTreatScalarAsUnknownShape(
1, &element_shape));
return TensorListConcatShapeInference(c, element_shape);
});
REGISTER_OP("TensorListSplit")
.Input("tensor: element_dtype")
.Input("element_shape: shape_type")
.Input("lengths: int64")
.Output("output_handle: variant")
.Attr("element_dtype: type")
.Attr("shape_type: {int32, int64}")
.SetShapeFn([](shape_inference::InferenceContext* c) {
c->set_output(0, c->Scalar());
DataType element_dtype;
TF_RETURN_IF_ERROR(c->GetAttr("element_dtype", &element_dtype));
shape_inference::ShapeHandle tensor_shape = c->input(0);
shape_inference::ShapeHandle ignored;
// Check that tensor is at least a vector.
TF_RETURN_IF_ERROR(c->WithRankAtLeast(tensor_shape, 1, &ignored));
// Check that lengths is a vector.
TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &ignored));
shape_inference::ShapeHandle element_shape_from_tensor_shape;
TF_RETURN_IF_ERROR(
c->Subshape(tensor_shape, 1, &element_shape_from_tensor_shape));
TF_RETURN_IF_ERROR(c->Concatenate(c->MakeShape({c->UnknownDim()}),
element_shape_from_tensor_shape,
&element_shape_from_tensor_shape));
shape_inference::ShapeHandle element_shape;
TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensorTreatScalarAsUnknownShape(
1, &element_shape));
TF_RETURN_IF_ERROR(c->Merge(element_shape_from_tensor_shape,
element_shape,
&element_shape_from_tensor_shape));
c->set_output_handle_shapes_and_types(
0, std::vector<shape_inference::ShapeAndType>{
{element_shape, element_dtype}});
return Status::OK();
});
REGISTER_OP("TensorListFromTensor")
.Input("tensor: element_dtype")
.Input("element_shape: shape_type")
.Output("output_handle: variant")
.Attr("element_dtype: type")
.Attr("shape_type: {int32, int64}")
.SetShapeFn([](shape_inference::InferenceContext* c) {
c->set_output(0, c->Scalar());
DataType element_dtype;
TF_RETURN_IF_ERROR(c->GetAttr("element_dtype", &element_dtype));
shape_inference::ShapeHandle tensor_shape = c->input(0);
shape_inference::ShapeHandle tensor_shape_except_first_dim;
TF_RETURN_IF_ERROR(
c->Subshape(tensor_shape, 1, &tensor_shape_except_first_dim));
shape_inference::ShapeHandle element_shape;
TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensorTreatScalarAsUnknownShape(
1, &element_shape));
TF_RETURN_IF_ERROR(c->Merge(tensor_shape_except_first_dim, element_shape,
&tensor_shape_except_first_dim));
c->set_output_handle_shapes_and_types(
0, std::vector<shape_inference::ShapeAndType>{
{element_shape, element_dtype}});
return Status::OK();
});
REGISTER_OP("TensorListElementShape")
.Input("input_handle: variant")
.Output("element_shape: shape_type")
.Attr("shape_type: {int32, int64}")
.SetShapeFn([](shape_inference::InferenceContext* c) {
auto* handle_data = c->input_handle_shapes_and_types(0);
if (handle_data == nullptr) {
c->set_output(0, c->Vector(c->UnknownDim()));
return Status::OK();
}
c->set_output(0, c->Vector(c->Rank((*handle_data)[0].shape)));
return Status::OK();
});
REGISTER_OP("TensorListReserve")
.Input("element_shape: shape_type")
.Input("num_elements: int32")
.Output("handle: variant")
.Attr("element_dtype: type")
.Attr("shape_type: {int32, int64}")
.SetShapeFn([](shape_inference::InferenceContext* c) {
c->set_output(0, c->Scalar());
shape_inference::ShapeHandle element_shape;
TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensorTreatScalarAsUnknownShape(
0, &element_shape));
DataType element_dtype;
TF_RETURN_IF_ERROR(c->GetAttr("element_dtype", &element_dtype));
c->set_output_handle_shapes_and_types(
0, std::vector<shape_inference::ShapeAndType>{
{element_shape, element_dtype}});
return Status::OK();
});
REGISTER_OP("TensorListGetItem")
.Input("input_handle: variant")
.Input("index: int32")
.Input("element_shape: int32")
.Output("item: element_dtype")
.Attr("element_dtype: type")
.SetShapeFn([](shape_inference::InferenceContext* c) {
DataType element_dtype;
TF_RETURN_IF_ERROR(c->GetAttr("element_dtype", &element_dtype));
auto* handle_data = c->input_handle_shapes_and_types(0);
shape_inference::ShapeHandle element_shape = c->UnknownShape();
if (handle_data != nullptr) {
const shape_inference::ShapeAndType& list_shape_type =
(*handle_data)[0];
element_shape = list_shape_type.shape;
if (list_shape_type.dtype != element_dtype) {
return errors::InvalidArgument("Expected list with element dtype ",
DataTypeString(element_dtype),
" but got list with element dtype ",
DataTypeString(list_shape_type.dtype));
}
}
shape_inference::ShapeHandle element_shape_input = c->UnknownShape();
TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensorTreatScalarAsUnknownShape(
2, &element_shape_input));
TF_RETURN_IF_ERROR(
c->Merge(element_shape, element_shape_input, &element_shape));
c->set_output(0, element_shape);
return Status::OK();
});
REGISTER_OP("TensorListResize")
.Input("input_handle: variant")
.Input("size: int32")
.Output("output_handle: variant")
.SetIsStateful()
.SetShapeFn([](shape_inference::InferenceContext* c) {
// Check that `size` has scalar shape.
shape_inference::ShapeHandle size_shape = c->input(1);
shape_inference::ShapeHandle unused;
TF_RETURN_IF_ERROR(c->WithRank(size_shape, 0, &unused));
c->set_output(0, c->Scalar());
auto* handle_data = c->input_handle_shapes_and_types(0);
if (handle_data != nullptr) {
c->set_output_handle_shapes_and_types(0, *handle_data);
}
return Status::OK();
});
REGISTER_OP("TensorListSetItem")
.Input("input_handle: variant")
.Input("index: int32")
.Input("item: element_dtype")
.Output("output_handle: variant")
.Attr("element_dtype: type")
.SetShapeFn([](shape_inference::InferenceContext* c) {
DataType element_dtype;
TF_RETURN_IF_ERROR(c->GetAttr("element_dtype", &element_dtype));
auto* handle_data = c->input_handle_shapes_and_types(0);
c->set_output(0, c->Scalar());
if (handle_data == nullptr) {
c->set_output_handle_shapes_and_types(
0, {{c->UnknownShape(), element_dtype}});
return Status::OK();
}
const shape_inference::ShapeAndType& list_shape_type = (*handle_data)[0];
shape_inference::ShapeHandle item_shape = c->input(2);
TF_RETURN_IF_ERROR(
c->Merge(item_shape, list_shape_type.shape, &item_shape));
c->set_output_handle_shapes_and_types(0, *handle_data);
return Status::OK();
});
REGISTER_OP("TensorListGather")
.Input("input_handle: variant")
.Input("indices: int32")
.Input("element_shape: int32")
.Output("values: element_dtype")
.Attr("element_dtype: type")
.SetShapeFn([](shape_inference::InferenceContext* c) {
DataType element_dtype;
TF_RETURN_IF_ERROR(c->GetAttr("element_dtype", &element_dtype));
auto* handle_data = c->input_handle_shapes_and_types(0);
shape_inference::ShapeHandle element_shape = c->UnknownShape();
if (handle_data != nullptr) {
const shape_inference::ShapeAndType& list_shape_type =
(*handle_data)[0];
element_shape = list_shape_type.shape;
if (list_shape_type.dtype != element_dtype) {
return errors::InvalidArgument("Expected list with element dtype ",
DataTypeString(element_dtype),
" but got list with element dtype ",
DataTypeString(list_shape_type.dtype));
}
}
shape_inference::ShapeHandle element_shape_input = c->UnknownShape();
TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensorTreatScalarAsUnknownShape(
2, &element_shape_input));
TF_RETURN_IF_ERROR(
c->Merge(element_shape, element_shape_input, &element_shape));
shape_inference::ShapeHandle out;
TF_RETURN_IF_ERROR(c->Concatenate(c->input(1), element_shape, &out));
c->set_output(0, out);
return Status::OK();
});
REGISTER_OP("TensorListScatter")
.Input("tensor: element_dtype")
.Input("indices: int32")
.Input("element_shape: shape_type")
.Output("output_handle: variant")
.Attr("element_dtype: type")
.Attr("shape_type: {int32, int64}")
.SetShapeFn([](shape_inference::InferenceContext* c) {
DataType element_dtype;
TF_RETURN_IF_ERROR(c->GetAttr("element_dtype", &element_dtype));
shape_inference::ShapeHandle element_shape;
TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensorTreatScalarAsUnknownShape(
2, &element_shape));
c->set_output_handle_shapes_and_types(0,
{{element_shape, element_dtype}});
c->set_output(0, c->Scalar());
return Status::OK();
});
REGISTER_OP("TensorListScatterV2")
.Input("tensor: element_dtype")
.Input("indices: int32")
.Input("element_shape: shape_type")
.Input("num_elements: int32")
.Output("output_handle: variant")
.Attr("element_dtype: type")
.Attr("shape_type: {int32, int64}")
.SetShapeFn([](shape_inference::InferenceContext* c) {
DataType element_dtype;
TF_RETURN_IF_ERROR(c->GetAttr("element_dtype", &element_dtype));
shape_inference::ShapeHandle element_shape;
TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensorTreatScalarAsUnknownShape(
2, &element_shape));
c->set_output_handle_shapes_and_types(0,
{{element_shape, element_dtype}});
c->set_output(0, c->Scalar());
return Status::OK();
});
REGISTER_OP("TensorListScatterIntoExistingList")
.Input("input_handle: variant")
.Input("tensor: element_dtype")
.Input("indices: int32")
.Output("output_handle: variant")
.Attr("element_dtype: type")
.SetShapeFn([](shape_inference::InferenceContext* c) {
shape_inference::ShapeHandle ignored;
// Check that tensor is at least a vector.
TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 1, &ignored));
// Check that indices is a vector.
TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &ignored));
DataType element_dtype;
TF_RETURN_IF_ERROR(c->GetAttr("element_dtype", &element_dtype));
shape_inference::ShapeHandle element_shape = c->UnknownShape();
auto* handle_data = c->input_handle_shapes_and_types(0);
if (handle_data != nullptr) {
TF_RETURN_IF_ERROR(VerifyHandleData(c, *handle_data, element_dtype));
element_shape = GetElementShapeFromHandleData(*handle_data);
}
c->set_output_handle_shapes_and_types(0,
{{element_shape, element_dtype}});
c->set_output(0, c->Scalar());
return Status::OK();
});
REGISTER_OP("TensorListConcatLists")
.Input("input_a: variant")
.Input("input_b: variant")
.Attr("element_dtype: type")
.Output("output: variant")
.SetShapeFn([](shape_inference::InferenceContext* c) {
auto input_a = c->input(0);
auto input_b = c->input(1);
TF_RETURN_IF_ERROR(c->Merge(input_a, input_b, &input_a));
c->set_output(0, input_a);
DataType element_dtype;
TF_RETURN_IF_ERROR(c->GetAttr("element_dtype", &element_dtype));
auto* handle_data_a = c->input_handle_shapes_and_types(0);
auto* handle_data_b = c->input_handle_shapes_and_types(1);
bool handle_data_a_nonempty = handle_data_a && !handle_data_a->empty();
bool handle_data_b_nonempty = handle_data_b && !handle_data_b->empty();
if (!(handle_data_a_nonempty || handle_data_b_nonempty)) {
c->set_output_handle_shapes_and_types(
0, {{c->UnknownShape(), element_dtype}});
return Status::OK();
}
shape_inference::ShapeAndType list_shape_type_a =
handle_data_a_nonempty ? handle_data_a->at(0) : handle_data_b->at(0);
const shape_inference::ShapeAndType& list_shape_type_b =
handle_data_b_nonempty ? handle_data_b->at(0) : handle_data_a->at(0);
if (list_shape_type_a.dtype != element_dtype) {
return errors::InvalidArgument("input_a.type != element_dtype: ",
DataTypeString(list_shape_type_a.dtype),
" vs. ", DataTypeString(element_dtype));
}
if (list_shape_type_b.dtype != element_dtype) {
return errors::InvalidArgument("input_b.type != element_dtype: ",
DataTypeString(list_shape_type_b.dtype),
" vs. ", DataTypeString(element_dtype));
}
TF_RETURN_IF_ERROR(c->Merge(list_shape_type_a.shape,
list_shape_type_b.shape,
&list_shape_type_a.shape));
c->set_output_handle_shapes_and_types(0, {list_shape_type_a});
return Status::OK();
});
} // namespace
} // namespace tensorflow