| /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
| |
| Licensed under the Apache License, Version 2.0 (the "License"); |
| you may not use this file except in compliance with the License. |
| You may obtain a copy of the License at |
| |
| http://www.apache.org/licenses/LICENSE-2.0 |
| |
| Unless required by applicable law or agreed to in writing, software |
| distributed under the License is distributed on an "AS IS" BASIS, |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| See the License for the specific language governing permissions and |
| limitations under the License. |
| ==============================================================================*/ |
| |
| #include "tensorflow/core/util/batch_util.h" |
| |
| #include "tensorflow/core/framework/register_types.h" |
| #include "tensorflow/core/framework/types.h" |
| #include "tensorflow/core/lib/core/errors.h" |
| |
| #define TF_CALL_DATASET_TYPES(m) TF_CALL_ALL_TYPES(m) TF_CALL_QUANTIZED_TYPES(m) |
| |
| namespace tensorflow { |
| namespace batch_util { |
| |
| namespace { |
| |
| Status ValidateInput(const Tensor& parent, const Tensor& element, int64 index) { |
| DCHECK_NE(parent.dim_size(0), 0); |
| DCHECK_GE(index, 0); |
| if (element.NumElements() != (parent.NumElements() / parent.dim_size(0))) { |
| TensorShape chip_shape = parent.shape(); |
| chip_shape.RemoveDim(0); |
| return errors::Internal( |
| "ValidateInput Cannot perform copy: number of elements does not match. " |
| " Shapes are: [element]: ", |
| element.shape().DebugString(), |
| ", [parent slice]: ", chip_shape.DebugString()); |
| } |
| return Status::OK(); |
| } |
| |
| template <typename T> |
| Status HandleElementToSlice(T* src, T* dest, int64 num_values, |
| bool /* can_move */) { |
| static_assert(is_simple_type<T>::value, "Memcpy requires a simple type."); |
| memcpy(dest, src, num_values * sizeof(T)); |
| return Status::OK(); |
| } |
| |
| template <> |
| Status HandleElementToSlice<string>(string* src, string* dest, int64 num_values, |
| bool can_move) { |
| if (can_move) { |
| for (int64 i = 0; i < num_values; ++i) { |
| *dest++ = std::move(*src++); |
| } |
| } else { |
| std::copy_n(src, num_values, dest); |
| } |
| return Status::OK(); |
| } |
| |
| template <> |
| Status HandleElementToSlice<Variant>(Variant* src, Variant* dest, |
| int64 num_values, bool can_move) { |
| if (can_move) { |
| for (int64 i = 0; i < num_values; ++i) { |
| *dest++ = std::move(*src++); |
| } |
| } else { |
| std::copy_n(src, num_values, dest); |
| } |
| return Status::OK(); |
| } |
| |
| template <> |
| Status HandleElementToSlice<ResourceHandle>(ResourceHandle* src, |
| ResourceHandle* dest, |
| int64 num_values, |
| bool /* can_move */) { |
| std::copy_n(src, num_values, dest); |
| return Status::OK(); |
| } |
| |
| template <> |
| Status HandleElementToSlice<Eigen::half>(Eigen::half* src, Eigen::half* dest, |
| int64 num_values, |
| bool /* can_move */) { |
| std::copy_n(src, num_values, dest); |
| return Status::OK(); |
| } |
| |
| // TODO(b/78245576): Consider removing this overload. |
| template <typename T> |
| void HandleSliceToElement(const Tensor& parent, Tensor* element, int64 index) { |
| element->flat<T>() = parent.flat_outer_dims<T>().chip(index, 0); |
| } |
| |
| template <typename T> |
| void HandleSliceToElement(Tensor* parent, Tensor* element, int64 index, |
| bool can_move) { |
| element->flat<T>() = parent->flat_outer_dims<T>().chip(index, 0); |
| } |
| |
| template <> |
| void HandleSliceToElement<string>(Tensor* parent, Tensor* element, int64 index, |
| bool can_move) { |
| auto parent_as_matrix = parent->flat_outer_dims<tstring>(); |
| auto element_flat = element->flat<tstring>(); |
| if (can_move) { |
| for (int64 i = 0; i < element->NumElements(); ++i) { |
| element_flat(i) = std::move(parent_as_matrix(index, i)); |
| } |
| } else { |
| element_flat = parent_as_matrix.chip(index, 0); |
| } |
| } |
| |
| template <> |
| void HandleSliceToElement<Variant>(Tensor* parent, Tensor* element, int64 index, |
| bool can_move) { |
| auto parent_as_matrix = parent->flat_outer_dims<Variant>(); |
| auto element_flat = element->flat<Variant>(); |
| if (can_move) { |
| for (int64 i = 0; i < element->NumElements(); ++i) { |
| element_flat(i) = std::move(parent_as_matrix(index, i)); |
| } |
| } else { |
| element_flat = parent_as_matrix.chip(index, 0); |
| } |
| } |
| |
| } // namespace |
| |
| // Copies element into the index^th slice of parent (in the 0th dimension). |
| Status CopyElementToSlice(Tensor element, Tensor* parent, int64 index) { |
| TF_RETURN_IF_ERROR(ValidateInput(*parent, element, index)); |
| const int64 num_values = element.NumElements(); |
| bool can_move = element.RefCountIsOne(); |
| #define HANDLE_TYPE(T) \ |
| case DataTypeToEnum<T>::value: { \ |
| T* src = element.base<T>(); \ |
| T* dest = parent->base<T>() + (num_values * index); \ |
| return HandleElementToSlice<T>(src, dest, num_values, can_move); \ |
| } |
| |
| switch (element.dtype()) { |
| TF_CALL_ALL_TYPES(HANDLE_TYPE); |
| TF_CALL_QUANTIZED_TYPES(HANDLE_TYPE); |
| TF_CALL_uint32(HANDLE_TYPE); |
| TF_CALL_uint64(HANDLE_TYPE); |
| #undef HANDLE_TYPE |
| default: |
| return errors::Unimplemented("CopyElementToSlice Unhandled data type: ", |
| element.dtype()); |
| } |
| } |
| |
| // Copies the index^th slice of parent (in the 0th dimension) into element. |
| Status CopySliceToElement(const Tensor& parent, Tensor* element, int64 index) { |
| TF_RETURN_IF_ERROR(ValidateInput(parent, *element, index)); |
| |
| #define HANDLE_TYPE(T) \ |
| case DataTypeToEnum<T>::value: { \ |
| HandleSliceToElement<T>(parent, element, index); \ |
| return Status::OK(); \ |
| } |
| |
| switch (parent.dtype()) { |
| TF_CALL_ALL_TYPES(HANDLE_TYPE); |
| TF_CALL_QUANTIZED_TYPES(HANDLE_TYPE); |
| TF_CALL_uint32(HANDLE_TYPE); |
| TF_CALL_uint64(HANDLE_TYPE); |
| #undef HANDLE_TYPE |
| default: |
| return errors::Unimplemented("CopySliceToElement Unhandled data type: ", |
| element->dtype()); |
| } |
| } |
| |
| // Copies the index^th slice of parent (in the 0th dimension) into element. |
| // |
| // NOTE(mrry): The implementation may be able to optimize the copy to a move. |
| // This is particularly important for DT_STRING tensors. |
| Status MaybeMoveSliceToElement(Tensor* parent, Tensor* element, int64 index) { |
| TF_RETURN_IF_ERROR(ValidateInput(*parent, *element, index)); |
| bool can_move = parent->RefCountIsOne(); |
| |
| #define HANDLE_TYPE(T) \ |
| case DataTypeToEnum<T>::value: { \ |
| HandleSliceToElement<T>(parent, element, index, can_move); \ |
| return Status::OK(); \ |
| } |
| |
| switch (parent->dtype()) { |
| TF_CALL_ALL_TYPES(HANDLE_TYPE); |
| TF_CALL_QUANTIZED_TYPES(HANDLE_TYPE); |
| TF_CALL_uint32(HANDLE_TYPE); |
| TF_CALL_uint64(HANDLE_TYPE); |
| #undef HANDLE_TYPE |
| default: |
| return errors::Unimplemented( |
| "MaybeMoveSliceToElement Unhandled data type: ", element->dtype()); |
| } |
| } |
| |
| // The following five functions are copied from padding_fifo_queue.cc. |
| // TODO(mrry): Reconcile these functions with the similar methods in the |
| // queue implementation. |
| Status ValidateElementToLargerSlice(const Tensor& element, Tensor* parent) { |
| DCHECK_NE(parent->dim_size(0), 0); |
| if (element.NumElements() > (parent->NumElements() / parent->dim_size(0))) { |
| TensorShape chip_shape = parent->shape(); |
| chip_shape.RemoveDim(0); |
| return errors::Internal( |
| "HandleElementToLargerSlice Cannot copy slice: number of entries in " |
| "element is greater than number of elements in parent slice. ", |
| "Shapes are: [element]: ", element.shape().DebugString(), |
| ", [parent slice]: ", chip_shape.DebugString()); |
| } |
| return Status::OK(); |
| } |
| |
| template <typename T, int NDIMS> |
| Status HandleElementToLargerSlice(const Tensor& element, Tensor* parent, |
| int index) { |
| TF_RETURN_IF_ERROR(ValidateElementToLargerSlice(element, parent)); |
| if (element.NumElements() == 0) { |
| return Status::OK(); |
| } |
| auto element_t = element.tensor<T, NDIMS>(); |
| auto parent_t = parent->tensor<T, NDIMS + 1>(); |
| Eigen::DSizes<Eigen::DenseIndex, NDIMS + 1> slice_indices; |
| slice_indices[0] = index; |
| Eigen::DSizes<Eigen::DenseIndex, NDIMS + 1> slice_size; |
| slice_size[0] = 1; |
| for (size_t i = 1; i < slice_size.size(); ++i) { |
| slice_size[i] = element_t.dimension(i - 1); |
| } |
| parent_t.slice(slice_indices, slice_size) = element_t.reshape(slice_size); |
| return Status::OK(); |
| } |
| |
| template <int NDIMS> |
| Status HandleElementToLargerSliceWithRank(const Tensor& element, Tensor* parent, |
| int index) { |
| #define HANDLE_TYPE(T) \ |
| case DataTypeToEnum<T>::value: { \ |
| return HandleElementToLargerSlice<T, NDIMS>(element, parent, index); \ |
| } |
| |
| switch (element.dtype()) { |
| TF_CALL_DATASET_TYPES(HANDLE_TYPE); |
| #undef HANDLE_TYPE |
| default: |
| return errors::Unimplemented( |
| "HandleElementToLargerSliceWithRank Unhandled data type: ", |
| element.dtype()); |
| } |
| } |
| |
| Status CopyElementToLargerSlice(const Tensor& element, Tensor* parent, |
| int index) { |
| if (parent->dims() != element.dims() + 1) { |
| return errors::Internal( |
| "Mismatched ranks. Element's rank is: ", element.dims(), |
| " but element is meant to be a slice in output Tensor having rank: ", |
| parent->dims(), " (should be: ", element.dims() + 1, ")"); |
| } |
| |
| #define HANDLE_DIMS(NDIMS) \ |
| case NDIMS: { \ |
| TF_RETURN_IF_ERROR( \ |
| HandleElementToLargerSliceWithRank<NDIMS>(element, parent, index)); \ |
| return Status::OK(); \ |
| } |
| |
| switch (element.dims()) { |
| HANDLE_DIMS(0); |
| HANDLE_DIMS(1); |
| HANDLE_DIMS(2); |
| HANDLE_DIMS(3); |
| HANDLE_DIMS(4); |
| HANDLE_DIMS(5); |
| #undef HANDLE_DIMS |
| default: |
| return errors::Unimplemented("CopyElementToLargerSlice Unhandled rank: ", |
| element.dims()); |
| } |
| } |
| |
| Status SetElementZero(Tensor* element, const Tensor& padding) { |
| #define HANDLE_TYPE(T) \ |
| if (element->dtype() == DataTypeToEnum<T>::value) { \ |
| element->flat<T>().setConstant(padding.scalar<T>()()); \ |
| return Status::OK(); \ |
| } |
| TF_CALL_DATASET_TYPES(HANDLE_TYPE); |
| #undef HANDLE_TYPE |
| return errors::Unimplemented("SetElementZero Unhandled data type: ", |
| element->dtype()); |
| } |
| |
| } // namespace batch_util |
| } // namespace tensorflow |