| /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. |
| |
| Licensed under the Apache License, Version 2.0 (the "License"); |
| you may not use this file except in compliance with the License. |
| You may obtain a copy of the License at |
| |
| http://www.apache.org/licenses/LICENSE-2.0 |
| |
| Unless required by applicable law or agreed to in writing, software |
| distributed under the License is distributed on an "AS IS" BASIS, |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| See the License for the specific language governing permissions and |
| limitations under the License. |
| ==============================================================================*/ |
| |
| #include <limits> |
| |
| #include "tensorflow/core/framework/allocator.h" |
| |
| #define EIGEN_USE_THREADS |
| #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
| #define EIGEN_USE_GPU |
| #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
| |
| #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" |
| #include "tensorflow/core/framework/op_kernel.h" |
| #include "tensorflow/core/framework/register_types.h" |
| #include "tensorflow/core/framework/tensor_types.h" |
| #include "tensorflow/core/framework/variant.h" |
| #include "tensorflow/core/framework/variant_op_registry.h" |
| #include "tensorflow/core/kernels/concat_lib.h" |
| #include "tensorflow/core/kernels/list_kernels.h" |
| #include "tensorflow/core/lib/core/coding.h" |
| #include "tensorflow/core/lib/core/errors.h" |
| #include "tensorflow/core/util/util.h" |
| |
| namespace tensorflow { |
| |
| typedef Eigen::ThreadPoolDevice CPUDevice; |
| |
| TensorList::~TensorList() { |
| if (tensors_) tensors_->Unref(); |
| } |
| |
| void TensorList::Encode(VariantTensorData* data) const { |
| data->set_type_name(TypeName()); |
| std::vector<size_t> invalid_indices; |
| for (size_t i = 0; i < tensors().size(); i++) { |
| if (tensors().at(i).dtype() != DT_INVALID) { |
| *data->add_tensors() = tensors().at(i); |
| } else { |
| invalid_indices.push_back(i); |
| } |
| } |
| string metadata; |
| // TODO(b/118838800): Add a proto for storing the metadata. |
| // Metadata format: |
| // <num_invalid_tensors><invalid_indices><element_dtype><element_shape_proto> |
| core::PutVarint64(&metadata, static_cast<uint64>(invalid_indices.size())); |
| for (size_t i : invalid_indices) { |
| core::PutVarint64(&metadata, static_cast<uint64>(i)); |
| } |
| core::PutVarint64(&metadata, static_cast<uint64>(element_dtype)); |
| core::PutVarint64(&metadata, static_cast<uint64>(max_num_elements)); |
| TensorShapeProto element_shape_proto; |
| element_shape.AsProto(&element_shape_proto); |
| element_shape_proto.AppendToString(&metadata); |
| data->set_metadata(metadata); |
| } |
| |
| static Status TensorListDeviceCopy( |
| const TensorList& from, TensorList* to, |
| const UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn& copy) { |
| to->element_shape = from.element_shape; |
| to->element_dtype = from.element_dtype; |
| to->max_num_elements = from.max_num_elements; |
| to->tensors().reserve(from.tensors().size()); |
| for (const Tensor& t : from.tensors()) { |
| to->tensors().emplace_back(t.dtype()); |
| if (t.dtype() != DT_INVALID) { |
| TF_RETURN_IF_ERROR(copy(t, &to->tensors().back())); |
| } |
| } |
| return Status::OK(); |
| } |
| |
| #define REGISTER_LIST_COPY(DIRECTION) \ |
| INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION(TensorList, DIRECTION, \ |
| TensorListDeviceCopy) |
| |
| REGISTER_LIST_COPY(VariantDeviceCopyDirection::HOST_TO_DEVICE); |
| REGISTER_LIST_COPY(VariantDeviceCopyDirection::DEVICE_TO_HOST); |
| REGISTER_LIST_COPY(VariantDeviceCopyDirection::DEVICE_TO_DEVICE); |
| |
| REGISTER_UNARY_VARIANT_DECODE_FUNCTION(TensorList, TensorList::kTypeName); |
| |
| bool TensorList::Decode(const VariantTensorData& data) { |
| // TODO(srbs): Change the signature to Decode(VariantTensorData data) so |
| // that we do not have to copy each tensor individually below. This would |
| // require changing VariantTensorData::tensors() as well. |
| string metadata; |
| data.get_metadata(&metadata); |
| uint64 scratch; |
| StringPiece iter(metadata); |
| std::vector<size_t> invalid_indices; |
| core::GetVarint64(&iter, &scratch); |
| size_t num_invalid_tensors = static_cast<size_t>(scratch); |
| invalid_indices.resize(num_invalid_tensors); |
| for (size_t i = 0; i < num_invalid_tensors; i++) { |
| core::GetVarint64(&iter, &scratch); |
| invalid_indices[i] = static_cast<size_t>(scratch); |
| } |
| |
| size_t total_num_tensors = data.tensors().size() + num_invalid_tensors; |
| tensors().reserve(total_num_tensors); |
| std::vector<size_t>::iterator invalid_indices_it = invalid_indices.begin(); |
| std::vector<Tensor>::const_iterator tensors_it = data.tensors().begin(); |
| for (size_t i = 0; i < total_num_tensors; i++) { |
| if (invalid_indices_it != invalid_indices.end() && |
| *invalid_indices_it == i) { |
| tensors().emplace_back(Tensor(DT_INVALID)); |
| invalid_indices_it++; |
| } else if (tensors_it != data.tensors().end()) { |
| tensors().emplace_back(*tensors_it); |
| tensors_it++; |
| } else { |
| // VariantTensorData is corrupted. |
| return false; |
| } |
| } |
| |
| core::GetVarint64(&iter, &scratch); |
| element_dtype = static_cast<DataType>(scratch); |
| core::GetVarint64(&iter, &scratch); |
| max_num_elements = static_cast<int>(scratch); |
| TensorShapeProto element_shape_proto; |
| element_shape_proto.ParseFromString(string(iter.data(), iter.size())); |
| element_shape = PartialTensorShape(element_shape_proto); |
| return true; |
| } |
| |
| Status TensorShapeFromTensor(const Tensor& t, PartialTensorShape* out) { |
| if (t.shape() == TensorShape({})) { |
| if ((t.dtype() == DT_INT32 && t.scalar<int32>()() == -1) || |
| (t.dtype() == DT_INT64 && t.scalar<int64>()() == -1)) { |
| *out = PartialTensorShape(); |
| return Status::OK(); |
| } |
| return errors::InvalidArgument( |
| "The only valid scalar shape tensor is the fully unknown shape " |
| "specified as -1."); |
| } |
| if (t.dtype() == DT_INT32) { |
| return PartialTensorShape::MakePartialShape(t.vec<int32>().data(), |
| t.NumElements(), out); |
| } else if (t.dtype() == DT_INT64) { |
| return PartialTensorShape::MakePartialShape(t.vec<int64>().data(), |
| t.NumElements(), out); |
| } |
| return errors::InvalidArgument( |
| "Expected an int32 or int64 shape tensor; found ", |
| DataTypeString(t.dtype())); |
| } |
| |
| Status GetElementShapeFromInput(OpKernelContext* c, |
| const TensorList& tensor_list, int index, |
| PartialTensorShape* element_shape) { |
| TF_RETURN_IF_ERROR(TensorShapeFromTensor(c->input(index), element_shape)); |
| // Check that `element_shape` and `tensor_list.element_shape` are |
| // compatible and store the merged shape in `element_shape`. |
| PartialTensorShape tmp = *element_shape; |
| TF_RETURN_IF_ERROR(tmp.MergeWith(tensor_list.element_shape, element_shape)); |
| return Status::OK(); |
| } |
| |
| Status GetInputList(OpKernelContext* c, int index, const TensorList** list) { |
| if (!TensorShapeUtils::IsScalar(c->input(index).shape())) { |
| return errors::InvalidArgument("Input list must be a scalar saw: ", |
| c->input(index).shape().DebugString()); |
| } |
| const TensorList* l = c->input(index).scalar<Variant>()().get<TensorList>(); |
| if (l == nullptr) { |
| return errors::InvalidArgument( |
| "Input handle is not a list. Saw: '", |
| c->input(index).scalar<Variant>()().DebugString(), "'"); |
| } |
| *list = l; |
| return Status::OK(); |
| } |
| |
| Status ForwardInputOrCreateNewList(OpKernelContext* c, int32 input_index, |
| int32 output_index, |
| const TensorList& input_list, |
| TensorList** output_list) { |
| // Attempt to forward the input tensor to the output if possible. |
| std::unique_ptr<Tensor> maybe_output = c->forward_input( |
| input_index, output_index, DT_VARIANT, TensorShape{}, |
| c->input_memory_type(input_index), AllocatorAttributes()); |
| Tensor* output_tensor; |
| if (maybe_output != nullptr && maybe_output->dtype() == DT_VARIANT && |
| maybe_output->NumElements() == 1) { |
| output_tensor = maybe_output.get(); |
| TensorList* tmp_out = output_tensor->scalar<Variant>()().get<TensorList>(); |
| if (tmp_out == nullptr) { |
| return errors::InvalidArgument( |
| "Expected input ", input_index, " to be a TensorList but saw ", |
| output_tensor->scalar<Variant>()().TypeName()); |
| } |
| if (tmp_out->RefCountIsOne()) { |
| // Woohoo, forwarding succeeded! |
| c->set_output(output_index, *output_tensor); |
| *output_list = tmp_out; |
| return Status::OK(); |
| } |
| } |
| |
| // If forwarding is not possible allocate a new output tensor and copy |
| // the `input_list` to it. |
| AllocatorAttributes attr; |
| attr.set_on_host(true); |
| TF_RETURN_IF_ERROR( |
| c->allocate_output(output_index, {}, &output_tensor, attr)); |
| output_tensor->scalar<Variant>()() = input_list.Copy(); |
| |
| *output_list = output_tensor->scalar<Variant>()().get<TensorList>(); |
| return Status::OK(); |
| } |
| |
| class EmptyTensorList : public OpKernel { |
| public: |
| explicit EmptyTensorList(OpKernelConstruction* ctx) : OpKernel(ctx) { |
| OP_REQUIRES_OK(ctx, ctx->GetAttr("element_dtype", &element_dtype_)); |
| } |
| |
| void Compute(OpKernelContext* ctx) override { |
| const Tensor& max_num_elements_t = ctx->input(1); |
| OP_REQUIRES( |
| ctx, TensorShapeUtils::IsScalar(max_num_elements_t.shape()), |
| errors::InvalidArgument( |
| "max_num_elements expected to be a scalar ", |
| "but got shape: ", max_num_elements_t.shape().DebugString())); |
| Tensor* result; |
| AllocatorAttributes attr; |
| attr.set_on_host(true); |
| OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape{}, &result, attr)); |
| TensorList empty; |
| empty.element_dtype = element_dtype_; |
| empty.max_num_elements = max_num_elements_t.scalar<int32>()(); |
| PartialTensorShape element_shape; |
| OP_REQUIRES_OK(ctx, TensorShapeFromTensor(ctx->input(0), &element_shape)); |
| empty.element_shape = element_shape; |
| result->scalar<Variant>()() = std::move(empty); |
| } |
| |
| private: |
| DataType element_dtype_; |
| }; |
| |
| const char TensorList::kTypeName[] = "tensorflow::TensorList"; |
| |
| REGISTER_KERNEL_BUILDER(Name("EmptyTensorList").Device(DEVICE_CPU), |
| EmptyTensorList); |
| |
| #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
| |
| REGISTER_KERNEL_BUILDER(Name("EmptyTensorList") |
| .Device(DEVICE_GPU) |
| .HostMemory("element_shape") |
| .HostMemory("max_num_elements"), |
| EmptyTensorList); |
| |
| #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
| |
| class TensorListPushBack : public OpKernel { |
| public: |
| explicit TensorListPushBack(OpKernelConstruction* c) : OpKernel(c) { |
| OP_REQUIRES_OK(c, c->GetAttr("element_dtype", &element_dtype_)); |
| } |
| |
| ~TensorListPushBack() override {} |
| |
| void Compute(OpKernelContext* c) override { |
| const Tensor& input = c->input(1); |
| OP_REQUIRES(c, element_dtype_ == input.dtype(), |
| errors::InvalidArgument("Invalid data types; list elements ", |
| DataTypeString(element_dtype_), |
| " but tried to append ", |
| DataTypeString(input.dtype()))); |
| |
| const TensorList* l = nullptr; |
| OP_REQUIRES_OK(c, GetInputList(c, 0, &l)); |
| OP_REQUIRES(c, l->element_shape.IsCompatibleWith(input.shape()), |
| errors::InvalidArgument( |
| "Tried to append a tensor with incompatible shape to a " |
| "list. Op element shape: ", |
| input.shape().DebugString(), |
| " list shape: ", l->element_shape.DebugString())); |
| OP_REQUIRES(c, element_dtype_ == l->element_dtype, |
| errors::InvalidArgument("Invalid data types; op elements ", |
| DataTypeString(element_dtype_), |
| " but list elements ", |
| DataTypeString(l->element_dtype))); |
| |
| if (l->max_num_elements != -1) { |
| OP_REQUIRES( |
| c, l->tensors().size() < l->max_num_elements, |
| errors::InvalidArgument("Tried to push item into a full list", |
| " list size: ", l->tensors().size(), |
| " max_num_elements: ", l->max_num_elements)); |
| } |
| |
| TensorList* output_list = nullptr; |
| OP_REQUIRES_OK(c, ForwardInputOrCreateNewList(c, 0, 0, *l, &output_list)); |
| output_list->tensors().push_back(input); |
| } |
| |
| private: |
| DataType element_dtype_; |
| }; |
| |
| REGISTER_KERNEL_BUILDER(Name("TensorListPushBack").Device(DEVICE_CPU), |
| TensorListPushBack); |
| |
| #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
| |
| REGISTER_KERNEL_BUILDER(Name("TensorListPushBack").Device(DEVICE_GPU), |
| TensorListPushBack); |
| |
| #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
| |
| class TensorListLength : public OpKernel { |
| public: |
| explicit TensorListLength(OpKernelConstruction* c) : OpKernel(c) {} |
| ~TensorListLength() override {} |
| |
| void Compute(OpKernelContext* c) override { |
| const TensorList* l = nullptr; |
| OP_REQUIRES_OK(c, GetInputList(c, 0, &l)); |
| Tensor* result; |
| OP_REQUIRES_OK(c, c->allocate_output(0, TensorShape{}, &result)); |
| result->scalar<int32>()() = l->tensors().size(); |
| } |
| }; |
| |
| REGISTER_KERNEL_BUILDER(Name("TensorListLength").Device(DEVICE_CPU), |
| TensorListLength); |
| |
| #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
| |
| REGISTER_KERNEL_BUILDER( |
| Name("TensorListLength").Device(DEVICE_GPU).HostMemory("length"), |
| TensorListLength); |
| |
| #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
| |
| class TensorListElementShape : public OpKernel { |
| public: |
| explicit TensorListElementShape(OpKernelConstruction* c) : OpKernel(c) {} |
| |
| void Compute(OpKernelContext* c) override { |
| const TensorList* l = nullptr; |
| OP_REQUIRES_OK(c, GetInputList(c, 0, &l)); |
| Tensor* result; |
| if (l->element_shape.unknown_rank()) { |
| OP_REQUIRES_OK(c, c->allocate_output(0, TensorShape({}), &result)); |
| if (result->dtype() == DT_INT32) { |
| result->scalar<int32>()() = -1; |
| } else { |
| result->scalar<int64>()() = -1; |
| } |
| } else { |
| OP_REQUIRES_OK(c, c->allocate_output( |
| 0, TensorShape{l->element_shape.dims()}, &result)); |
| for (int i = 0; i < l->element_shape.dims(); ++i) { |
| if (result->dtype() == DT_INT32) { |
| result->flat<int32>()(i) = l->element_shape.dim_size(i); |
| } else { |
| result->flat<int64>()(i) = l->element_shape.dim_size(i); |
| } |
| } |
| } |
| } |
| }; |
| |
| REGISTER_KERNEL_BUILDER(Name("TensorListElementShape").Device(DEVICE_CPU), |
| TensorListElementShape); |
| |
| #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
| |
| REGISTER_KERNEL_BUILDER(Name("TensorListElementShape") |
| .Device(DEVICE_GPU) |
| .HostMemory("element_shape"), |
| TensorListElementShape); |
| |
| #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
| |
| class TensorListReserve : public OpKernel { |
| public: |
| explicit TensorListReserve(OpKernelConstruction* c) : OpKernel(c) { |
| OP_REQUIRES_OK(c, c->GetAttr("element_dtype", &element_dtype_)); |
| } |
| |
| void Compute(OpKernelContext* c) override { |
| PartialTensorShape element_shape; |
| OP_REQUIRES_OK(c, TensorShapeFromTensor(c->input(0), &element_shape)); |
| int32 num_elements = c->input(1).scalar<int32>()(); |
| TensorList output; |
| output.element_shape = element_shape; |
| output.element_dtype = element_dtype_; |
| output.tensors().resize(num_elements, Tensor(DT_INVALID)); |
| Tensor* result; |
| AllocatorAttributes attr; |
| attr.set_on_host(true); |
| OP_REQUIRES_OK(c, c->allocate_output(0, TensorShape{}, &result, attr)); |
| result->scalar<Variant>()() = std::move(output); |
| } |
| |
| private: |
| DataType element_dtype_; |
| }; |
| |
| REGISTER_KERNEL_BUILDER(Name("TensorListReserve").Device(DEVICE_CPU), |
| TensorListReserve); |
| |
| #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
| |
| REGISTER_KERNEL_BUILDER(Name("TensorListReserve") |
| .Device(DEVICE_GPU) |
| .HostMemory("element_shape") |
| .HostMemory("num_elements"), |
| TensorListReserve); |
| |
| #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
| class TensorListResize : public OpKernel { |
| public: |
| explicit TensorListResize(OpKernelConstruction* c) : OpKernel(c) {} |
| |
| void Compute(OpKernelContext* c) override { |
| const TensorList* input_list = nullptr; |
| OP_REQUIRES_OK(c, GetInputList(c, 0, &input_list)); |
| int32 size = c->input(1).scalar<int32>()(); |
| OP_REQUIRES( |
| c, size >= 0, |
| errors::InvalidArgument( |
| "TensorListSlice expects size to be non-negative. Got: ", size)); |
| |
| std::unique_ptr<Tensor> maybe_result = |
| c->forward_input(0, 0, DT_VARIANT, TensorShape{}, |
| c->input_memory_type(0), AllocatorAttributes()); |
| if (maybe_result != nullptr) { |
| TensorList* out = maybe_result->scalar<Variant>()().get<TensorList>(); |
| if (out->RefCountIsOne()) { |
| // We are able to forward the input. |
| out->tensors().resize(size, Tensor(DT_INVALID)); |
| c->set_output(0, *maybe_result); |
| return; |
| } |
| } |
| |
| // We were not able to forward the input. Will have to resize from scratch. |
| Tensor* result; |
| AllocatorAttributes attr; |
| attr.set_on_host(true); |
| OP_REQUIRES_OK(c, c->allocate_output(0, TensorShape{}, &result, attr)); |
| TensorList output_list; |
| output_list.element_shape = input_list->element_shape; |
| output_list.element_dtype = input_list->element_dtype; |
| output_list.max_num_elements = input_list->max_num_elements; |
| if (size > input_list->tensors().size()) { |
| output_list.tensors().insert(output_list.tensors().begin(), |
| input_list->tensors().begin(), |
| input_list->tensors().end()); |
| // Add DT_INVALID tensors to the end of the list if the requested size |
| // is larger than the list length. |
| output_list.tensors().resize(size, Tensor(DT_INVALID)); |
| } else { |
| output_list.tensors().insert(output_list.tensors().begin(), |
| input_list->tensors().begin(), |
| input_list->tensors().begin() + size); |
| } |
| result->scalar<Variant>()() = std::move(output_list); |
| } |
| }; |
| |
| REGISTER_KERNEL_BUILDER(Name("TensorListResize").Device(DEVICE_CPU), |
| TensorListResize); |
| |
| #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
| |
| REGISTER_KERNEL_BUILDER( |
| Name("TensorListResize").Device(DEVICE_GPU).HostMemory("size"), |
| TensorListResize); |
| |
| #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
| |
| class TensorListSetItem : public OpKernel { |
| public: |
| explicit TensorListSetItem(OpKernelConstruction* c) : OpKernel(c) { |
| OP_REQUIRES_OK(c, c->GetAttr("element_dtype", &element_dtype_)); |
| } |
| |
| void Compute(OpKernelContext* c) override { |
| const TensorList* l = nullptr; |
| OP_REQUIRES_OK(c, GetInputList(c, 0, &l)); |
| OP_REQUIRES(c, element_dtype_ == l->element_dtype, |
| errors::InvalidArgument("Invalid data types; op elements ", |
| DataTypeString(element_dtype_), |
| " but list elements ", |
| DataTypeString(l->element_dtype))); |
| int32 index = c->input(1).scalar<int32>()(); |
| OP_REQUIRES(c, index < l->tensors().size(), |
| errors::InvalidArgument("Trying to modify element ", index, |
| " in a list with ", l->tensors().size(), |
| " elements.")); |
| const Tensor& value = c->input(2); |
| OP_REQUIRES(c, l->element_shape.IsCompatibleWith(value.shape()), |
| errors::InvalidArgument( |
| "Tried to set a tensor with incompatible shape at a " |
| "list index. Item element shape: ", |
| value.shape().DebugString(), |
| " list shape: ", l->element_shape.DebugString())); |
| TensorList* output_list = nullptr; |
| OP_REQUIRES_OK(c, ForwardInputOrCreateNewList(c, 0, 0, *l, &output_list)); |
| output_list->tensors()[index] = value; |
| } |
| |
| private: |
| DataType element_dtype_; |
| }; |
| |
| REGISTER_KERNEL_BUILDER(Name("TensorListSetItem").Device(DEVICE_CPU), |
| TensorListSetItem); |
| |
| #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
| |
| #define REGISTER_TENSOR_LIST_SET_ITEM_GPU(T) \ |
| REGISTER_KERNEL_BUILDER(Name("TensorListSetItem") \ |
| .TypeConstraint<T>("element_dtype") \ |
| .Device(DEVICE_GPU) \ |
| .HostMemory("index"), \ |
| TensorListSetItem); |
| |
| TF_CALL_GPU_NUMBER_TYPES(REGISTER_TENSOR_LIST_SET_ITEM_GPU); |
| TF_CALL_complex64(REGISTER_TENSOR_LIST_SET_ITEM_GPU); |
| TF_CALL_complex128(REGISTER_TENSOR_LIST_SET_ITEM_GPU); |
| TF_CALL_int32(REGISTER_TENSOR_LIST_SET_ITEM_GPU); |
| TF_CALL_int64(REGISTER_TENSOR_LIST_SET_ITEM_GPU); |
| REGISTER_TENSOR_LIST_SET_ITEM_GPU(bfloat16) |
| #undef REGISTER_TENSOR_LIST_SET_ITEM_GPU |
| |
| #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
| |
| class TensorListConcatLists : public OpKernel { |
| public: |
| explicit TensorListConcatLists(OpKernelConstruction* c) : OpKernel(c) { |
| OP_REQUIRES_OK(c, c->GetAttr("element_dtype", &element_dtype_)); |
| } |
| |
| void Compute(OpKernelContext* c) override { |
| const TensorShape& tl_a_shape = c->input(0).shape(); |
| const TensorShape& tl_b_shape = c->input(1).shape(); |
| OP_REQUIRES( |
| c, tl_a_shape == tl_b_shape, |
| errors::InvalidArgument("Incompatible input TensorList tensor shapes: ", |
| tl_a_shape.DebugString(), " vs. ", |
| tl_b_shape.DebugString())); |
| AllocatorAttributes attr; |
| std::unique_ptr<Tensor> tl_alias = c->forward_input( |
| 0 /*input_index*/, 0 /*output_index*/, DT_VARIANT, tl_a_shape, |
| DEVICE_MEMORY /* input is always on DEVICE_MEMORY */, attr); |
| |
| // tl_a may be aliased by tl_alias. |
| const Tensor& tl_a = c->input(0); |
| const Tensor& tl_b = c->input(1); |
| |
| Tensor* output = nullptr; |
| bool ok_to_alias = tl_alias != nullptr; |
| if (tl_alias && tl_alias->dtype() == DT_VARIANT && |
| tl_alias->NumElements() > 0) { |
| auto tl_a_t = tl_alias->flat<Variant>(); |
| for (int64 i = 0; i < tl_alias->NumElements(); ++i) { |
| TensorList* aliased = tl_a_t(i).get<TensorList>(); |
| if (aliased == nullptr || !aliased->RefCountIsOne()) { |
| ok_to_alias = false; |
| break; |
| } |
| } |
| if (ok_to_alias) { |
| c->set_output(0, *tl_alias); |
| output = tl_alias.get(); |
| } |
| } |
| if (!ok_to_alias) { |
| // Couldn't alias the entire Tensor. We'll be conservative and not try |
| // to alias individual batch entries. |
| attr.set_on_host(true); |
| OP_REQUIRES_OK(c, c->allocate_output(0, tl_a_shape, &output, attr)); |
| } |
| |
| auto output_t = output->flat<Variant>(); |
| auto tl_a_t = tl_a.flat<Variant>(); |
| auto tl_b_t = tl_b.flat<Variant>(); |
| |
| for (int64 i = 0; i < tl_a.NumElements(); ++i) { |
| const TensorList* l_a = tl_a_t(i).get<TensorList>(); |
| const TensorList* l_b = tl_b_t(i).get<TensorList>(); |
| OP_REQUIRES( |
| c, l_a != nullptr, |
| errors::InvalidArgument("input_a is not a TensorList at index ", i, |
| ". Saw: '", tl_a_t(i).DebugString(), "'")); |
| OP_REQUIRES( |
| c, l_b != nullptr, |
| errors::InvalidArgument("input_b is not a TensorList at index ", i, |
| ". Saw: '", tl_b_t(i).DebugString(), "'")); |
| OP_REQUIRES(c, l_a->element_dtype == element_dtype_, |
| errors::InvalidArgument( |
| "input_a[", i, "].dtype != element_dtype. Saw: ", |
| DataTypeString(l_a->element_dtype), " vs. ", |
| DataTypeString(element_dtype_))); |
| OP_REQUIRES(c, l_b->element_dtype == element_dtype_, |
| errors::InvalidArgument( |
| "input_b[", i, "].dtype != element_dtype. Saw: ", |
| DataTypeString(l_b->element_dtype), " vs. ", |
| DataTypeString(element_dtype_))); |
| OP_REQUIRES(c, l_a->element_shape.IsIdenticalTo(l_b->element_shape), |
| errors::InvalidArgument( |
| "input_a and input_b TensorList element shapes are not " |
| "identical at index ", |
| i, ". Saw ", l_a->element_shape.DebugString(), " vs. ", |
| l_b->element_shape.DebugString())); |
| if (ok_to_alias) { |
| TensorList* out = output_t(i).get<TensorList>(); |
| std::copy(l_b->tensors().begin(), l_b->tensors().end(), |
| std::back_inserter(out->tensors())); |
| } else { |
| TensorList out = l_a->Copy(); |
| std::copy(l_b->tensors().begin(), l_b->tensors().end(), |
| std::back_inserter(out.tensors())); |
| output_t(i) = std::move(out); |
| } |
| } |
| } |
| |
| private: |
| DataType element_dtype_; |
| }; |
| |
| REGISTER_KERNEL_BUILDER(Name("TensorListConcatLists").Device(DEVICE_CPU), |
| TensorListConcatLists); |
| |
| #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
| |
| REGISTER_KERNEL_BUILDER(Name("TensorListConcatLists").Device(DEVICE_GPU), |
| TensorListConcatLists); |
| |
| #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
| |
| #define REGISTER_TENSOR_LIST_OPS_CPU(T) \ |
| REGISTER_KERNEL_BUILDER(Name("TensorListStack") \ |
| .TypeConstraint<T>("element_dtype") \ |
| .Device(DEVICE_CPU), \ |
| TensorListStack<CPUDevice, T>) \ |
| REGISTER_KERNEL_BUILDER(Name("TensorListGather") \ |
| .TypeConstraint<T>("element_dtype") \ |
| .Device(DEVICE_CPU), \ |
| TensorListGather<CPUDevice, T>) \ |
| REGISTER_KERNEL_BUILDER(Name("TensorListConcat") \ |
| .TypeConstraint<T>("element_dtype") \ |
| .Device(DEVICE_CPU), \ |
| TensorListConcat<CPUDevice, T>) \ |
| REGISTER_KERNEL_BUILDER(Name("TensorListConcatV2") \ |
| .TypeConstraint<T>("element_dtype") \ |
| .Device(DEVICE_CPU), \ |
| TensorListConcat<CPUDevice, T>) \ |
| REGISTER_KERNEL_BUILDER(Name("TensorListGetItem") \ |
| .TypeConstraint<T>("element_dtype") \ |
| .Device(DEVICE_CPU), \ |
| TensorListGetItem<CPUDevice, T>) \ |
| REGISTER_KERNEL_BUILDER(Name("TensorListPopBack") \ |
| .TypeConstraint<T>("element_dtype") \ |
| .Device(DEVICE_CPU), \ |
| TensorListPopBack<CPUDevice, T>) \ |
| REGISTER_KERNEL_BUILDER(Name("TensorListFromTensor") \ |
| .TypeConstraint<T>("element_dtype") \ |
| .Device(DEVICE_CPU), \ |
| TensorListFromTensor<CPUDevice, T>) \ |
| REGISTER_KERNEL_BUILDER(Name("TensorListScatter") \ |
| .TypeConstraint<T>("element_dtype") \ |
| .Device(DEVICE_CPU), \ |
| TensorListScatter<CPUDevice, T>) \ |
| REGISTER_KERNEL_BUILDER(Name("TensorListScatterV2") \ |
| .TypeConstraint<T>("element_dtype") \ |
| .Device(DEVICE_CPU), \ |
| TensorListScatter<CPUDevice, T>) \ |
| REGISTER_KERNEL_BUILDER(Name("TensorListScatterIntoExistingList") \ |
| .TypeConstraint<T>("element_dtype") \ |
| .Device(DEVICE_CPU), \ |
| TensorListScatterIntoExistingList<CPUDevice, T>) \ |
| REGISTER_KERNEL_BUILDER(Name("TensorListSplit") \ |
| .TypeConstraint<T>("element_dtype") \ |
| .Device(DEVICE_CPU), \ |
| TensorListSplit<CPUDevice, T>) \ |
| REGISTER_KERNEL_BUILDER(Name("TensorListPushBackBatch") \ |
| .TypeConstraint<T>("element_dtype") \ |
| .Device(DEVICE_CPU), \ |
| TensorListPushBackBatch<CPUDevice, T>) |
| |
| TF_CALL_POD_STRING_TYPES(REGISTER_TENSOR_LIST_OPS_CPU); |
| REGISTER_TENSOR_LIST_OPS_CPU(quint8); |
| REGISTER_TENSOR_LIST_OPS_CPU(qint8); |
| REGISTER_TENSOR_LIST_OPS_CPU(quint16); |
| REGISTER_TENSOR_LIST_OPS_CPU(qint16); |
| REGISTER_TENSOR_LIST_OPS_CPU(qint32); |
| REGISTER_TENSOR_LIST_OPS_CPU(Variant); |
| |
| #undef REGISTER_TENSOR_LIST_OPS_CPU |
| |
| #define REGISTER_TENSOR_LIST_OPS_CPU(T) |
| |
| REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION(ADD_VARIANT_BINARY_OP, DEVICE_CPU, |
| TensorList, |
| TensorListBinaryAdd<CPUDevice>); |
| |
| REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION(ZEROS_LIKE_VARIANT_UNARY_OP, |
| DEVICE_CPU, TensorList, |
| TensorListZerosLike<CPUDevice>); |
| |
| } // namespace tensorflow |