blob: e9ac76163acafa941936f5675b9be43f9afd2cca [file] [log] [blame]
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/dataset.h"
#include "tensorflow/core/framework/partial_tensor_shape.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/lib/hash/hash.h"
namespace tensorflow {
namespace data {
namespace experimental {
namespace {
class DirectedInterleaveDatasetOp : public DatasetOpKernel {
public:
explicit DirectedInterleaveDatasetOp(OpKernelConstruction* ctx)
: DatasetOpKernel(ctx) {}
void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override {
DatasetBase* selector_input;
OP_REQUIRES_OK(ctx,
GetDatasetFromVariantTensor(ctx->input(0), &selector_input));
OP_REQUIRES(
ctx,
selector_input->output_dtypes().size() == 1 &&
selector_input->output_dtypes()[0] == DT_INT64 &&
selector_input->output_shapes().size() == 1 &&
selector_input->output_shapes()[0].IsCompatibleWith(
PartialTensorShape({})),
errors::InvalidArgument(
"The selector input must be a dataset of scalar int64 elements."));
std::vector<DatasetBase*> data_inputs;
for (size_t i = 1; i < ctx->num_inputs(); ++i) {
DatasetBase* input;
OP_REQUIRES_OK(ctx, GetDatasetFromVariantTensor(ctx->input(i), &input));
data_inputs.push_back(input);
OP_REQUIRES(
ctx, data_inputs[0]->output_dtypes() == input->output_dtypes(),
errors::InvalidArgument(
"All inputs must have the same output_dtypes. First input "
"has types ",
DataTypeVectorString(data_inputs[0]->output_dtypes()),
", and input ", i - 1, " has types ",
DataTypeVectorString(input->output_dtypes())));
}
*output = new Dataset(ctx, selector_input, std::move(data_inputs));
}
private:
class Dataset : public DatasetBase {
public:
Dataset(OpKernelContext* ctx, const DatasetBase* selector_input,
std::vector<DatasetBase*> data_inputs)
: DatasetBase(DatasetContext(ctx)),
selector_input_(selector_input),
data_inputs_(std::move(data_inputs)) {
selector_input_->Ref();
output_shapes_ = data_inputs_[0]->output_shapes();
data_inputs_[0]->Ref();
for (size_t i = 1; i < data_inputs_.size(); ++i) {
const DatasetBase* data_input = data_inputs_[i];
data_input->Ref();
for (size_t j = 0; j < output_shapes_.size(); ++j) {
output_shapes_[j] = MostSpecificCompatibleShape(
output_shapes_[j], data_input->output_shapes()[j]);
}
}
}
~Dataset() override {
selector_input_->Unref();
for (DatasetBase* data_input : data_inputs_) {
data_input->Unref();
}
}
std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
return absl::make_unique<Iterator>(Iterator::Params{
this, strings::StrCat(prefix, "::DirectedInterleave")});
}
const DataTypeVector& output_dtypes() const override {
return data_inputs_[0]->output_dtypes();
}
const std::vector<PartialTensorShape>& output_shapes() const override {
return output_shapes_;
}
string DebugString() const override {
return strings::StrCat("DirectedInterleaveDatasetOp::Dataset");
}
Status CheckExternalState() const override {
for (const auto& input : data_inputs_) {
TF_RETURN_IF_ERROR(input->CheckExternalState());
}
return selector_input_->CheckExternalState();
}
protected:
Status AsGraphDefInternal(SerializationContext* ctx,
DatasetGraphDefBuilder* b,
Node** output) const override {
Node* selector_input_node;
TF_RETURN_IF_ERROR(
b->AddInputDataset(ctx, selector_input_, &selector_input_node));
std::vector<Node*> data_input_nodes(data_inputs_.size());
for (size_t i = 0; i < data_inputs_.size(); ++i) {
TF_RETURN_IF_ERROR(
b->AddInputDataset(ctx, data_inputs_[i], &data_input_nodes[i]));
}
TF_RETURN_IF_ERROR(b->AddDataset(this, {{0, selector_input_node}},
{{1, data_input_nodes}}, {}, output));
return Status::OK();
}
private:
class Iterator : public DatasetIterator<Dataset> {
public:
explicit Iterator(const Params& params)
: DatasetIterator<Dataset>(params),
num_active_inputs_(params.dataset->data_inputs_.size()) {}
Status Initialize(IteratorContext* ctx) override {
mutex_lock l(mu_);
TF_RETURN_IF_ERROR(dataset()->selector_input_->MakeIterator(
ctx, strings::StrCat(prefix(), ".selector"),
&selector_input_impl_));
data_input_impls_.resize(dataset()->data_inputs_.size());
for (size_t i = 0; i < data_input_impls_.size(); ++i) {
const DatasetBase* data_input = dataset()->data_inputs_[i];
TF_RETURN_IF_ERROR(data_input->MakeIterator(
ctx, strings::StrCat(prefix(), "[", i, "]"),
&data_input_impls_[i]));
}
return Status::OK();
}
Status GetNextInternal(IteratorContext* ctx,
std::vector<Tensor>* out_tensors,
bool* end_of_sequence) override {
mutex_lock l(mu_);
if (!selector_input_impl_) {
*end_of_sequence = true;
return Status::OK();
}
while (true) {
std::vector<Tensor> selector_result;
*end_of_sequence = false;
TF_RETURN_IF_ERROR(selector_input_impl_->GetNext(
ctx, &selector_result, end_of_sequence));
if (*end_of_sequence) {
selector_input_impl_.reset();
for (auto& data_input_impl : data_input_impls_) {
data_input_impl.reset();
}
return Status::OK();
}
int64 selected_input = selector_result[0].scalar<int64>()();
if (selected_input < 0 || selected_input > data_input_impls_.size()) {
return errors::InvalidArgument(
"Selector index out of range: ", selected_input,
" >= ", data_input_impls_.size());
}
if (data_input_impls_[selected_input]) {
bool end_of_selected_input = false;
TF_RETURN_IF_ERROR(data_input_impls_[selected_input]->GetNext(
ctx, out_tensors, &end_of_selected_input));
if (!end_of_selected_input) {
return Status::OK();
}
data_input_impls_[selected_input].reset();
--num_active_inputs_;
if (num_active_inputs_ == 0) {
selector_input_impl_.reset();
*end_of_sequence = true;
return Status::OK();
}
}
VLOG(2) << "DirectedInterleave selected an exhausted input: "
<< selected_input;
}
}
protected:
std::shared_ptr<model::Node> CreateNode(
IteratorContext* ctx, model::Node::Args args) const override {
return model::MakeInterleaveManyNode(std::move(args));
}
Status SaveInternal(IteratorStateWriter* writer) override {
mutex_lock l(mu_);
if (selector_input_impl_) {
TF_RETURN_IF_ERROR(SaveInput(writer, selector_input_impl_));
} else {
TF_RETURN_IF_ERROR(
writer->WriteScalar(full_name("selector_input_impl_empty"), ""));
}
for (size_t i = 0; i < data_input_impls_.size(); ++i) {
const auto& data_input_impl = data_input_impls_[i];
if (data_input_impl) {
TF_RETURN_IF_ERROR(SaveInput(writer, data_input_impl));
} else {
TF_RETURN_IF_ERROR(writer->WriteScalar(
full_name(strings::StrCat("data_input_impl_empty[", i, "]")),
""));
}
}
return Status::OK();
}
Status RestoreInternal(IteratorContext* ctx,
IteratorStateReader* reader) override {
mutex_lock l(mu_);
if (!reader->Contains(full_name("selector_input_impl_empty"))) {
TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, selector_input_impl_));
} else {
selector_input_impl_.reset();
}
for (size_t i = 0; i < data_input_impls_.size(); ++i) {
if (!reader->Contains(full_name(
strings::StrCat("data_input_impl_empty[", i, "]")))) {
TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, data_input_impls_[i]));
} else {
data_input_impls_[i].reset();
}
}
return Status::OK();
}
private:
mutex mu_;
std::unique_ptr<IteratorBase> selector_input_impl_ GUARDED_BY(mu_);
std::vector<std::unique_ptr<IteratorBase>> data_input_impls_
GUARDED_BY(mu_);
int64 num_active_inputs_ GUARDED_BY(mu_);
};
static PartialTensorShape MostSpecificCompatibleShape(
const PartialTensorShape& ts1, const PartialTensorShape& ts2) {
PartialTensorShape output_tensorshape;
if (ts1.dims() != ts2.dims() || ts1.unknown_rank() || ts2.unknown_rank())
return output_tensorshape;
auto dims1 = ts1.dim_sizes();
auto dims2 = ts2.dim_sizes();
for (int d = 0; d < ts1.dims(); d++) {
if (dims1[d] == dims2[d])
output_tensorshape.Concatenate(dims1[d]);
else
output_tensorshape.Concatenate(-1);
}
return output_tensorshape;
}
const DatasetBase* const selector_input_;
const std::vector<DatasetBase*> data_inputs_;
std::vector<PartialTensorShape> output_shapes_;
};
};
REGISTER_KERNEL_BUILDER(Name("DirectedInterleaveDataset").Device(DEVICE_CPU),
DirectedInterleaveDatasetOp);
REGISTER_KERNEL_BUILDER(
Name("ExperimentalDirectedInterleaveDataset").Device(DEVICE_CPU),
DirectedInterleaveDatasetOp);
} // namespace
} // namespace experimental
} // namespace data
} // namespace tensorflow