Add overload to MakeIteratorFromInputElement
diff --git a/tensorflow/core/kernels/data/captured_function.cc b/tensorflow/core/kernels/data/captured_function.cc
index 7d929a6..2192158 100644
--- a/tensorflow/core/kernels/data/captured_function.cc
+++ b/tensorflow/core/kernels/data/captured_function.cc
@@ -447,6 +447,16 @@
IteratorContext* ctx, const IteratorBase* parent,
const std::vector<Tensor>& input_element, int64 thread_index,
const InstantiatedCapturedFunction& inst_captured_func, StringPiece prefix,
+ std::unique_ptr<IteratorBase>* out_iterator) {
+ return MakeIteratorFromInputElement(
+ ctx, parent, input_element, thread_index, inst_captured_func, prefix,
+ /*node=*/nullptr);
+}
+
+Status MakeIteratorFromInputElement(
+ IteratorContext* ctx, const IteratorBase* parent,
+ const std::vector<Tensor>& input_element, int64 thread_index,
+ const InstantiatedCapturedFunction& inst_captured_func, StringPiece prefix,
std::unique_ptr<IteratorBase>* out_iterator,
const std::shared_ptr<model::Node>& node) {
std::vector<Tensor> return_values;
diff --git a/tensorflow/core/kernels/data/captured_function.h b/tensorflow/core/kernels/data/captured_function.h
index fbfcc4c..fc17a31 100644
--- a/tensorflow/core/kernels/data/captured_function.h
+++ b/tensorflow/core/kernels/data/captured_function.h
@@ -46,6 +46,15 @@
IteratorContext* ctx, const IteratorBase* parent,
const std::vector<Tensor>& input_element, int64 thread_index,
const InstantiatedCapturedFunction& inst_captured_func, StringPiece prefix,
+ std::unique_ptr<IteratorBase>* out_iterator);
+
+// Creates an iterator for a dataset which is created by applying the given
+// function to the given input element. Pass non-null `node` to record
+// processing time for modeling Iterator's GetNext() resource usage.
+Status MakeIteratorFromInputElement(
+ IteratorContext* ctx, const IteratorBase* parent,
+ const std::vector<Tensor>& input_element, int64 thread_index,
+ const InstantiatedCapturedFunction& inst_captured_func, StringPiece prefix,
std::unique_ptr<IteratorBase>* out_iterator,
const std::shared_ptr<model::Node>& node);