| /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. |
| |
| Licensed under the Apache License, Version 2.0 (the "License"); |
| you may not use this file except in compliance with the License. |
| You may obtain a copy of the License at |
| |
| http://www.apache.org/licenses/LICENSE-2.0 |
| |
| Unless required by applicable law or agreed to in writing, software |
| distributed under the License is distributed on an "AS IS" BASIS, |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| See the License for the specific language governing permissions and |
| limitations under the License. |
| ==============================================================================*/ |
| |
| #include <limits> |
| |
| #define EIGEN_USE_THREADS |
| #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
| #define EIGEN_USE_GPU |
| |
| #include "tensorflow/core/kernels/list_kernels.h" |
| |
| #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" |
| #include "tensorflow/core/framework/op_kernel.h" |
| #include "tensorflow/core/framework/register_types.h" |
| #include "tensorflow/core/framework/tensor_types.h" |
| #include "tensorflow/core/framework/variant.h" |
| #include "tensorflow/core/framework/variant_op_registry.h" |
| #include "tensorflow/core/kernels/concat_lib.h" |
| #include "tensorflow/core/lib/core/coding.h" |
| #include "tensorflow/core/lib/core/errors.h" |
| #include "tensorflow/core/util/util.h" |
| |
| namespace tensorflow { |
| |
| typedef Eigen::GpuDevice GPUDevice; |
| |
| #define REGISTER_TENSOR_LIST_OPS_GPU(T) \ |
| REGISTER_KERNEL_BUILDER(Name("TensorListStack") \ |
| .TypeConstraint<T>("element_dtype") \ |
| .Device(DEVICE_GPU) \ |
| .HostMemory("element_shape"), \ |
| TensorListStack<GPUDevice, T>) \ |
| REGISTER_KERNEL_BUILDER(Name("TensorListGather") \ |
| .TypeConstraint<T>("element_dtype") \ |
| .Device(DEVICE_GPU) \ |
| .HostMemory("indices") \ |
| .HostMemory("element_shape"), \ |
| TensorListGather<GPUDevice, T>) \ |
| REGISTER_KERNEL_BUILDER(Name("TensorListGetItem") \ |
| .TypeConstraint<T>("element_dtype") \ |
| .Device(DEVICE_GPU) \ |
| .HostMemory("index") \ |
| .HostMemory("element_shape"), \ |
| TensorListGetItem<GPUDevice, T>) \ |
| REGISTER_KERNEL_BUILDER(Name("TensorListPopBack") \ |
| .TypeConstraint<T>("element_dtype") \ |
| .Device(DEVICE_GPU) \ |
| .HostMemory("element_shape"), \ |
| TensorListPopBack<GPUDevice, T>) \ |
| REGISTER_KERNEL_BUILDER(Name("TensorListConcat") \ |
| .TypeConstraint<T>("element_dtype") \ |
| .Device(DEVICE_GPU) \ |
| .HostMemory("lengths"), \ |
| TensorListConcat<GPUDevice, T>) \ |
| REGISTER_KERNEL_BUILDER(Name("TensorListConcatV2") \ |
| .TypeConstraint<T>("element_dtype") \ |
| .Device(DEVICE_GPU) \ |
| .HostMemory("leading_dims") \ |
| .HostMemory("element_shape") \ |
| .HostMemory("lengths"), \ |
| TensorListConcat<GPUDevice, T>) \ |
| REGISTER_KERNEL_BUILDER(Name("TensorListPushBackBatch") \ |
| .TypeConstraint<T>("element_dtype") \ |
| .Device(DEVICE_GPU), \ |
| TensorListPushBackBatch<GPUDevice, T>) \ |
| REGISTER_KERNEL_BUILDER(Name("TensorListFromTensor") \ |
| .TypeConstraint<T>("element_dtype") \ |
| .Device(DEVICE_GPU) \ |
| .HostMemory("element_shape"), \ |
| TensorListFromTensor<GPUDevice, T>) \ |
| REGISTER_KERNEL_BUILDER(Name("TensorListScatter") \ |
| .TypeConstraint<T>("element_dtype") \ |
| .Device(DEVICE_GPU) \ |
| .HostMemory("element_shape") \ |
| .HostMemory("indices"), \ |
| TensorListScatter<GPUDevice, T>) \ |
| REGISTER_KERNEL_BUILDER(Name("TensorListScatterV2") \ |
| .TypeConstraint<T>("element_dtype") \ |
| .Device(DEVICE_GPU) \ |
| .HostMemory("element_shape") \ |
| .HostMemory("num_elements") \ |
| .HostMemory("indices"), \ |
| TensorListScatter<GPUDevice, T>) \ |
| REGISTER_KERNEL_BUILDER(Name("TensorListScatterIntoExistingList") \ |
| .TypeConstraint<T>("element_dtype") \ |
| .Device(DEVICE_GPU) \ |
| .HostMemory("indices"), \ |
| TensorListScatterIntoExistingList<GPUDevice, T>) \ |
| REGISTER_KERNEL_BUILDER(Name("TensorListSplit") \ |
| .TypeConstraint<T>("element_dtype") \ |
| .Device(DEVICE_GPU) \ |
| .HostMemory("element_shape") \ |
| .HostMemory("lengths"), \ |
| TensorListSplit<GPUDevice, T>) |
| |
| TF_CALL_GPU_NUMBER_TYPES(REGISTER_TENSOR_LIST_OPS_GPU); |
| REGISTER_TENSOR_LIST_OPS_GPU(bfloat16); |
| TF_CALL_complex64(REGISTER_TENSOR_LIST_OPS_GPU); |
| TF_CALL_complex128(REGISTER_TENSOR_LIST_OPS_GPU); |
| TF_CALL_int32(REGISTER_TENSOR_LIST_OPS_GPU); |
| TF_CALL_int64(REGISTER_TENSOR_LIST_OPS_GPU); |
| REGISTER_TENSOR_LIST_OPS_GPU(bool); |
| |
| #undef REGISTER_TENSOR_LIST_OPS_GPU |
| |
| REGISTER_KERNEL_BUILDER(Name("TensorListPopBack") |
| .TypeConstraint<Variant>("element_dtype") |
| .Device(DEVICE_GPU) |
| .HostMemory("element_shape"), |
| TensorListPopBack<GPUDevice, Variant>) |
| |
| REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION(ADD_VARIANT_BINARY_OP, DEVICE_GPU, |
| TensorList, |
| TensorListBinaryAdd<GPUDevice>); |
| REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION(ZEROS_LIKE_VARIANT_UNARY_OP, |
| DEVICE_GPU, TensorList, |
| TensorListZerosLike<GPUDevice>); |
| |
| } // namespace tensorflow |
| #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM |