blob: cbb53e19937c288c4f50b3e35096ad3148580f26 [file] [log] [blame]
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/dataset.h"
#include "tensorflow/core/framework/partial_tensor_shape.h"
#include "tensorflow/core/framework/tensor.h"
namespace tensorflow {
namespace data {
namespace experimental {
namespace {
class SleepDatasetOp : public UnaryDatasetOpKernel {
public:
using UnaryDatasetOpKernel::UnaryDatasetOpKernel;
void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
DatasetBase** output) override {
int64 sleep_microseconds;
OP_REQUIRES_OK(ctx, ParseScalarArgument<int64>(ctx, "sleep_microseconds",
&sleep_microseconds));
OP_REQUIRES(ctx, sleep_microseconds >= 0,
errors::InvalidArgument("`sleep_microseconds` must be >= 0"));
*output = new Dataset(ctx, input, sleep_microseconds);
}
private:
class Dataset : public DatasetBase {
public:
Dataset(OpKernelContext* ctx, const DatasetBase* input,
int64 sleep_microseconds)
: DatasetBase(DatasetContext(ctx)),
input_(input),
sleep_microseconds_(sleep_microseconds) {
input_->Ref();
}
~Dataset() override { input_->Unref(); }
std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
return absl::make_unique<Iterator>(
Iterator::Params{this, strings::StrCat(prefix, "::Sleep")});
}
const DataTypeVector& output_dtypes() const override {
return input_->output_dtypes();
}
const std::vector<PartialTensorShape>& output_shapes() const override {
return input_->output_shapes();
}
string DebugString() const override { return "SleepDatasetOp::Dataset"; }
int64 Cardinality() const override { return input_->Cardinality(); }
Status CheckExternalState() const override {
return input_->CheckExternalState();
}
protected:
Status AsGraphDefInternal(SerializationContext* ctx,
DatasetGraphDefBuilder* b,
Node** output) const override {
Node* input_graph_node = nullptr;
TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
Node* sleep_microseconds = nullptr;
TF_RETURN_IF_ERROR(
b->AddScalar(sleep_microseconds_, &sleep_microseconds));
return b->AddDataset(this,
{{0, input_graph_node},
{1, sleep_microseconds}}, // Single tensor inputs.
{}, // Tensor list inputs.
{}, // Attrs
output);
}
private:
class Iterator : public DatasetIterator<Dataset> {
public:
explicit Iterator(const Params& params)
: DatasetIterator<Dataset>(params) {}
Status Initialize(IteratorContext* ctx) override {
return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
}
Status GetNextInternal(IteratorContext* ctx,
std::vector<Tensor>* out_tensors,
bool* end_of_sequence) override {
RecordStop(ctx);
ctx->env()->SleepForMicroseconds(dataset()->sleep_microseconds_);
RecordStart(ctx);
return input_impl_->GetNext(ctx, out_tensors, end_of_sequence);
}
protected:
std::shared_ptr<model::Node> CreateNode(
IteratorContext* ctx, model::Node::Args args) const override {
return model::MakeKnownRatioNode(std::move(args),
/*ratio=*/1);
}
Status SaveInternal(IteratorStateWriter* writer) override {
return SaveInput(writer, input_impl_);
}
Status RestoreInternal(IteratorContext* ctx,
IteratorStateReader* reader) override {
return RestoreInput(ctx, reader, input_impl_);
}
private:
std::unique_ptr<IteratorBase> input_impl_;
};
const DatasetBase* const input_;
// TODO(b/117612213): Investigate autotuning for this value.
const int64 sleep_microseconds_;
};
};
REGISTER_KERNEL_BUILDER(Name("SleepDataset").Device(DEVICE_CPU),
SleepDatasetOp);
REGISTER_KERNEL_BUILDER(Name("ExperimentalSleepDataset").Device(DEVICE_CPU),
SleepDatasetOp);
REGISTER_KERNEL_BUILDER(Name("SleepDataset")
.Device(DEVICE_GPU)
.HostMemory("sleep_microseconds")
.HostMemory("input_dataset")
.HostMemory("handle"),
SleepDatasetOp);
REGISTER_KERNEL_BUILDER(Name("ExperimentalSleepDataset")
.Device(DEVICE_GPU)
.HostMemory("sleep_microseconds")
.HostMemory("input_dataset")
.HostMemory("handle"),
SleepDatasetOp);
} // namespace
} // namespace experimental
} // namespace data
} // namespace tensorflow