| /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
| |
| Licensed under the Apache License, Version 2.0 (the "License"); |
| you may not use this file except in compliance with the License. |
| You may obtain a copy of the License at |
| |
| http://www.apache.org/licenses/LICENSE-2.0 |
| |
| Unless required by applicable law or agreed to in writing, software |
| distributed under the License is distributed on an "AS IS" BASIS, |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| See the License for the specific language governing permissions and |
| limitations under the License. |
| ==============================================================================*/ |
| |
| #include "tensorflow/core/framework/dataset.h" |
| |
| #include "rdkafkacpp.h" |
| |
| namespace tensorflow { |
| |
| class KafkaDatasetOp : public DatasetOpKernel { |
| public: |
| using DatasetOpKernel::DatasetOpKernel; |
| |
| void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override { |
| const Tensor* topics_tensor; |
| OP_REQUIRES_OK(ctx, ctx->input("topics", &topics_tensor)); |
| OP_REQUIRES( |
| ctx, topics_tensor->dims() <= 1, |
| errors::InvalidArgument("`topics` must be a scalar or a vector.")); |
| |
| std::vector<string> topics; |
| topics.reserve(topics_tensor->NumElements()); |
| for (int i = 0; i < topics_tensor->NumElements(); ++i) { |
| topics.push_back(topics_tensor->flat<tstring>()(i)); |
| } |
| |
| std::string servers = ""; |
| OP_REQUIRES_OK( |
| ctx, data::ParseScalarArgument<std::string>(ctx, "servers", &servers)); |
| std::string group = ""; |
| OP_REQUIRES_OK( |
| ctx, data::ParseScalarArgument<std::string>(ctx, "group", &group)); |
| bool eof = false; |
| OP_REQUIRES_OK(ctx, data::ParseScalarArgument<bool>(ctx, "eof", &eof)); |
| int64 timeout = -1; |
| OP_REQUIRES_OK(ctx, |
| data::ParseScalarArgument<int64>(ctx, "timeout", &timeout)); |
| OP_REQUIRES(ctx, (timeout > 0), |
| errors::InvalidArgument( |
| "Timeout value should be large than 0, got ", timeout)); |
| *output = new Dataset(ctx, std::move(topics), servers, group, eof, timeout); |
| } |
| |
| private: |
| class Dataset : public DatasetBase { |
| public: |
| Dataset(OpKernelContext* ctx, std::vector<string> topics, |
| const string& servers, const string& group, const bool eof, |
| const int64 timeout) |
| : DatasetBase(DatasetContext(ctx)), |
| topics_(std::move(topics)), |
| servers_(servers), |
| group_(group), |
| eof_(eof), |
| timeout_(timeout) {} |
| |
| std::unique_ptr<IteratorBase> MakeIteratorInternal( |
| const string& prefix) const override { |
| return std::unique_ptr<IteratorBase>( |
| new Iterator({this, strings::StrCat(prefix, "::Kafka")})); |
| } |
| |
| const DataTypeVector& output_dtypes() const override { |
| static DataTypeVector* dtypes = new DataTypeVector({DT_STRING}); |
| return *dtypes; |
| } |
| |
| const std::vector<PartialTensorShape>& output_shapes() const override { |
| static std::vector<PartialTensorShape>* shapes = |
| new std::vector<PartialTensorShape>({{}}); |
| return *shapes; |
| } |
| |
| string DebugString() const override { return "KafkaDatasetOp::Dataset"; } |
| |
| protected: |
| Status AsGraphDefInternal(SerializationContext* ctx, |
| DatasetGraphDefBuilder* b, |
| Node** output) const override { |
| Node* topics = nullptr; |
| TF_RETURN_IF_ERROR(b->AddVector(topics_, &topics)); |
| Node* servers = nullptr; |
| TF_RETURN_IF_ERROR(b->AddScalar(servers_, &servers)); |
| Node* group = nullptr; |
| TF_RETURN_IF_ERROR(b->AddScalar(group_, &group)); |
| Node* eof = nullptr; |
| TF_RETURN_IF_ERROR(b->AddScalar(eof_, &eof)); |
| Node* timeout = nullptr; |
| TF_RETURN_IF_ERROR(b->AddScalar(timeout_, &timeout)); |
| TF_RETURN_IF_ERROR( |
| b->AddDataset(this, {topics, servers, group, eof, timeout}, output)); |
| return Status::OK(); |
| } |
| |
| private: |
| class Iterator : public DatasetIterator<Dataset> { |
| public: |
| explicit Iterator(const Params& params) |
| : DatasetIterator<Dataset>(params) {} |
| |
| Status GetNextInternal(IteratorContext* ctx, |
| std::vector<Tensor>* out_tensors, |
| bool* end_of_sequence) override { |
| mutex_lock l(mu_); |
| do { |
| // We are currently processing a topic, so try to read the next line. |
| if (consumer_.get()) { |
| while (true) { |
| if (limit_ >= 0 && |
| (topic_partition_->offset() >= limit_ || offset_ >= limit_)) { |
| // EOF current topic |
| break; |
| } |
| std::unique_ptr<RdKafka::Message> message( |
| consumer_->consume(dataset()->timeout_)); |
| if (message->err() == RdKafka::ERR_NO_ERROR) { |
| // Produce the line as output. |
| Tensor line_tensor(cpu_allocator(), DT_STRING, {}); |
| line_tensor.scalar<tstring>()().assign( |
| static_cast<const char*>(message->payload()), |
| message->len()); |
| out_tensors->emplace_back(std::move(line_tensor)); |
| *end_of_sequence = false; |
| // Sync offset |
| offset_ = message->offset(); |
| return Status::OK(); |
| } |
| |
| if (message->err() == RdKafka::ERR__PARTITION_EOF && |
| dataset()->eof_) { |
| // EOF current topic |
| break; |
| } |
| if (message->err() != RdKafka::ERR__TIMED_OUT) { |
| return errors::Internal("Failed to consume:", |
| message->errstr()); |
| } |
| message.reset(nullptr); |
| consumer_->poll(0); |
| } |
| |
| // We have reached the end of the current topic, so maybe |
| // move on to next topic. |
| ResetStreamsLocked(); |
| ++current_topic_index_; |
| } |
| |
| // Iteration ends when there are no more topic to process. |
| if (current_topic_index_ == dataset()->topics_.size()) { |
| *end_of_sequence = true; |
| return Status::OK(); |
| } |
| |
| TF_RETURN_IF_ERROR(SetupStreamsLocked(ctx->env())); |
| } while (true); |
| } |
| |
| protected: |
| Status SaveInternal(IteratorStateWriter* writer) override { |
| mutex_lock l(mu_); |
| TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("current_topic_index"), |
| current_topic_index_)); |
| |
| // `consumer_` is empty if |
| // 1. GetNext has not been called even once. |
| // 2. All topics have been read and iterator has been exhausted. |
| if (consumer_.get()) { |
| TF_RETURN_IF_ERROR( |
| writer->WriteScalar(full_name("current_pos"), offset_)); |
| } |
| return Status::OK(); |
| } |
| |
| Status RestoreInternal(IteratorContext* ctx, |
| IteratorStateReader* reader) override { |
| mutex_lock l(mu_); |
| ResetStreamsLocked(); |
| int64 current_topic_index; |
| TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("current_topic_index"), |
| ¤t_topic_index)); |
| current_topic_index_ = size_t(current_topic_index); |
| // The key "current_pos" is written only if the iterator was saved |
| // with an open topic. |
| if (reader->Contains(full_name("current_pos"))) { |
| int64 current_pos; |
| TF_RETURN_IF_ERROR( |
| reader->ReadScalar(full_name("current_pos"), ¤t_pos)); |
| |
| TF_RETURN_IF_ERROR(SetupStreamsLocked(ctx->env())); |
| topic_partition_->set_offset(current_pos); |
| if (topic_partition_->offset() != current_pos) { |
| return errors::Internal("Failed to restore to offset ", |
| current_pos); |
| } |
| offset_ = current_pos; |
| } |
| return Status::OK(); |
| } |
| |
| private: |
| // Sets up Kafka streams to read from the topic at |
| // `current_topic_index_`. |
| Status SetupStreamsLocked(Env* env) EXCLUSIVE_LOCKS_REQUIRED(mu_) { |
| if (current_topic_index_ >= dataset()->topics_.size()) { |
| return errors::InvalidArgument( |
| "current_topic_index_:", current_topic_index_, |
| " >= topics_.size():", dataset()->topics_.size()); |
| } |
| |
| // Actually move on to next topic. |
| string entry = dataset()->topics_[current_topic_index_]; |
| |
| std::vector<string> parts = str_util::Split(entry, ":"); |
| if (parts.size() < 1) { |
| return errors::InvalidArgument("Invalid parameters: ", entry); |
| } |
| string topic = parts[0]; |
| int32 partition = 0; |
| if (parts.size() > 1) { |
| if (!strings::safe_strto32(parts[1], &partition)) { |
| return errors::InvalidArgument("Invalid parameters: ", entry); |
| } |
| } |
| int64 offset = 0; |
| if (parts.size() > 2) { |
| if (!strings::safe_strto64(parts[2], &offset)) { |
| return errors::InvalidArgument("Invalid parameters: ", entry); |
| } |
| } |
| |
| topic_partition_.reset( |
| RdKafka::TopicPartition::create(topic, partition, offset)); |
| |
| offset_ = topic_partition_->offset(); |
| limit_ = -1; |
| if (parts.size() > 3) { |
| if (!strings::safe_strto64(parts[3], &limit_)) { |
| return errors::InvalidArgument("Invalid parameters: ", entry); |
| } |
| } |
| |
| std::unique_ptr<RdKafka::Conf> conf( |
| RdKafka::Conf::create(RdKafka::Conf::CONF_GLOBAL)); |
| std::unique_ptr<RdKafka::Conf> topic_conf( |
| RdKafka::Conf::create(RdKafka::Conf::CONF_TOPIC)); |
| |
| std::string errstr; |
| |
| RdKafka::Conf::ConfResult result = |
| conf->set("default_topic_conf", topic_conf.get(), errstr); |
| if (result != RdKafka::Conf::CONF_OK) { |
| return errors::Internal("Failed to set default_topic_conf:", errstr); |
| } |
| |
| result = conf->set("bootstrap.servers", dataset()->servers_, errstr); |
| if (result != RdKafka::Conf::CONF_OK) { |
| return errors::Internal("Failed to set bootstrap.servers ", |
| dataset()->servers_, ":", errstr); |
| } |
| result = conf->set("group.id", dataset()->group_, errstr); |
| if (result != RdKafka::Conf::CONF_OK) { |
| return errors::Internal("Failed to set group.id ", dataset()->group_, |
| ":", errstr); |
| } |
| |
| consumer_.reset(RdKafka::KafkaConsumer::create(conf.get(), errstr)); |
| if (!consumer_.get()) { |
| return errors::Internal("Failed to create consumer:", errstr); |
| } |
| |
| std::vector<RdKafka::TopicPartition*> partitions; |
| partitions.emplace_back(topic_partition_.get()); |
| RdKafka::ErrorCode err = consumer_->assign(partitions); |
| if (err != RdKafka::ERR_NO_ERROR) { |
| return errors::Internal( |
| "Failed to assign partition [", topic_partition_->topic(), ", ", |
| topic_partition_->partition(), ", ", topic_partition_->offset(), |
| "]:", RdKafka::err2str(err)); |
| } |
| |
| return Status::OK(); |
| } |
| |
| // Resets all Kafka streams. |
| void ResetStreamsLocked() EXCLUSIVE_LOCKS_REQUIRED(mu_) { |
| consumer_->unassign(); |
| consumer_->close(); |
| consumer_.reset(nullptr); |
| } |
| |
| mutex mu_; |
| size_t current_topic_index_ GUARDED_BY(mu_) = 0; |
| int64 offset_ GUARDED_BY(mu_) = 0; |
| int64 limit_ GUARDED_BY(mu_) = -1; |
| std::unique_ptr<RdKafka::TopicPartition> topic_partition_ GUARDED_BY(mu_); |
| std::unique_ptr<RdKafka::KafkaConsumer> consumer_ GUARDED_BY(mu_); |
| }; |
| |
| const std::vector<string> topics_; |
| const std::string servers_; |
| const std::string group_; |
| const bool eof_; |
| const int64 timeout_; |
| }; |
| }; |
| |
| REGISTER_KERNEL_BUILDER(Name("KafkaDataset").Device(DEVICE_CPU), |
| KafkaDatasetOp); |
| |
| } // namespace tensorflow |