blob: 8b417bb1c2804a64b75b25843172a7f8ef9904b4 [file] [log] [blame]
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <map>
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/framework/partial_tensor_shape.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/kernels/data/captured_function.h"
#include "tensorflow/core/kernels/data/dataset.h"
#include "tensorflow/core/kernels/data/window_dataset.h"
#include "tensorflow/core/lib/random/random.h"
namespace tensorflow {
namespace data {
namespace {
// See documentation in ../ops/dataset_ops.cc for a high-level
// description of the following op.
class GroupByWindowDatasetOp : public UnaryDatasetOpKernel {
public:
explicit GroupByWindowDatasetOp(OpKernelConstruction* ctx)
: UnaryDatasetOpKernel(ctx),
graph_def_version_(ctx->graph_def_version()) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("key_func", &key_func_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("reduce_func", &reduce_func_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("window_size_func", &window_size_func_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
}
void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
DatasetBase** output) override {
std::unique_ptr<CapturedFunction> captured_key_func;
OP_REQUIRES_OK(ctx, CapturedFunction::Create(key_func_, ctx,
"key_func_other_arguments",
&captured_key_func));
std::unique_ptr<CapturedFunction> captured_reduce_func;
OP_REQUIRES_OK(ctx, CapturedFunction::Create(reduce_func_, ctx,
"reduce_func_other_arguments",
&captured_reduce_func));
std::unique_ptr<CapturedFunction> captured_window_size_func;
OP_REQUIRES_OK(ctx,
CapturedFunction::Create(window_size_func_, ctx,
"window_size_func_other_arguments",
&captured_window_size_func));
*output = new Dataset(
ctx, input, key_func_, reduce_func_, window_size_func_,
std::move(captured_key_func), std::move(captured_reduce_func),
std::move(captured_window_size_func), output_types_, output_shapes_);
}
private:
class Dataset : public DatasetBase {
public:
Dataset(OpKernelContext* ctx, const DatasetBase* input,
const NameAttrList& key_func, const NameAttrList& reduce_func,
const NameAttrList& window_size_func,
std::unique_ptr<CapturedFunction> captured_key_func,
std::unique_ptr<CapturedFunction> captured_reduce_func,
std::unique_ptr<CapturedFunction> captured_window_size_func,
const DataTypeVector& output_types,
const std::vector<PartialTensorShape>& output_shapes)
: DatasetBase(DatasetContext(ctx)),
input_(input),
key_func_(key_func),
reduce_func_(reduce_func),
window_size_func_(window_size_func),
captured_key_func_(std::move(captured_key_func)),
captured_reduce_func_(std::move(captured_reduce_func)),
captured_window_size_func_(std::move(captured_window_size_func)),
output_types_(output_types),
output_shapes_(output_shapes) {
input_->Ref();
}
~Dataset() override { input_->Unref(); }
std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
return std::unique_ptr<IteratorBase>(
new Iterator({this, strings::StrCat(prefix, "::GroupByWindow")}));
}
const DataTypeVector& output_dtypes() const override {
return output_types_;
}
const std::vector<PartialTensorShape>& output_shapes() const override {
return output_shapes_;
}
string DebugString() const override {
return "GroupByWindowDatasetOp::Dataset";
}
protected:
Status AsGraphDefInternal(SerializationContext* ctx,
DatasetGraphDefBuilder* b,
Node** output) const override {
TF_RETURN_IF_ERROR(b->AddFunction(ctx, key_func_.name()));
TF_RETURN_IF_ERROR(b->AddFunction(ctx, reduce_func_.name()));
TF_RETURN_IF_ERROR(b->AddFunction(ctx, window_size_func_.name()));
Node* input_graph_node = nullptr;
TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
std::vector<Node*> key_func_other_arguments_node;
DataTypeVector key_func_other_arguments_types;
TF_RETURN_IF_ERROR(OtherArgumentsNodeAndType(
b, captured_key_func_, &key_func_other_arguments_node,
&key_func_other_arguments_types));
std::vector<Node*> reduce_func_other_arguments_node;
DataTypeVector reduce_func_other_arguments_types;
TF_RETURN_IF_ERROR(OtherArgumentsNodeAndType(
b, captured_reduce_func_, &reduce_func_other_arguments_node,
&reduce_func_other_arguments_types));
std::vector<Node*> window_size_func_other_arguments_node;
DataTypeVector window_size_func_other_arguments_types;
TF_RETURN_IF_ERROR(OtherArgumentsNodeAndType(
b, captured_window_size_func_, &window_size_func_other_arguments_node,
&window_size_func_other_arguments_types));
AttrValue key_func;
b->BuildAttrValue(key_func_, &key_func);
AttrValue reduce_func;
b->BuildAttrValue(reduce_func_, &reduce_func);
AttrValue window_size_func;
b->BuildAttrValue(window_size_func_, &window_size_func);
AttrValue key_func_other_arguments_types_attr;
b->BuildAttrValue(key_func_other_arguments_types,
&key_func_other_arguments_types_attr);
AttrValue reduce_func_other_arguments_types_attr;
b->BuildAttrValue(reduce_func_other_arguments_types,
&reduce_func_other_arguments_types_attr);
AttrValue window_size_func_other_arguments_types_attr;
b->BuildAttrValue(window_size_func_other_arguments_types,
&window_size_func_other_arguments_types_attr);
TF_RETURN_IF_ERROR(b->AddDataset(
this, {{0, input_graph_node}},
{{1, key_func_other_arguments_node},
{2, reduce_func_other_arguments_node},
{3, window_size_func_other_arguments_node}},
{{"key_func", key_func},
{"reduce_func", reduce_func},
{"window_size_func", window_size_func},
{"Tkey_func_other_arguments", key_func_other_arguments_types_attr},
{"Treduce_func_other_arguments",
reduce_func_other_arguments_types_attr},
{"Twindow_size_func_other_arguments",
window_size_func_other_arguments_types_attr}},
output));
return Status::OK();
}
private:
class Iterator : public DatasetIterator<Dataset> {
public:
explicit Iterator(const Params& params)
: DatasetIterator<Dataset>(params) {}
Status Initialize(IteratorContext* ctx) override {
TF_RETURN_IF_ERROR(
dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_));
TF_RETURN_IF_ERROR(dataset()->captured_key_func_->Instantiate(ctx));
TF_RETURN_IF_ERROR(dataset()->captured_reduce_func_->Instantiate(ctx));
TF_RETURN_IF_ERROR(
dataset()->captured_window_size_func_->Instantiate(ctx));
return Status::OK();
}
Status GetNextInternal(IteratorContext* ctx,
std::vector<Tensor>* out_tensors,
bool* end_of_sequence) override {
mutex_lock l(mu_);
do {
if (current_group_iterator_) {
// We are currently processing a group, so try to get the
// next element.
bool end_of_group;
TF_RETURN_IF_ERROR(current_group_iterator_->GetNext(
ctx, out_tensors, &end_of_group));
if (!end_of_group) {
// Produce the subelement as output.
*end_of_sequence = false;
return Status::OK();
}
// We have reached the end of the current group, so maybe move on
// to the next group.
current_group_iterator_.reset();
groups_.erase(current_key_);
}
// Iterate through the input dataset until we get a full
// group, or reach the end.
while (!end_of_input_) {
std::vector<Tensor> next_input_element;
TF_RETURN_IF_ERROR(
input_impl_->GetNext(ctx, &next_input_element, &end_of_input_));
if (!end_of_input_) {
// Run the key function on the input element to identify its
// group.
std::vector<Tensor> key_func_output;
TF_RETURN_IF_ERROR(
dataset()->captured_key_func_->RunWithBorrowedArgs(
ctx, next_input_element, &key_func_output));
if (key_func_output.size() != 1 ||
key_func_output[0].dtype() != DT_INT64 ||
key_func_output[0].NumElements() != 1) {
// TODO(b/78665031): Support non-int64 keys.
return errors::InvalidArgument(
"`key_func` must return a scalar int64.");
}
const int64 key = key_func_output[0].scalar<int64>()();
if (window_sizes_.find(key) == window_sizes_.end()) {
// Run the window size function on the key to identify its
// window size.
std::vector<Tensor> window_size_func_output;
TF_RETURN_IF_ERROR(dataset()->captured_window_size_func_->Run(
ctx, std::move(key_func_output), &window_size_func_output));
if (window_size_func_output.size() != 1 ||
window_size_func_output[0].dtype() != DT_INT64 ||
window_size_func_output[0].NumElements() != 1) {
// TODO(mrry): Support non-int64 window sizes.
return errors::InvalidArgument(
"`window_size_func` must return a scalar int64.");
}
const int64 window_size =
window_size_func_output[0].scalar<int64>()();
if (window_size <= 0) {
return errors::InvalidArgument(
"Window size must be greater than zero, but got ",
window_size, ".");
}
window_sizes_[key] = window_size;
}
const int64 window_size = window_sizes_[key];
std::vector<std::vector<Tensor>>& group = groups_[key];
group.push_back(std::move(next_input_element));
if (group.size() == window_size) {
current_key_ = key;
TF_RETURN_IF_ERROR(StartFlushingGroup(ctx, key));
break;
}
}
}
if (end_of_input_) {
if (!groups_.empty()) {
// We have consumed all of the input, so flush an
// arbitrarily chosen group.
current_key_ = groups_.begin()->first;
TF_RETURN_IF_ERROR(
StartFlushingGroup(ctx, groups_.begin()->first));
}
}
} while (current_group_iterator_ || !end_of_input_);
*end_of_sequence = true;
return Status::OK();
}
protected:
Status SaveInternal(IteratorStateWriter* writer) override {
mutex_lock l(mu_);
TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_));
if (end_of_input_) {
TF_RETURN_IF_ERROR(
writer->WriteScalar(full_name("end_of_input"), ""));
}
// Saving groups_
if (!groups_.empty()) {
TF_RETURN_IF_ERROR(
writer->WriteScalar(full_name("groups_size"), groups_.size()));
int idx = 0;
for (auto it = groups_.begin(); it != groups_.end(); it++) {
int64 key = it->first;
TF_RETURN_IF_ERROR(writer->WriteScalar(
full_name(strings::StrCat("groups_[", idx, "]->key")), key));
TF_RETURN_IF_ERROR(SaveGroup(
writer, full_name(strings::StrCat("groups_[", idx, "]")),
it->second));
idx++;
}
}
// Saving window_sizes_
if (!window_sizes_.empty()) {
TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("window_sizes_size"),
window_sizes_.size()));
int idx = 0;
for (auto it = window_sizes_.begin(); it != window_sizes_.end();
it++) {
TF_RETURN_IF_ERROR(writer->WriteScalar(
full_name(strings::StrCat("window_sizes_[", idx, "]->key")),
it->first));
TF_RETURN_IF_ERROR(writer->WriteScalar(
full_name(strings::StrCat("window_sizes_[", idx, "]->value")),
it->second));
idx++;
}
}
if (current_group_iterator_) {
TF_RETURN_IF_ERROR(SaveInput(writer, current_group_iterator_));
// Saving current_key_
TF_RETURN_IF_ERROR(
writer->WriteScalar(full_name("current_key"), current_key_));
} else {
TF_RETURN_IF_ERROR(writer->WriteScalar(
full_name("current_iterator_not_initialized"), ""));
}
return Status::OK();
}
Status RestoreInternal(IteratorContext* ctx,
IteratorStateReader* reader) override {
mutex_lock l(mu_);
TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
if (reader->Contains(full_name("end_of_input"))) end_of_input_ = true;
// Restoring groups
if (reader->Contains(full_name("groups_size"))) {
int64 size;
TF_RETURN_IF_ERROR(
reader->ReadScalar(full_name("groups_size"), &size));
for (int idx = 0; idx < size; idx++) {
int64 key;
TF_RETURN_IF_ERROR(reader->ReadScalar(
full_name(strings::StrCat("groups_[", idx, "]->key")), &key));
std::vector<std::vector<Tensor>> group;
TF_RETURN_IF_ERROR(RestoreGroup(
reader, full_name(strings::StrCat("groups_[", idx, "]")),
&group));
groups_[key] = group;
}
}
// Restoring Windows
if (reader->Contains(full_name("window_sizes_size"))) {
int64 size;
TF_RETURN_IF_ERROR(
reader->ReadScalar(full_name("window_sizes_size"), &size));
for (int idx = 0; idx < size; idx++) {
int64 key;
TF_RETURN_IF_ERROR(reader->ReadScalar(
full_name(strings::StrCat("window_sizes_[", idx, "]->key")),
&key));
TF_RETURN_IF_ERROR(reader->ReadScalar(
full_name(strings::StrCat("window_sizes_[", idx, "]->value")),
&window_sizes_[key]));
}
}
if (reader->Contains(full_name("current_iterator_not_initialized"))) {
current_group_iterator_.reset();
} else {
// Restore current_key_
TF_RETURN_IF_ERROR(
reader->ReadScalar(full_name("current_key"), &current_key_));
// Initialize current_group_iterator_
TF_RETURN_IF_ERROR(StartFlushingGroup(ctx, current_key_));
// Restore current_group_iterator_ state
TF_RETURN_IF_ERROR(
RestoreInput(ctx, reader, current_group_iterator_));
}
return Status::OK();
}
private:
Status SaveGroup(IteratorStateWriter* writer, const string& name,
const std::vector<std::vector<Tensor>>& group)
EXCLUSIVE_LOCKS_REQUIRED(mu_) {
TF_RETURN_IF_ERROR(
writer->WriteScalar(strings::StrCat(name, "_size"), group.size()));
for (int i = 0; i < group.size(); i++) {
TF_RETURN_IF_ERROR(writer->WriteScalar(
strings::StrCat(name, "[", i, "]_size"), group[i].size()));
for (int j = 0; j < group[i].size(); j++) {
TF_RETURN_IF_ERROR(writer->WriteTensor(
strings::StrCat(name, "[", i, "][", j, "]"), group[i][j]));
}
}
return Status::OK();
}
Status RestoreGroup(IteratorStateReader* reader, const string& name,
std::vector<std::vector<Tensor>>* group)
EXCLUSIVE_LOCKS_REQUIRED(mu_) {
int64 group_size;
TF_RETURN_IF_ERROR(
reader->ReadScalar(strings::StrCat(name, "_size"), &group_size));
group->resize(group_size);
for (int i = 0; i < group_size; i++) {
int64 vector_size;
TF_RETURN_IF_ERROR(reader->ReadScalar(
strings::StrCat(name, "[", i, "]_size"), &vector_size));
group->at(i).resize(vector_size);
for (int j = 0; j < vector_size; j++) {
TF_RETURN_IF_ERROR(reader->ReadTensor(
strings::StrCat(name, "[", i, "][", j, "]"), &group->at(i)[j]));
}
}
return Status::OK();
}
Status StartFlushingGroup(IteratorContext* ctx, int64 key)
EXCLUSIVE_LOCKS_REQUIRED(mu_) {
DatasetBase* group_dataset;
TF_RETURN_IF_ERROR(NewWindowDataset(
groups_[key], dataset()->input_->output_dtypes(),
dataset()->input_->output_shapes(), &group_dataset));
Tensor key_arg(DT_INT64, TensorShape({}));
key_arg.scalar<int64>()() = key;
Tensor group_dataset_arg(DT_VARIANT, TensorShape({}));
TF_RETURN_IF_ERROR(
StoreDatasetInVariantTensor(group_dataset, &group_dataset_arg));
std::vector<Tensor> args(
{std::move(key_arg), std::move(group_dataset_arg)});
std::vector<Tensor> return_values;
TF_RETURN_IF_ERROR(dataset()->captured_reduce_func_->Run(
ctx, std::move(args), &return_values));
if (!(return_values.size() == 1 &&
return_values[0].dtype() == DT_VARIANT &&
TensorShapeUtils::IsScalar(return_values[0].shape()))) {
return errors::InvalidArgument(
"`reduce_func` must return a single scalar of dtype "
"DT_VARIANT.");
}
// Retrieve the dataset that was created in `f`.
// `returned_dataset` is borrowed from the `return_values[0]`.
DatasetBase* returned_dataset;
TF_RETURN_IF_ERROR(
GetDatasetFromVariantTensor(return_values[0], &returned_dataset));
// Create an iterator for the dataset that was returned by `f`.
return returned_dataset->MakeIterator(ctx, prefix(),
&current_group_iterator_);
}
mutex mu_;
std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_);
// TODO(mrry): Optimize for dense key space if appropriate.
bool end_of_input_ GUARDED_BY(mu_) = false;
int64 current_key_ GUARDED_BY(mu_);
std::map<int64, std::vector<std::vector<Tensor>>> groups_ GUARDED_BY(mu_);
std::unique_ptr<IteratorBase> current_group_iterator_ GUARDED_BY(mu_);
std::map<int64, int64> window_sizes_ GUARDED_BY(mu_);
};
Status OtherArgumentsNodeAndType(
DatasetGraphDefBuilder* b,
const std::unique_ptr<CapturedFunction>& captured_func,
std::vector<Node*>* other_arguments_node,
DataTypeVector* other_arguments_types) const {
other_arguments_node->reserve(captured_func->captured_inputs().size());
other_arguments_types->reserve(captured_func->captured_inputs().size());
for (const Tensor& t : captured_func->captured_inputs()) {
Node* node;
TF_RETURN_IF_ERROR(b->AddTensor(t, &node));
other_arguments_node->emplace_back(node);
other_arguments_types->emplace_back(t.dtype());
}
return Status::OK();
}
const DatasetBase* const input_;
const NameAttrList key_func_;
const NameAttrList reduce_func_;
const NameAttrList window_size_func_;
const std::unique_ptr<CapturedFunction> captured_key_func_;
const std::unique_ptr<CapturedFunction> captured_reduce_func_;
const std::unique_ptr<CapturedFunction> captured_window_size_func_;
const DataTypeVector output_types_;
const std::vector<PartialTensorShape> output_shapes_;
};
const int graph_def_version_;
DataTypeVector output_types_;
std::vector<PartialTensorShape> output_shapes_;
NameAttrList key_func_;
NameAttrList reduce_func_;
NameAttrList window_size_func_;
};
REGISTER_KERNEL_BUILDER(Name("GroupByWindowDataset").Device(DEVICE_CPU),
GroupByWindowDatasetOp);
} // namespace
} // namespace data
} // namespace tensorflow