blob: 8a160aa8502121042706c3fd9706909d74bb5b7e [file] [log] [blame]
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/kernels/data/experimental/data_service_dataset_op.h"
#include <map>
#include <memory>
#include <queue>
#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
#include "absl/strings/str_cat.h"
#include "tensorflow/core/data/dataset.pb.h"
#include "tensorflow/core/data/service/data_service.h"
#include "tensorflow/core/data/service/grpc_util.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
#include "tensorflow/core/framework/dataset.h"
#include "tensorflow/core/framework/model.h"
#include "tensorflow/core/framework/partial_tensor_shape.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/kernels/data/dataset_utils.h"
#include "tensorflow/core/kernels/data/name_utils.h"
#include "tensorflow/core/kernels/data/serialization_utils.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/snappy.h"
#include "tensorflow/core/profiler/lib/traceme.h"
#include "tensorflow/core/protobuf/error_codes.pb.h"
namespace tensorflow {
namespace data {
/* static */ constexpr const char* const DataServiceDatasetOp::kDatasetType;
/* static */ constexpr const char* const DataServiceDatasetOp::kDatasetId;
/* static */ constexpr const char* const DataServiceDatasetOp::kProcessingMode;
/* static */ constexpr const char* const DataServiceDatasetOp::kAddress;
/* static */ constexpr const char* const DataServiceDatasetOp::kProtocol;
/* static */ constexpr const char* const DataServiceDatasetOp::kJobName;
/* static */ constexpr const char* const
DataServiceDatasetOp::kMaxOutstandingRequests;
/* static */ constexpr const char* const
DataServiceDatasetOp::kIterationCounter;
/* static */ constexpr const char* const DataServiceDatasetOp::kOutputTypes;
/* static */ constexpr const char* const DataServiceDatasetOp::kOutputShapes;
namespace {
// Once we've spent `kRetryTimeoutMicros` in `GetNextInternal`, we will wait for
// the current attempt to complete and perform no more retries.
const int64 kRetryTimeoutMicros = 1000LL * 1000 * 60 * 60; // 60 minutes.
// Default interval between task list refreshes.
const int64 kDefaultTaskRefreshIntervalMs = 1000; // 1 second.
} // namespace
// Dataset for reading data from the tf.data service non-deterministically.
//
// This dataset interleaves dataset elements produced by multiple tf.data
// workers. We periodically query the dispatcher to determine which workers
// to read from (in case workers are added or removed).
class DataServiceDatasetOp::Dataset : public DatasetBase {
public:
Dataset(OpKernelContext* ctx, int64 dataset_id,
ProcessingMode processing_mode, const std::string& address,
const std::string& protocol, const std::string& job_name,
int64 max_outstanding_requests, int64 task_refresh_interval_ms,
IterationCounter* iteration_counter, bool owns_resource,
ResourceHandle iteration_counter_handle,
const DataTypeVector& output_types,
const std::vector<PartialTensorShape>& output_shapes)
: DatasetBase(DatasetContext(ctx)),
dataset_id_(dataset_id),
processing_mode_(processing_mode),
address_(address),
protocol_(protocol),
job_name_(job_name),
max_outstanding_requests_(max_outstanding_requests),
task_refresh_interval_ms_(task_refresh_interval_ms),
iteration_counter_(iteration_counter),
owns_resource_(owns_resource),
iteration_counter_handle_(iteration_counter_handle),
resource_mgr_(ctx->resource_manager()),
output_types_(output_types),
output_shapes_(output_shapes) {}
~Dataset() override {
iteration_counter_->Unref();
if (owns_resource_) {
Status s = resource_mgr_->Delete<IterationCounter>(
iteration_counter_handle_.container(),
iteration_counter_handle_.name());
if (!s.ok()) {
LOG(WARNING) << "Failed to delete iteration counter resource: " << s;
}
}
}
std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
return absl::make_unique<Iterator>(
Iterator::Params{this,
name_utils::IteratorPrefix(kDatasetType, prefix)},
iteration_counter_->GetAndIncrement());
}
const DataTypeVector& output_dtypes() const override { return output_types_; }
const std::vector<PartialTensorShape>& output_shapes() const override {
return output_shapes_;
}
string DebugString() const override {
return name_utils::DatasetDebugString(kDatasetType);
}
Status CheckExternalState() const override {
return Status(
error::FAILED_PRECONDITION,
strings::StrCat(DebugString(), " does not yet support serialization."));
}
protected:
Status AsGraphDefInternal(SerializationContext* ctx,
DatasetGraphDefBuilder* b,
Node** output) const override {
Node* dataset_id;
TF_RETURN_IF_ERROR(b->AddScalar(dataset_id_, &dataset_id));
Node* processing_mode;
tstring processing_mode_str = ProcessingModeToString(processing_mode_);
TF_RETURN_IF_ERROR(b->AddScalar(processing_mode_str, &processing_mode));
Node* address;
TF_RETURN_IF_ERROR(b->AddScalar(address_, &address));
Node* protocol;
TF_RETURN_IF_ERROR(b->AddScalar(protocol_, &protocol));
Node* job_name;
TF_RETURN_IF_ERROR(b->AddScalar(job_name_, &job_name));
Node* max_outstanding_requests;
TF_RETURN_IF_ERROR(
b->AddScalar(max_outstanding_requests_, &max_outstanding_requests));
Node* iteration_counter_handle = nullptr;
Tensor handle(DT_RESOURCE, TensorShape({}));
handle.scalar<ResourceHandle>()() = iteration_counter_handle_;
TF_RETURN_IF_ERROR(b->AddTensor(handle, &iteration_counter_handle));
AttrValue task_refresh_interval_hint_ms;
b->BuildAttrValue(task_refresh_interval_ms_,
&task_refresh_interval_hint_ms);
TF_RETURN_IF_ERROR(
b->AddDataset(this,
{dataset_id, processing_mode, address, protocol, job_name,
max_outstanding_requests, iteration_counter_handle},
{std::make_pair(kTaskRefreshIntervalHintMs,
task_refresh_interval_hint_ms)},
output));
return Status::OK();
}
private:
class Iterator : public DatasetIterator<Dataset> {
public:
explicit Iterator(const Params& params, int64 iterator_index)
: DatasetIterator<Dataset>(params),
iterator_index_(iterator_index),
max_outstanding_requests_(params.dataset->max_outstanding_requests_) {
}
~Iterator() override {
VLOG(1) << "Destroying data service dataset iterator for job id "
<< job_id_;
CancelThreads();
if (deregister_fn_) deregister_fn_();
// Thread destructors will block until the threads finish, no need to wait
// here.
}
void CancelThreads() TF_LOCKS_EXCLUDED(mu_) {
mutex_lock l(mu_);
VLOG(1) << "Cancelling threads in DataServiceDataset::Iterator";
cancelled_ = true;
worker_thread_cv_.notify_all();
manager_thread_cv_.notify_all();
get_next_cv_.notify_all();
}
Status Initialize(IteratorContext* ctx) override {
VLOG(3) << "Connecting to " << dataset()->address_
<< " in data service dataset op";
TF_RETURN_IF_ERROR(RegisterCancellationCallback(
ctx->cancellation_manager(), [this]() { CancelThreads(); },
&deregister_fn_));
DataServiceDispatcherClient dispatcher(dataset()->address_,
dataset()->protocol_);
int64 deadline_micros = ctx->env()->NowMicros() + kRetryTimeoutMicros;
if (dataset()->job_name_.empty()) {
TF_RETURN_IF_ERROR(grpc_util::Retry(
[&]() {
return dispatcher.CreateJob(dataset()->dataset_id_,
dataset()->processing_mode_,
&job_id_);
},
"create job", deadline_micros));
} else {
TF_RETURN_IF_ERROR(grpc_util::Retry(
[&]() {
return dispatcher.GetOrCreateJob(
dataset()->dataset_id_, dataset()->processing_mode_,
dataset()->job_name_, iterator_index_, &job_id_);
},
"get or create job", deadline_micros));
}
VLOG(1) << "Created data service job with id " << job_id_;
return Status::OK();
}
Status GetNextInternal(IteratorContext* ctx,
std::vector<Tensor>* out_tensors,
bool* end_of_sequence) override {
VLOG(3) << "Calling GetNext in data service dataset op";
mutex_lock l(mu_);
if (!task_thread_manager_ && !cancelled_) {
task_thread_manager_ =
ctx->StartThread("task-thread-manager", [this, ctx]() {
TaskThreadManager(absl::make_unique<IteratorContext>(*ctx));
});
}
while (results_.empty() && !job_finished_ && !cancelled_ &&
status_.ok()) {
get_next_cv_.wait(l);
}
if (cancelled_) {
return errors::Cancelled("Data service iterator was cancelled");
}
if (!status_.ok()) {
return status_;
}
if (results_.empty()) {
*end_of_sequence = true;
return Status::OK();
}
DCHECK(!results_.empty());
*end_of_sequence = false;
out_tensors->swap(results_.front());
results_.pop();
worker_thread_cv_.notify_one();
return Status::OK();
}
protected:
std::shared_ptr<model::Node> CreateNode(
IteratorContext* ctx, model::Node::Args args) const override {
return model::MakeKnownRatioNode(std::move(args),
/*ratio=*/1);
}
Status SaveInternal(SerializationContext* ctx,
IteratorStateWriter* writer) override {
return errors::Unimplemented("SaveInternal is not yet supported");
}
Status RestoreInternal(IteratorContext* ctx,
IteratorStateReader* reader) override {
return errors::Unimplemented("RestoreInternal is not yet supported");
}
private:
struct Task {
Task(int64 task_id, const std::string& address,
std::unique_ptr<DataServiceWorkerClient> worker)
: task_id(task_id), address(address), worker(std::move(worker)) {}
const int64 task_id;
// Address of the tf.data service worker for task `task_id`.
const std::string address;
// Client for fetching task elements from the tf.data service worker.
const std::unique_ptr<DataServiceWorkerClient> worker;
// Indicates whether a worker thread is currently processing the task.
bool in_use TF_GUARDED_BY(&Iterator::mu_) = false;
// Indicates whether the worker has returned end_of_sequence for the task.
bool end_of_sequence TF_GUARDED_BY(&Iterator::mu_) = false;
};
// Periodically refresh the task list.
// Maintain one thread fetching elements for each task.
// TODO(aaudibert): Instead of polling, have dispatcher send updates when
// the list of tasks changes.
void TaskThreadManager(std::unique_ptr<IteratorContext> ctx) {
auto cleanup =
gtl::MakeCleanup([] { VLOG(1) << "Task thread manager exiting"; });
VLOG(1) << "Starting task thread manager";
DataServiceDispatcherClient dispatcher(dataset()->address_,
dataset()->protocol_);
uint64 next_check = Env::Default()->NowMicros();
while (true) {
{
mutex_lock l(mu_);
// All units are microseconds.
while (!cancelled_ && Env::Default()->NowMicros() < next_check) {
int64 remaining_time = next_check - Env::Default()->NowMicros();
VLOG(3) << "Task thread manager waiting for " << remaining_time
<< "us";
manager_thread_cv_.wait_for(
l, std::chrono::microseconds(remaining_time));
}
if (cancelled_) {
VLOG(3) << "Task thread manager finished";
return;
}
}
UpdateTasks(&dispatcher);
UpdateWorkerThreads(ctx.get());
next_check = Env::Default()->NowMicros() +
dataset()->task_refresh_interval_ms_ * 1000;
}
}
void UpdateTasks(DataServiceDispatcherClient* dispatcher)
LOCKS_EXCLUDED(mu_) {
VLOG(3) << "Updating tasks";
std::vector<TaskInfo> tasks;
bool job_finished;
Status s = dispatcher->GetTasks(job_id_, &tasks, &job_finished);
if (!s.ok()) {
LOG(WARNING) << "Failed to get task info for job id " << job_id_ << ": "
<< s;
return;
}
absl::flat_hash_map<int64, TaskInfo> task_id_to_task;
for (auto& task : tasks) {
task_id_to_task[task.task_id()] = task;
}
mutex_lock l(mu_);
job_finished_ = job_finished;
if (job_finished) {
get_next_cv_.notify_all();
return;
}
for (int i = 0; i < tasks_.size(); ++i) {
std::shared_ptr<Task> task = tasks_[i];
if (task_id_to_task.contains(task->task_id)) {
// Remove already-known tasks from `task_id_to_task`, so that at the
// end of the loop, only new tasks remain.
task_id_to_task.erase(task->task_id);
} else {
// Task has been removed.
if (task->end_of_sequence) {
finished_tasks_--;
}
tasks_[i] = tasks_[tasks_.size() - 1];
tasks_.pop_back();
}
}
for (auto& new_task_entry : task_id_to_task) {
TaskInfo& task_info = new_task_entry.second;
std::unique_ptr<DataServiceWorkerClient> worker;
Status s = CreateDataServiceWorkerClient(task_info.worker_address(),
dataset()->protocol_, &worker);
if (!s.ok()) {
status_ = s;
get_next_cv_.notify_all();
continue;
}
tasks_.push_back(std::make_shared<Task>(task_info.task_id(),
task_info.worker_address(),
std::move(worker)));
}
if (dataset()->max_outstanding_requests_ == model::kAutotune) {
// Adjust max_outstanding_requests to account for newly added tasks.
max_outstanding_requests_ = tasks_.size();
}
}
void UpdateWorkerThreads(IteratorContext* ctx) LOCKS_EXCLUDED(mu_) {
mutex_lock l(mu_);
while (num_running_worker_threads_ < max_outstanding_requests_) {
num_running_worker_threads_++;
outstanding_requests_++;
auto done = [this]() {
mutex_lock l(mu_);
num_running_worker_threads_--;
outstanding_requests_--;
};
worker_threads_.push_back(ctx->StartThread(
"tf-data-service-task_thread", [this, done = std::move(done)]() {
RunWorkerThread(std::move(done));
}));
}
}
void RunWorkerThread(std::function<void()> done) {
auto cleanup = gtl::MakeCleanup([done = std::move(done)]() {
done();
VLOG(1) << "Worker thread exiting";
});
VLOG(1) << "Starting worker thread";
std::shared_ptr<Task> task_to_process;
while (true) {
{
mutex_lock l(mu_);
if (task_to_process) {
task_to_process->in_use = false;
task_to_process = nullptr;
worker_thread_cv_.notify_one();
}
outstanding_requests_--;
while (!cancelled_ && !(SpaceInBuffer() && TaskAvailable())) {
if (VLOG_IS_ON(3)) {
VLOG(3) << "Sleeping with results_.size=" << results_.size()
<< ", outstanding_requests_=" << outstanding_requests_
<< ", max_oustanding_requests="
<< max_outstanding_requests_
<< " finished_tasks=" << finished_tasks_
<< " tasks_.size()=" << tasks_.size();
}
worker_thread_cv_.wait(l);
}
outstanding_requests_++;
if (cancelled_) {
return;
}
// Search for a task to update.
int num_tasks = tasks_.size();
for (int i = 0; i < num_tasks; ++i) {
int index = (next_task_index_ + i) % num_tasks;
std::shared_ptr<Task>& task = tasks_[index];
if (!task->in_use && !task->end_of_sequence) {
task->in_use = true;
task_to_process = task;
next_task_index_ = (index + 1) % num_tasks;
break;
}
}
DCHECK(task_to_process != nullptr);
VLOG(3) << "Processing task " << task_to_process->task_id;
}
int64 deadline_micros =
Env::Default()->NowMicros() + kRetryTimeoutMicros;
Status s = GetElement(task_to_process.get(), deadline_micros);
if (!s.ok()) {
mutex_lock l(mu_);
VLOG(1) << "Failed to get element for task "
<< task_to_process->task_id << ": " << s;
task_to_process->in_use = false;
status_ = s;
get_next_cv_.notify_all();
return;
}
}
}
// Gets an element from a task and adds the element to `results_`.
//
// If the task reaches end_of_sequence or is cancelled (e.g. due to a
// worker dying), GetElement returns Status::OK() without adding to
// `results_`.
Status GetElement(Task* task, int64 deadline_micros)
TF_LOCKS_EXCLUDED(mu_) {
VLOG(3) << "Getting an element for task id " << task->task_id;
tensorflow::profiler::TraceMe activity(
"GetDataServiceElement", tensorflow::profiler::TraceMeLevel::kInfo);
CompressedElement compressed;
bool end_of_sequence;
for (int num_retries = 0;; ++num_retries) {
Status s = task->worker->GetElement(task->task_id, &compressed,
&end_of_sequence);
if (s.ok()) {
break;
}
// Retry all errors that could indicate preemption.
if (!errors::IsUnavailable(s) && !errors::IsCancelled(s) &&
!errors::IsAborted(s)) {
return s;
}
{
mutex_lock l(mu_);
// If `UpdateTaskThreads` finds that the task has been cancelled, it
// will set end_of_sequence to `true`.
if (task->end_of_sequence || cancelled_) {
return Status::OK();
}
}
const int64 now_micros = EnvTime::NowMicros();
if (now_micros > deadline_micros) {
return s;
}
const int64 deadline_with_backoff_micros =
now_micros + ::tensorflow::ComputeBackoffMicroseconds(num_retries);
// Wait for a short period of time before retrying the RPC. If our
// backoff would put us past the RPC deadline, we truncate it to ensure
// our RPC starts before the deadline.
const auto backoff_until =
(deadline_micros > deadline_with_backoff_micros)
? deadline_with_backoff_micros
: deadline_micros;
Env::Default()->SleepForMicroseconds(backoff_until - now_micros);
}
std::vector<Tensor> element;
if (!end_of_sequence) {
Tensor tensor(DT_VARIANT, TensorShape{});
tensor.scalar<Variant>()() = std::move(compressed);
element.push_back(tensor);
}
mutex_lock l(mu_);
if (end_of_sequence) {
task->end_of_sequence = true;
finished_tasks_++;
return Status::OK();
}
results_.push(std::move(element));
get_next_cv_.notify_all();
VLOG(3) << "Got an element for task id " << task->task_id;
return Status::OK();
}
bool SpaceInBuffer() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
return results_.size() + outstanding_requests_ <
max_outstanding_requests_;
}
bool TaskAvailable() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
return finished_tasks_ + outstanding_requests_ < tasks_.size();
}
const int64 iterator_index_;
mutex mu_;
condition_variable get_next_cv_ TF_GUARDED_BY(mu_);
condition_variable worker_thread_cv_ TF_GUARDED_BY(mu_);
condition_variable manager_thread_cv_ TF_GUARDED_BY(mu_);
bool cancelled_ TF_GUARDED_BY(mu_) = false;
// Method for deregistering the cancellation callback.
std::function<void()> deregister_fn_;
int64 outstanding_requests_ TF_GUARDED_BY(mu_) = 0;
// max_outstanding_requests controls how many elements may be held in memory
// at the same time. This count includes both in-progress requests for
// elements as well as completed requests which haven't yet been produced.
int64 max_outstanding_requests_ TF_GUARDED_BY(mu_);
// The number of threads in `worker_threads_` which are still running.
int64 num_running_worker_threads_ TF_GUARDED_BY(mu_) = 0;
// The index of the next task in `tasks_` to read from.
int64 next_task_index_ TF_GUARDED_BY(mu_) = 0;
// The number tasks in the `tasks_` list that have reached end_of_sequence.
int64 finished_tasks_ TF_GUARDED_BY(mu_) = 0;
// List of tasks to read from.
std::vector<std::shared_ptr<Task>> tasks_ TF_GUARDED_BY(mu_);
// A status to be returned from the next call to `GetNext`. This is set by
// asynchronous threads when they encounter errors.
Status status_ TF_GUARDED_BY(mu_) = Status::OK();
std::queue<std::vector<Tensor>> results_ TF_GUARDED_BY(mu_);
// Set once in Initialize().
int64 job_id_;
bool job_finished_ = false;
// Must be ordered second to last so that worker threads are joined before
// destroying other fields.
std::vector<std::unique_ptr<Thread>> worker_threads_ TF_GUARDED_BY(mu_);
// Must be ordered last so that the thread is joined before destroying other
// fields.
std::unique_ptr<Thread> task_thread_manager_ GUARDED_BY(mu_);
};
const int64 dataset_id_;
const ProcessingMode processing_mode_;
const tstring address_;
const tstring protocol_;
const tstring job_name_;
const int64 max_outstanding_requests_;
const int64 task_refresh_interval_ms_;
IterationCounter* const iteration_counter_; // Owned
const bool owns_resource_;
const ResourceHandle iteration_counter_handle_;
ResourceMgr* const resource_mgr_; // Not owned
const DataTypeVector output_types_;
const std::vector<PartialTensorShape> output_shapes_;
};
DataServiceDatasetOp::DataServiceDatasetOp(OpKernelConstruction* ctx)
: DatasetOpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr(kTaskRefreshIntervalHintMs,
&task_refresh_interval_hint_ms_));
if (task_refresh_interval_hint_ms_ == model::kAutotune) {
task_refresh_interval_hint_ms_ = kDefaultTaskRefreshIntervalMs;
}
OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputTypes, &output_types_));
OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputShapes, &output_shapes_));
}
void DataServiceDatasetOp::MakeDataset(OpKernelContext* ctx,
DatasetBase** output) {
int64 dataset_id;
OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, kDatasetId, &dataset_id));
tstring processing_mode_str;
OP_REQUIRES_OK(
ctx, ParseScalarArgument(ctx, kProcessingMode, &processing_mode_str));
ProcessingMode processing_mode;
OP_REQUIRES_OK(ctx,
ParseProcessingMode(processing_mode_str, &processing_mode));
tstring address;
OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, kAddress, &address));
OP_REQUIRES(ctx, !address.empty(),
errors::InvalidArgument(kAddress, " must be non-empty."));
tstring protocol;
OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, kProtocol, &protocol));
OP_REQUIRES(ctx, !protocol.empty(),
errors::InvalidArgument(kProtocol, " must be non-empty."));
tstring job_name;
OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, kJobName, &job_name));
int64 max_outstanding_requests;
OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, kMaxOutstandingRequests,
&max_outstanding_requests));
ResourceHandle iteration_counter_handle;
OP_REQUIRES_OK(
ctx, HandleFromInput(ctx, kIterationCounter, &iteration_counter_handle));
IterationCounter* iteration_counter = nullptr;
Status s = ctx->resource_manager()->Lookup<IterationCounter>(
iteration_counter_handle.container(), iteration_counter_handle.name(),
&iteration_counter);
bool owns_resource = false;
if (errors::IsNotFound(s)) {
owns_resource = true;
static std::atomic<int64> resource_id_counter(0);
const std::string& container = ctx->resource_manager()->default_container();
std::string name =
strings::StrCat(ctx->op_kernel().name(), "/", kIterationCounter, "_",
resource_id_counter.fetch_add(1));
OP_REQUIRES_OK(ctx,
ctx->resource_manager()->LookupOrCreate<IterationCounter>(
container, name, &iteration_counter,
[](IterationCounter** counter) {
*counter = new IterationCounter();
return Status::OK();
}));
iteration_counter_handle =
MakeResourceHandle<IterationCounter>(ctx, container, name);
} else {
OP_REQUIRES_OK(ctx, s);
}
OP_REQUIRES(
ctx,
max_outstanding_requests == model::kAutotune ||
max_outstanding_requests > 0,
errors::InvalidArgument(kMaxOutstandingRequests, " must be positive or ",
model::kAutotune));
*output =
new Dataset(ctx, dataset_id, processing_mode, address, protocol, job_name,
max_outstanding_requests, task_refresh_interval_hint_ms_,
iteration_counter, owns_resource, iteration_counter_handle,
output_types_, output_shapes_);
}
REGISTER_KERNEL_BUILDER(Name("DataServiceDataset").Device(DEVICE_CPU),
DataServiceDatasetOp);
REGISTER_KERNEL_BUILDER(Name("DummyIterationCounter").Device(DEVICE_CPU),
DummyResourceOp<IterationCounter>);
} // namespace data
} // namespace tensorflow