| /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
| |
| Licensed under the Apache License, Version 2.0 (the "License"); |
| you may not use this file except in compliance with the License. |
| You may obtain a copy of the License at |
| |
| http://www.apache.org/licenses/LICENSE-2.0 |
| |
| Unless required by applicable law or agreed to in writing, software |
| distributed under the License is distributed on an "AS IS" BASIS, |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| See the License for the specific language governing permissions and |
| limitations under the License. |
| ==============================================================================*/ |
| #define EIGEN_USE_THREADS |
| |
| #include <atomic> |
| #include <utility> |
| |
| #include "tensorflow/core/common_runtime/function.h" |
| #include "tensorflow/core/framework/partial_tensor_shape.h" |
| #include "tensorflow/core/framework/tensor.h" |
| #include "tensorflow/core/kernels/data/captured_function.h" |
| #include "tensorflow/core/kernels/data/dataset.h" |
| #include "tensorflow/core/kernels/inplace_ops_functor.h" |
| #include "tensorflow/core/lib/core/blocking_counter.h" |
| #include "tensorflow/core/lib/gtl/cleanup.h" |
| #include "tensorflow/core/lib/random/random.h" |
| #include "tensorflow/core/lib/strings/strcat.h" |
| #include "tensorflow/core/platform/cpu_info.h" |
| #include "tensorflow/core/platform/tracing.h" |
| |
| namespace tensorflow { |
| namespace data { |
| namespace { |
| |
| // See documentation in ../ops/dataset_ops.cc for a high-level |
| // description of the following op. |
| |
| class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { |
| public: |
| explicit MapAndBatchDatasetOp(OpKernelConstruction* ctx) |
| : UnaryDatasetOpKernel(ctx), |
| op_version_(ctx->def().op() == "MapAndBatchDataset" ? 1 : 2) { |
| OP_REQUIRES_OK(ctx, ctx->GetAttr("f", &func_)); |
| OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_)); |
| OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_)); |
| } |
| |
| protected: |
| void MakeDataset(OpKernelContext* ctx, DatasetBase* input, |
| DatasetBase** output) override { |
| int64 batch_size; |
| OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "batch_size", &batch_size)); |
| OP_REQUIRES( |
| ctx, batch_size > 0, |
| errors::InvalidArgument("batch_size must be greater than zero.")); |
| |
| int64 num_parallel_calls; |
| switch (op_version_) { |
| case 1: |
| int64 num_parallel_batches; |
| OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "num_parallel_batches", |
| &num_parallel_batches)); |
| num_parallel_calls = num_parallel_batches * batch_size; |
| OP_REQUIRES(ctx, num_parallel_batches > 0, |
| errors::InvalidArgument( |
| "num_parallel_batches must be greater than zero.")); |
| break; |
| case 2: |
| OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "num_parallel_calls", |
| &num_parallel_calls)); |
| OP_REQUIRES(ctx, |
| num_parallel_calls > 0 || num_parallel_calls == kAutoTune, |
| errors::InvalidArgument( |
| "num_parallel_calls must be greater than zero.")); |
| break; |
| default: |
| OP_REQUIRES(ctx, false, |
| errors::Unimplemented("Unsupported operation version %d.", |
| op_version_)); |
| } |
| |
| bool drop_remainder; |
| OP_REQUIRES_OK(ctx, |
| ParseScalarArgument(ctx, "drop_remainder", &drop_remainder)); |
| |
| std::unique_ptr<CapturedFunction> captured_func; |
| OP_REQUIRES_OK(ctx, CapturedFunction::Create(func_, ctx, "other_arguments", |
| &captured_func)); |
| |
| *output = new Dataset(ctx, input, batch_size, num_parallel_calls, |
| drop_remainder, output_types_, output_shapes_, func_, |
| std::move(captured_func), &ctx->eigen_cpu_device()); |
| } |
| |
| private: |
| class Dataset : public DatasetBase { |
| public: |
| Dataset(OpKernelContext* ctx, const DatasetBase* input, int64 batch_size, |
| int64 num_parallel_calls, bool drop_remainder, |
| const DataTypeVector& output_types, |
| const std::vector<PartialTensorShape>& output_shapes, |
| const NameAttrList& func, |
| std::unique_ptr<CapturedFunction> captured_func, |
| const Eigen::ThreadPoolDevice* device) |
| : DatasetBase(DatasetContext(ctx)), |
| input_(input), |
| batch_size_(batch_size), |
| num_parallel_calls_(num_parallel_calls), |
| drop_remainder_(drop_remainder), |
| output_types_(output_types), |
| output_shapes_(output_shapes), |
| map_fn_(func), |
| captured_func_(std::move(captured_func)), |
| device_(device) { |
| input_->Ref(); |
| } |
| |
| ~Dataset() override { input_->Unref(); } |
| |
| std::unique_ptr<IteratorBase> MakeIteratorInternal( |
| const string& prefix) const override { |
| return std::unique_ptr<IteratorBase>( |
| new Iterator({this, strings::StrCat(prefix, "::MapAndBatch")})); |
| } |
| |
| const DataTypeVector& output_dtypes() const override { |
| return output_types_; |
| } |
| |
| const std::vector<PartialTensorShape>& output_shapes() const override { |
| return output_shapes_; |
| } |
| |
| string DebugString() const override { |
| return "MapAndBatchDatasetOp::Dataset"; |
| } |
| |
| protected: |
| Status AsGraphDefInternal(SerializationContext* ctx, |
| DatasetGraphDefBuilder* b, |
| Node** output) const override { |
| TF_RETURN_IF_ERROR(b->AddFunction(ctx, map_fn_.name())); |
| Node* input_graph_node = nullptr; |
| TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node)); |
| Node* batch_size_node; |
| TF_RETURN_IF_ERROR(b->AddScalar(batch_size_, &batch_size_node)); |
| Node* num_parallel_calls_node; |
| TF_RETURN_IF_ERROR( |
| b->AddScalar(num_parallel_calls_, &num_parallel_calls_node)); |
| Node* drop_remainder_node; |
| TF_RETURN_IF_ERROR(b->AddScalar(drop_remainder_, &drop_remainder_node)); |
| |
| DataTypeVector other_arguments_types; |
| other_arguments_types.reserve(captured_func_->captured_inputs().size()); |
| std::vector<Node*> other_arguments; |
| other_arguments.reserve(captured_func_->captured_inputs().size()); |
| for (const Tensor& t : captured_func_->captured_inputs()) { |
| Node* node; |
| TF_RETURN_IF_ERROR(b->AddTensor(t, &node)); |
| other_arguments.emplace_back(node); |
| other_arguments_types.emplace_back(t.dtype()); |
| } |
| AttrValue f; |
| b->BuildAttrValue(map_fn_, &f); |
| AttrValue other_arguments_types_attr; |
| b->BuildAttrValue(other_arguments_types, &other_arguments_types_attr); |
| |
| TF_RETURN_IF_ERROR(b->AddDataset( |
| this, |
| {std::make_pair(0, input_graph_node), |
| std::make_pair(2, batch_size_node), |
| std::make_pair(3, num_parallel_calls_node), |
| std::make_pair(4, drop_remainder_node)}, // Single tensor inputs. |
| {std::make_pair(1, other_arguments)}, // Tensor list inputs. |
| {std::make_pair("f", f), |
| std::make_pair("Targuments", other_arguments_types_attr)}, // Attrs |
| output)); |
| return Status::OK(); |
| } |
| |
| private: |
| class Iterator : public DatasetIterator<Dataset> { |
| public: |
| explicit Iterator(const Params& params) |
| : DatasetIterator<Dataset>(params), |
| num_parallel_calls_(params.dataset->num_parallel_calls_) {} |
| |
| ~Iterator() override { |
| mutex_lock l(mu_); |
| // Cancel the runner thread. |
| cancelled_ = true; |
| cond_var_.notify_all(); |
| // Wait for all in-flight calls to complete. |
| while (num_calls_ > 0) { |
| cond_var_.wait(l); |
| } |
| } |
| |
| Status Initialize(IteratorContext* ctx) override { |
| mutex_lock l(mu_); |
| AddConstantParameter(ctx, "batch_size", dataset()->batch_size_); |
| if (num_parallel_calls_ == kAutoTune) { |
| num_parallel_calls_ = 1; |
| AddTunableParameter(ctx, "parallelism", |
| &num_parallel_calls_ /* value */, 1 /* min */, |
| port::NumSchedulableCPUs() /* max */, &cond_var_); |
| } else { |
| AddConstantParameter(ctx, "parallelism", num_parallel_calls_); |
| } |
| TF_RETURN_IF_ERROR( |
| dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_)); |
| return dataset()->captured_func_->Instantiate(ctx); |
| } |
| |
| Status GetNextInternal(IteratorContext* ctx, |
| std::vector<Tensor>* out_tensors, |
| bool* end_of_sequence) override { |
| std::shared_ptr<BatchResult> result; |
| { |
| mutex_lock l(mu_); |
| EnsureRunnerThreadStarted(ctx); |
| while (batch_results_.empty() || |
| batch_results_.front()->num_calls > 0) { |
| RecordStop(ctx); |
| cond_var_.wait(l); |
| RecordStart(ctx); |
| } |
| std::swap(result, batch_results_.front()); |
| batch_results_.pop_front(); |
| cond_var_.notify_all(); |
| } |
| return ProcessResult(ctx, result, out_tensors, end_of_sequence); |
| } |
| |
| protected: |
| Status SaveInternal(IteratorStateWriter* writer) override { |
| mutex_lock l(mu_); |
| // Wait for all in-flight calls to complete. |
| while (num_calls_ > 0) { |
| cond_var_.wait(l); |
| } |
| CHECK_EQ(num_calls_, 0); |
| TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_)); |
| TF_RETURN_IF_ERROR( |
| writer->WriteScalar(full_name("call_counter"), call_counter_)); |
| TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("batch_results_size"), |
| batch_results_.size())); |
| for (size_t i = 0; i < batch_results_.size(); ++i) { |
| TF_RETURN_IF_ERROR(WriteBatchResult(writer, i)); |
| } |
| return Status::OK(); |
| } |
| |
| Status RestoreInternal(IteratorContext* ctx, |
| IteratorStateReader* reader) override { |
| mutex_lock l(mu_); |
| TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_)); |
| TF_RETURN_IF_ERROR( |
| reader->ReadScalar(full_name("call_counter"), &call_counter_)); |
| int64 batch_results_size; |
| TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("batch_results_size"), |
| &batch_results_size)); |
| for (int i = 0; i < batch_results_size; ++i) { |
| TF_RETURN_IF_ERROR(ReadBatchResult(ctx, reader, i)); |
| } |
| return Status::OK(); |
| } |
| |
| private: |
| struct BatchResult { |
| explicit BatchResult(int64 batch_size) { |
| end_of_input = false; |
| num_calls = batch_size; |
| num_elements = 0; |
| output_allocated = false; |
| status = Status::OK(); |
| } |
| |
| void UpdateStatus(const Status& s) { |
| mutex_lock l(mu); |
| status.Update(s); |
| } |
| |
| mutex mu; |
| bool end_of_input GUARDED_BY(mu); |
| int64 num_elements GUARDED_BY(mu); |
| std::vector<Tensor> output; |
| bool output_allocated GUARDED_BY(mu); |
| Status status GUARDED_BY(mu); |
| // Counts the number of outstanding calls for this batch. |
| int64 num_calls; // access guarded by owner's mutex |
| }; |
| |
| void Callback(const std::shared_ptr<IteratorContext>& ctx, |
| const std::shared_ptr<BatchResult>& result, |
| const std::shared_ptr<std::vector<Tensor>>& return_values, |
| int64 offset, const Status& status) LOCKS_EXCLUDED(mu_) { |
| result->UpdateStatus(status); |
| if (status.ok()) { |
| EnsureOutputAllocated(ctx, result, return_values); |
| for (size_t i = 0; i < return_values->size(); ++i) { |
| const Tensor& tensor = return_values->at(i); |
| Tensor* batch = &(result->output)[i]; |
| if (tensor.NumElements() != |
| (batch->NumElements() / batch->dim_size(0))) { |
| TensorShape batch_shape = batch->shape(); |
| batch_shape.RemoveDim(0); |
| result->UpdateStatus(errors::InvalidArgument( |
| "Cannot add tensor to the batch: number of elements does not " |
| "match. Shapes are: [tensor]: ", |
| tensor.shape().DebugString(), |
| ", [batch]: ", batch_shape.DebugString())); |
| break; |
| } |
| // TODO(mrry): Add a version of DoParallelConcat that allows us to |
| // move `tensor` where possible, to speed up string tensor batching. |
| Status copy_status = ::tensorflow::functor::DoParallelConcat( |
| *dataset()->device_, tensor, offset, batch); |
| if (!copy_status.ok()) { |
| result->UpdateStatus(copy_status); |
| break; |
| } |
| } |
| { |
| mutex_lock l(result->mu); |
| result->num_elements++; |
| } |
| } |
| CallCompleted(result); |
| } |
| |
| void CallCompleted(const std::shared_ptr<BatchResult>& result) |
| LOCKS_EXCLUDED(mu_) { |
| mutex_lock l(mu_); |
| num_calls_--; |
| result->num_calls--; |
| cond_var_.notify_all(); |
| } |
| |
| void CallFunction(std::shared_ptr<IteratorContext> ctx, |
| const std::shared_ptr<BatchResult>& result, |
| int64 offset) LOCKS_EXCLUDED(mu_) { |
| // Get the next input element. |
| std::vector<Tensor> input_element; |
| bool end_of_input; |
| Status status = |
| input_impl_->GetNext(ctx.get(), &input_element, &end_of_input); |
| bool return_early; |
| { |
| mutex_lock l(result->mu); |
| result->end_of_input = result->end_of_input || end_of_input; |
| result->status.Update(status); |
| return_early = result->end_of_input || !result->status.ok(); |
| } |
| if (return_early) { |
| CallCompleted(result); |
| return; |
| } |
| |
| // Call `captured_func_(input_element)`, using `Callback` to store the |
| // result in `result`. |
| (*ctx->runner())(std::bind( |
| [this, result, offset](std::shared_ptr<IteratorContext> ctx, |
| std::vector<Tensor> input_element) { |
| std::shared_ptr<std::vector<Tensor>> return_values( |
| new std::vector<Tensor>()); |
| dataset()->captured_func_->RunAsync( |
| ctx.get(), std::move(input_element), return_values.get(), |
| [this, ctx, result, return_values, offset](Status status) { |
| Callback(ctx, result, return_values, offset, status); |
| }, |
| prefix()); |
| }, |
| ctx, std::move(input_element))); |
| } |
| |
| Status CopyPartialBatch(Tensor* output, const Tensor& value, |
| int64 num_elements) { |
| switch (value.dtype()) { |
| #define HANDLE_TYPE(type) \ |
| case DataTypeToEnum<type>::value: { \ |
| auto output_t = output->flat_outer_dims<type>(); \ |
| auto value_t = value.flat_outer_dims<type>(); \ |
| for (size_t i = 0; i < num_elements; i++) { \ |
| output_t.template chip<0>(i) = value_t.template chip<0>(i); \ |
| } \ |
| return Status::OK(); \ |
| } |
| TF_CALL_DATASET_TYPES(HANDLE_TYPE); |
| #undef HANDLE_TYPE |
| default: |
| return errors::InvalidArgument("Unsupported data type: ", |
| DataTypeString(value.dtype())); |
| } |
| return Status::OK(); |
| } |
| |
| void EnsureRunnerThreadStarted(IteratorContext* ctx) |
| EXCLUSIVE_LOCKS_REQUIRED(mu_) { |
| if (!runner_thread_) { |
| std::shared_ptr<IteratorContext> ctx_copy(new IteratorContext(*ctx)); |
| runner_thread_.reset(ctx->env()->StartThread( |
| {}, "runner_thread", |
| std::bind(&Iterator::RunnerThread, this, ctx_copy))); |
| } |
| } |
| |
| void EnsureOutputAllocated( |
| const std::shared_ptr<IteratorContext>& ctx, |
| const std::shared_ptr<BatchResult>& result, |
| const std::shared_ptr<std::vector<Tensor>>& return_values) { |
| mutex_lock l(result->mu); |
| if (result->output_allocated) { |
| return; |
| } |
| const size_t num_components = return_values->size(); |
| for (size_t i = 0; i < num_components; ++i) { |
| TensorShape component_shape({dataset()->batch_size_}); |
| component_shape.AppendShape(return_values->at(i).shape()); |
| AllocatorAttributes attr; |
| attr.set_gpu_compatible(true); |
| Tensor component(ctx->allocator(attr), return_values->at(i).dtype(), |
| component_shape); |
| result->output.emplace_back(std::move(component)); |
| } |
| result->output_allocated = true; |
| } |
| |
| Status ProcessResult(IteratorContext* ctx, |
| const std::shared_ptr<BatchResult>& result, |
| std::vector<Tensor>* out_tensors, |
| bool* end_of_sequence) { |
| mutex_lock l(result->mu); |
| if (result->num_elements == 0) { |
| *end_of_sequence = true; |
| return Status::OK(); |
| } |
| // `f` may deliberately raise `errors::OutOfRange` to indicate that we |
| // should terminate the iteration early. |
| if (!result->status.ok() && !errors::IsOutOfRange(result->status)) { |
| // Deallocate tensors allocated for the output. |
| result->output.clear(); |
| *end_of_sequence = false; |
| return result->status; |
| } |
| if (result->num_elements < dataset()->batch_size_) { |
| if (dataset()->drop_remainder_) { |
| // Deallocate tensors allocated for the output. |
| result->output.clear(); |
| *end_of_sequence = true; |
| return Status::OK(); |
| } |
| const std::vector<Tensor>& output = result->output; |
| for (size_t i = 0; i < output.size(); ++i) { |
| TensorShape component_shape(result->output[i].shape()); |
| component_shape.set_dim(0, result->num_elements); |
| AllocatorAttributes attr; |
| attr.set_gpu_compatible(true); |
| Tensor component(ctx->allocator(attr), output[i].dtype(), |
| component_shape); |
| TF_RETURN_IF_ERROR( |
| CopyPartialBatch(&component, output[i], result->num_elements)); |
| out_tensors->emplace_back(std::move(component)); |
| } |
| // Deallocate tensors allocated for the output. |
| result->output.clear(); |
| } else { |
| *out_tensors = std::move(result->output); |
| } |
| *end_of_sequence = result->num_elements == 0; |
| return Status::OK(); |
| } |
| |
| void RunnerThread(const std::shared_ptr<IteratorContext>& ctx) |
| LOCKS_EXCLUDED(mu_) { |
| std::vector<std::pair<std::shared_ptr<BatchResult>, int64>> new_calls; |
| RecordStart(ctx.get()); |
| auto stop_cleanup = |
| gtl::MakeCleanup([this, &ctx]() { RecordStop(ctx.get()); }); |
| new_calls.reserve(num_parallel_calls_); |
| auto busy = [this]() EXCLUSIVE_LOCKS_REQUIRED(mu_) -> bool { |
| int64 num_parallel_calls = num_parallel_calls_; |
| int64 max_batch_results = |
| (num_parallel_calls + dataset()->batch_size_ - 1) / |
| dataset()->batch_size_; |
| return num_calls_ >= num_parallel_calls || |
| (batch_results_.size() > max_batch_results || |
| (batch_results_.size() == max_batch_results && |
| call_counter_ % dataset()->batch_size_ == 0)); |
| }; |
| while (true) { |
| { |
| mutex_lock l(mu_); |
| while (!cancelled_ && busy()) { |
| RecordStop(ctx.get()); |
| cond_var_.wait(l); |
| RecordStart(ctx.get()); |
| } |
| |
| if (cancelled_) { |
| return; |
| } |
| |
| while (!busy()) { |
| if (call_counter_ % dataset()->batch_size_ == 0) { |
| batch_results_.emplace_back( |
| new BatchResult(dataset()->batch_size_)); |
| } |
| int64 offset = call_counter_++ % dataset()->batch_size_; |
| new_calls.emplace_back(batch_results_.back(), offset); |
| num_calls_++; |
| } |
| } |
| |
| for (const auto& call : new_calls) { |
| CallFunction(ctx, call.first, call.second); |
| } |
| new_calls.clear(); |
| } |
| } |
| |
| Status ReadBatchResult(IteratorContext* ctx, IteratorStateReader* reader, |
| size_t index) EXCLUSIVE_LOCKS_REQUIRED(mu_) { |
| batch_results_.emplace_back(new BatchResult(dataset()->batch_size_)); |
| std::shared_ptr<BatchResult> result = batch_results_.back(); |
| string prefix = strings::StrCat("batch_results_", index); |
| mutex_lock l(result->mu); |
| result->end_of_input = reader->Contains( |
| full_name(strings::StrCat(prefix, "_end_of_input"))); |
| TF_RETURN_IF_ERROR( |
| reader->ReadScalar(full_name(strings::StrCat(prefix, "_num_calls")), |
| &result->num_calls)); |
| TF_RETURN_IF_ERROR(reader->ReadScalar( |
| full_name(strings::StrCat(prefix, "_num_elements")), |
| &result->num_elements)); |
| result->output_allocated = reader->Contains( |
| full_name(strings::StrCat(prefix, "_output_allocated"))); |
| int64 output_size; |
| TF_RETURN_IF_ERROR(reader->ReadScalar( |
| full_name(strings::StrCat(prefix, "_output_size")), &output_size)); |
| result->output.reserve(output_size); |
| for (int i = 0; i < output_size; i++) { |
| Tensor t; |
| TF_RETURN_IF_ERROR(reader->ReadTensor( |
| full_name(strings::StrCat(prefix, "_output_", i)), &t)); |
| // If the batch was not full, we may have stored only the relevant |
| // slice. Since tensors in `BatchResult.output` are expected to |
| // have the leading dimension of size batch_size, we build a larger |
| // tensor and copy the slice read from the checkpoint into it. |
| if (t.dim_size(0) < dataset()->batch_size_) { |
| TensorShape component_shape(t.shape()); |
| component_shape.set_dim(0, dataset()->batch_size_); |
| AllocatorAttributes attr; |
| attr.set_gpu_compatible(true); |
| Tensor new_t(ctx->allocator(attr), t.dtype(), component_shape); |
| TF_RETURN_IF_ERROR(CopyPartialBatch(&new_t, t, t.dim_size(0))); |
| result->output.emplace_back(std::move(new_t)); |
| } else { |
| result->output.emplace_back(std::move(t)); |
| } |
| } |
| TF_RETURN_IF_ERROR(ReadStatus( |
| reader, strings::StrCat(prefix, "_status"), &result->status)); |
| return Status::OK(); |
| } |
| |
| Status ReadStatus(IteratorStateReader* reader, const string& prefix, |
| Status* status) EXCLUSIVE_LOCKS_REQUIRED(mu_) { |
| int64 code_int; |
| TF_RETURN_IF_ERROR(reader->ReadScalar( |
| full_name(strings::StrCat(prefix, "_code")), &code_int)); |
| error::Code code = static_cast<error::Code>(code_int); |
| |
| if (code != error::Code::OK) { |
| string error_message; |
| TF_RETURN_IF_ERROR(reader->ReadScalar( |
| full_name(strings::StrCat(prefix, "_msg")), &error_message)); |
| *status = Status(code, error_message); |
| } else { |
| *status = Status::OK(); |
| } |
| return Status::OK(); |
| } |
| |
| Status WriteBatchResult(IteratorStateWriter* writer, size_t index) |
| EXCLUSIVE_LOCKS_REQUIRED(mu_) { |
| std::shared_ptr<BatchResult> result = batch_results_[index]; |
| string prefix = strings::StrCat("batch_results_", index); |
| mutex_lock l(result->mu); |
| if (result->end_of_input) { |
| TF_RETURN_IF_ERROR(writer->WriteScalar( |
| full_name(strings::StrCat(prefix, "_end_of_input")), "")); |
| } |
| TF_RETURN_IF_ERROR(writer->WriteScalar( |
| full_name(strings::StrCat(prefix, "_num_calls")), |
| result->num_calls)); |
| TF_RETURN_IF_ERROR(writer->WriteScalar( |
| full_name(strings::StrCat(prefix, "_num_elements")), |
| result->num_elements)); |
| if (result->output_allocated) { |
| TF_RETURN_IF_ERROR(writer->WriteScalar( |
| full_name(strings::StrCat(prefix, "_output_allocated")), "")); |
| } |
| TF_RETURN_IF_ERROR(writer->WriteScalar( |
| full_name(strings::StrCat(prefix, "_output_size")), |
| result->output.size())); |
| for (int i = 0; i < result->output.size(); i++) { |
| // If the batch is not full, we only store the first `num_elements` |
| // values. The rest of the batch tensor is *uninitialized* and |
| // accessing that will raise msan errors. |
| if (result->num_elements < dataset()->batch_size_) { |
| TF_RETURN_IF_ERROR(writer->WriteTensor( |
| full_name(strings::StrCat(prefix, "_output_", i)), |
| result->output[i].Slice(0, result->num_elements))); |
| } else { |
| TF_RETURN_IF_ERROR(writer->WriteTensor( |
| full_name(strings::StrCat(prefix, "_output_", i)), |
| result->output[i])); |
| } |
| } |
| TF_RETURN_IF_ERROR(WriteStatus( |
| writer, strings::StrCat(prefix, "_status"), result->status)); |
| return Status::OK(); |
| } |
| |
| Status WriteStatus(IteratorStateWriter* writer, const string& prefix, |
| const Status& status) EXCLUSIVE_LOCKS_REQUIRED(mu_) { |
| TF_RETURN_IF_ERROR( |
| writer->WriteScalar(full_name(strings::StrCat(prefix, "_code")), |
| static_cast<int64>(status.code()))); |
| if (!status.ok()) { |
| TF_RETURN_IF_ERROR( |
| writer->WriteScalar(full_name(strings::StrCat(prefix, "_msg")), |
| status.error_message())); |
| } |
| return Status::OK(); |
| } |
| |
| // Used for coordination between the main thread, the runner thread, and |
| // the callback threads. |
| mutex mu_; |
| // Used for coordination between the main thread, the runner thread, and |
| // the callback threads. In particular, the runner thread should only |
| // schedule new calls when the number of in-flight calls is less than the |
| // user specified level of parallelism and there are slots available in |
| // the `batch_results_` buffer. |
| condition_variable cond_var_; |
| // Identifies the maximum number of parallel calls. |
| std::atomic<int64> num_parallel_calls_; |
| // Counts the number of outstanding calls for this batch. |
| int64 num_calls_ GUARDED_BY(mu_) = 0; |
| // Counts the total number of calls. |
| int64 call_counter_ GUARDED_BY(mu_) = 0; |
| std::unique_ptr<IteratorBase> input_impl_; |
| // Buffer for storing the (intermediate) batch results. |
| std::deque<std::shared_ptr<BatchResult>> batch_results_ GUARDED_BY(mu_); |
| std::unique_ptr<Thread> runner_thread_ GUARDED_BY(mu_); |
| bool cancelled_ GUARDED_BY(mu_) = false; |
| }; |
| |
| const DatasetBase* const input_; |
| const NameAttrList func_; |
| const int64 batch_size_; |
| const int64 num_parallel_calls_; |
| const bool drop_remainder_; |
| const DataTypeVector output_types_; |
| const std::vector<PartialTensorShape> output_shapes_; |
| const NameAttrList map_fn_; |
| const std::unique_ptr<CapturedFunction> captured_func_; |
| const Eigen::ThreadPoolDevice* device_; // not owned |
| }; |
| |
| const int op_version_; |
| DataTypeVector output_types_; |
| std::vector<PartialTensorShape> output_shapes_; |
| NameAttrList func_; |
| }; |
| |
| REGISTER_KERNEL_BUILDER(Name("MapAndBatchDataset").Device(DEVICE_CPU), |
| MapAndBatchDatasetOp); |
| |
| REGISTER_KERNEL_BUILDER(Name("MapAndBatchDatasetV2").Device(DEVICE_CPU), |
| MapAndBatchDatasetOp); |
| |
| } // namespace |
| } // namespace data |
| } // namespace tensorflow |