blob: b86b03c8f0933d43b5fc1a6f631a66675515ec47 [file] [log] [blame]
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_KERNELS_PADDING_FIFO_QUEUE_H_
#define TENSORFLOW_CORE_KERNELS_PADDING_FIFO_QUEUE_H_
#include <deque>
#include <vector>
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/partial_tensor_shape.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/kernels/fifo_queue.h"
#include "tensorflow/core/kernels/typed_queue.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
class PaddingFIFOQueue : public FIFOQueue {
public:
PaddingFIFOQueue(int32 capacity, const DataTypeVector& component_dtypes,
const std::vector<PartialTensorShape>& component_shapes,
const string& name);
Status Initialize() override;
// Implementations of QueueInterface methods --------------------------------
void TryDequeueMany(int num_elements, OpKernelContext* ctx,
bool allow_small_batch,
CallbackWithTuple callback) override;
Status MatchesNodeDef(const NodeDef& node_def) override;
protected:
Status ValidateManyTuple(const Tuple& tuple) override;
Status ValidateTuple(const Tuple& tuple) override;
Status CompatibleNodeDefShapes(const NodeDef& node_def) const;
// Convert a list of PartialTensorShape to a list of
// TensorShape.
// Any unknown dimension sizes are converted to 0.
// REQUIRED: All the input shapes have well defined rank.
static std::vector<TensorShape> ConvertShapesPartialDimensionsToZero(
const gtl::ArraySlice<PartialTensorShape>& partial_shapes);
// Sets the values in the given element to zero.
static Status SetElementZero(Tensor* element);
// Copies element into the index^th slice (in the first dimension)
// of parent. Allows for the parent's slice to have a larger size
// than the element, and copies the element into the upper left hand
// corner of the slice.
static Status CopyElementToLargerSlice(const Tensor& element, Tensor* parent,
int index);
std::vector<PartialTensorShape> partial_shapes_;
private:
~PaddingFIFOQueue() override {}
static Status GetElementComponent(const PaddingFIFOQueue::Tuple& tuple,
int component, OpKernelContext* ctx,
PersistentTensor* out_tensor);
static Status IsSameSizeExceptZerosInFirst(const TensorShape& first,
const TensorShape& second);
TF_DISALLOW_COPY_AND_ASSIGN(PaddingFIFOQueue);
};
} // namespace tensorflow
#endif // TENSORFLOW_CORE_KERNELS_PADDING_FIFO_QUEUE_H_