| /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. |
| |
| Licensed under the Apache License, Version 2.0 (the "License"); |
| you may not use this file except in compliance with the License. |
| You may obtain a copy of the License at |
| |
| http://www.apache.org/licenses/LICENSE-2.0 |
| |
| Unless required by applicable law or agreed to in writing, software |
| distributed under the License is distributed on an "AS IS" BASIS, |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| See the License for the specific language governing permissions and |
| limitations under the License. |
| ==============================================================================*/ |
| |
| #include "tensorflow/core/framework/partial_tensor_shape.h" |
| #include "tensorflow/core/framework/tensor.h" |
| #include "tensorflow/core/kernels/data/dataset.h" |
| #include "tensorflow/core/kernels/data/window_dataset.h" |
| |
| namespace tensorflow { |
| namespace data { |
| namespace { |
| |
| // See documentation in ../ops/dataset_ops.cc for a high-level |
| // description of the following op. |
| |
| class WindowDatasetOp : public UnaryDatasetOpKernel { |
| public: |
| explicit WindowDatasetOp(OpKernelConstruction* ctx) |
| : UnaryDatasetOpKernel(ctx) {} |
| |
| void MakeDataset(OpKernelContext* ctx, DatasetBase* input, |
| DatasetBase** output) override { |
| int64 window_size = 0; |
| OP_REQUIRES_OK(ctx, ParseScalarArgument<int64>(ctx, "size", &window_size)); |
| OP_REQUIRES( |
| ctx, window_size > 0, |
| errors::InvalidArgument("Window size must be greater than zero.")); |
| |
| int64 window_shift = 0; |
| OP_REQUIRES_OK(ctx, |
| ParseScalarArgument<int64>(ctx, "shift", &window_shift)); |
| OP_REQUIRES( |
| ctx, window_shift > 0, |
| errors::InvalidArgument("Window shift must be greater than zero.")); |
| |
| int64 window_stride = 0; |
| OP_REQUIRES_OK(ctx, |
| ParseScalarArgument<int64>(ctx, "stride", &window_stride)); |
| OP_REQUIRES( |
| ctx, window_stride > 0, |
| errors::InvalidArgument("Window stride must be greater than zero.")); |
| |
| bool drop_remainder; |
| OP_REQUIRES_OK( |
| ctx, ParseScalarArgument<bool>(ctx, "drop_remainder", &drop_remainder)); |
| |
| *output = new Dataset(ctx, input, window_size, window_shift, window_stride, |
| drop_remainder); |
| } |
| |
| private: |
| class Dataset : public DatasetBase { |
| public: |
| Dataset(OpKernelContext* ctx, const DatasetBase* input, int64 window_size, |
| int64 window_shift, int64 window_stride, bool drop_remainder) |
| : DatasetBase(DatasetContext(ctx)), |
| input_(input), |
| window_size_(window_size), |
| window_shift_(window_shift), |
| window_stride_(window_stride), |
| drop_remainder_(drop_remainder) { |
| input_->Ref(); |
| } |
| |
| ~Dataset() override { input_->Unref(); } |
| |
| std::unique_ptr<IteratorBase> MakeIteratorInternal( |
| const string& prefix) const override { |
| return std::unique_ptr<IteratorBase>(new Iterator( |
| Iterator::Params{this, strings::StrCat(prefix, "::Window")})); |
| } |
| |
| const DataTypeVector& output_dtypes() const override { |
| static DataTypeVector* output_dtypes = new DataTypeVector({DT_VARIANT}); |
| return *output_dtypes; |
| } |
| |
| const std::vector<PartialTensorShape>& output_shapes() const override { |
| static std::vector<PartialTensorShape>* output_shapes = |
| new std::vector<PartialTensorShape>({TensorShape({})}); |
| return *output_shapes; |
| } |
| |
| string DebugString() const override { |
| return strings::StrCat("WindowDatasetOp(", window_size_, window_shift_, |
| window_stride_, drop_remainder_, ")::Dataset"); |
| } |
| |
| protected: |
| Status AsGraphDefInternal(SerializationContext* ctx, |
| DatasetGraphDefBuilder* b, |
| Node** output) const override { |
| Node* input_graph_node = nullptr; |
| TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node)); |
| Node* window_size_node = nullptr; |
| TF_RETURN_IF_ERROR(b->AddScalar(window_size_, &window_size_node)); |
| Node* window_shift_node = nullptr; |
| TF_RETURN_IF_ERROR(b->AddScalar(window_shift_, &window_shift_node)); |
| Node* window_stride_node = nullptr; |
| TF_RETURN_IF_ERROR(b->AddScalar(window_stride_, &window_stride_node)); |
| Node* drop_remainder_node = nullptr; |
| TF_RETURN_IF_ERROR(b->AddScalar(drop_remainder_, &drop_remainder_node)); |
| TF_RETURN_IF_ERROR( |
| b->AddDataset(this, |
| {input_graph_node, window_size_node, window_shift_node, |
| window_stride_node, drop_remainder_node}, |
| output)); |
| return Status::OK(); |
| } |
| |
| private: |
| class Iterator : public DatasetIterator<Dataset> { |
| public: |
| explicit Iterator(const Params& params) |
| : DatasetIterator<Dataset>(params) {} |
| |
| Status Initialize(IteratorContext* ctx) override { |
| return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_); |
| } |
| |
| Status GetNextInternal(IteratorContext* ctx, |
| std::vector<Tensor>* out_tensors, |
| bool* end_of_sequence) override { |
| const int64 window_size = dataset()->window_size_; |
| const int64 window_shift = dataset()->window_shift_; |
| const int64 window_stride = dataset()->window_stride_; |
| std::vector<std::vector<Tensor>> window_elements; |
| Status status = Status::OK(); |
| { |
| mutex_lock l(mu_); |
| if (!input_impl_ && buffer_.empty()) { |
| *end_of_sequence = true; |
| return Status::OK(); |
| } |
| |
| // Add elements to the buffer. |
| size_t target_size = TargetBufferSize(window_size, window_stride); |
| if (input_impl_) { |
| *end_of_sequence = false; |
| for (size_t i = buffer_.size(); |
| i < target_size && !*end_of_sequence; ++i) { |
| std::vector<Tensor> element; |
| Status status = |
| input_impl_->GetNext(ctx, &element, end_of_sequence); |
| if (!*end_of_sequence) { |
| buffer_.emplace_back(std::move(element), status); |
| } else { |
| input_impl_.reset(); |
| } |
| } |
| } |
| |
| // If there are not enough elements and `drop_remainder` is set, we do |
| // not wish to return a smaller window. |
| if (buffer_.empty() || |
| (dataset()->drop_remainder_ && buffer_.size() < target_size)) { |
| DCHECK(*end_of_sequence); |
| return Status::OK(); |
| } |
| |
| int num_elements = 1 + (buffer_.size() - 1) / window_stride; |
| window_elements.reserve(num_elements); |
| for (size_t i = 0; i < num_elements; ++i) { |
| status.Update(buffer_[window_stride * i].status); |
| if (!status.ok()) { |
| break; |
| } |
| window_elements.emplace_back(buffer_[window_stride * i].result); |
| } |
| |
| // Shift the window, discarding elements if necessary. |
| int buffer_size = buffer_.size(); |
| if (window_shift >= buffer_size) { |
| for (size_t i = buffer_size; input_impl_ && i < window_shift; ++i) { |
| bool end_of_input; |
| std::vector<Tensor> element; |
| // Ignore non-error status of discarded elements. |
| input_impl_->GetNext(ctx, &element, &end_of_input).IgnoreError(); |
| if (end_of_input) { |
| input_impl_.reset(); |
| } |
| } |
| buffer_.clear(); |
| } else { |
| buffer_.erase(buffer_.begin(), buffer_.begin() + window_shift); |
| } |
| } |
| |
| if (!status.ok()) { |
| return status; |
| } |
| |
| // Construct output tensors. |
| const size_t num_tuple_components = window_elements[0].size(); |
| const int64 num_window_elements = window_elements.size(); |
| *end_of_sequence = false; |
| for (size_t idx = 0; idx < num_tuple_components; ++idx) { |
| DatasetBase* window_dataset; |
| std::vector<std::vector<Tensor>> window_component_elements; |
| window_component_elements.reserve(num_window_elements); |
| // Build the output tuple component by copying one slice |
| // from each input element in the window. |
| for (size_t i = 0; i < num_window_elements; ++i) { |
| std::vector<Tensor> component_element; |
| component_element.push_back(std::move(window_elements[i][idx])); |
| window_component_elements.push_back(component_element); |
| } |
| DataTypeVector output_types( |
| {dataset()->input_->output_dtypes()[idx]}); |
| std::vector<PartialTensorShape> output_shapes( |
| {dataset()->input_->output_shapes()[idx]}); |
| TF_RETURN_IF_ERROR(NewWindowDataset(window_component_elements, |
| output_types, output_shapes, |
| &window_dataset)); |
| out_tensors->emplace_back(DT_VARIANT, TensorShape({})); |
| TF_RETURN_IF_ERROR(StoreDatasetInVariantTensor(window_dataset, |
| &out_tensors->back())); |
| } |
| return Status::OK(); |
| } |
| |
| protected: |
| Status SaveInternal(IteratorStateWriter* writer) override { |
| mutex_lock l(mu_); |
| if (!input_impl_) { |
| TF_RETURN_IF_ERROR( |
| writer->WriteScalar(full_name("input_impl_empty"), "")); |
| } else { |
| TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_)); |
| } |
| // Save buffer. |
| TF_RETURN_IF_ERROR(writer->WriteScalar(strings::StrCat("buffer_size"), |
| buffer_.size())); |
| for (int64 i = 0; i < buffer_.size(); i++) { |
| TF_RETURN_IF_ERROR(WriteStatusLocked(writer, i, buffer_[i].status)); |
| TF_RETURN_IF_ERROR( |
| writer->WriteScalar(strings::StrCat("buffer[", i, "].size"), |
| buffer_[i].result.size())); |
| for (int64 j = 0; j < buffer_[i].result.size(); j++) { |
| TF_RETURN_IF_ERROR( |
| writer->WriteTensor(strings::StrCat("buffer[", i, "][", j, "]"), |
| buffer_[i].result[j])); |
| } |
| } |
| return Status::OK(); |
| } |
| |
| Status RestoreInternal(IteratorContext* ctx, |
| IteratorStateReader* reader) override { |
| mutex_lock l(mu_); |
| if (!reader->Contains(full_name("input_impl_empty"))) { |
| TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_)); |
| } else { |
| input_impl_.reset(); |
| } |
| // Restore buffer. |
| int64 buffer_size; |
| TF_RETURN_IF_ERROR( |
| reader->ReadScalar(strings::StrCat("buffer_size"), &buffer_size)); |
| buffer_.resize(buffer_size); |
| for (int64 i = 0; i < buffer_size; i++) { |
| int64 vector_size; |
| TF_RETURN_IF_ERROR(ReadStatusLocked(reader, i, &buffer_[i].status)); |
| TF_RETURN_IF_ERROR(reader->ReadScalar( |
| strings::StrCat("buffer[", i, "].size"), &vector_size)); |
| buffer_[i].result.resize(vector_size); |
| for (int64 j = 0; j < vector_size; j++) { |
| TF_RETURN_IF_ERROR( |
| reader->ReadTensor(strings::StrCat("buffer[", i, "][", j, "]"), |
| &buffer_[i].result[j])); |
| } |
| } |
| return Status::OK(); |
| } |
| |
| private: |
| struct InvocationResult { |
| InvocationResult() = default; |
| InvocationResult(std::vector<Tensor>&& result, const Status& status) |
| : result(result), status(status) {} |
| |
| std::vector<Tensor> result; |
| Status status; |
| }; |
| |
| Status WriteStatusLocked(IteratorStateWriter* writer, size_t index, |
| const Status& status) |
| EXCLUSIVE_LOCKS_REQUIRED(mu_) { |
| TF_RETURN_IF_ERROR(writer->WriteScalar( |
| CodeKey(index), static_cast<int64>(status.code()))); |
| if (!status.ok()) { |
| TF_RETURN_IF_ERROR(writer->WriteScalar(ErrorMessageKey(index), |
| status.error_message())); |
| } |
| return Status::OK(); |
| } |
| |
| Status ReadStatusLocked(IteratorStateReader* reader, size_t index, |
| Status* status) EXCLUSIVE_LOCKS_REQUIRED(mu_) { |
| int64 code_int; |
| TF_RETURN_IF_ERROR(reader->ReadScalar(CodeKey(index), &code_int)); |
| error::Code code = static_cast<error::Code>(code_int); |
| |
| if (code != error::Code::OK) { |
| string error_message; |
| TF_RETURN_IF_ERROR( |
| reader->ReadScalar(ErrorMessageKey(index), &error_message)); |
| *status = Status(code, error_message); |
| } else { |
| *status = Status::OK(); |
| } |
| return Status::OK(); |
| } |
| |
| string CodeKey(size_t index) { |
| return full_name(strings::StrCat("buffer[", index, "].code")); |
| } |
| |
| string ErrorMessageKey(size_t index) { |
| return full_name(strings::StrCat("buffer[", index, "].error_message")); |
| } |
| |
| size_t TargetBufferSize(int64 window_size, int64 window_stride) { |
| return (window_size - 1) * window_stride + 1; |
| } |
| |
| mutex mu_; |
| std::deque<InvocationResult> buffer_ GUARDED_BY(mu_); |
| std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_); |
| }; |
| |
| const DatasetBase* const input_; |
| const int64 window_size_; |
| const int64 window_shift_; |
| const int64 window_stride_; |
| const bool drop_remainder_; |
| }; |
| }; |
| |
| REGISTER_KERNEL_BUILDER(Name("WindowDataset").Device(DEVICE_CPU), |
| WindowDatasetOp); |
| } // namespace |
| } // namespace data |
| } // namespace tensorflow |