blob: 5fd4633c6a61eb94ed6e4ac6a592975a763fd93e [file] [log] [blame]
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_KERNELS_DATA_CAPTURED_FUNCTION_H_
#define TENSORFLOW_CORE_KERNELS_DATA_CAPTURED_FUNCTION_H_
#include <memory>
#include <vector>
#include "tensorflow/core/framework/cancellation.h"
#include "tensorflow/core/framework/dataset.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/lib/random/random.h"
#include "tensorflow/core/platform/macros.h"
namespace tensorflow {
class Device;
class OpKernelContext;
class ResourceMgr;
namespace data {
class CapturedFunction;
class InstantiatedCapturedFunction;
Status MakeIteratorFromInputElement(
IteratorContext* ctx, const std::vector<Tensor>& input_element,
int64 thread_index, const InstantiatedCapturedFunction& inst_captured_func,
StringPiece prefix, std::unique_ptr<IteratorBase>* out_iterator);
// `InstantiatedCapturedFunction` encapsulates all the runtime support needed
// to execute a tensorflow function.
//
// While `CapturedFunction` encapsulates constant attributes of the function,
// such as its name and captured arguments, `InstantiatedCapturedFunction`
// encapsulates runtime aspects, such as `FunctionLibraryRuntime` and function
// handle.
//
// The `Iterator` related classes use `InstantiatedCapturedFunction` to execute
// functions outside of the normal `OpKernel::Compute()` context.
class InstantiatedCapturedFunction {
public:
~InstantiatedCapturedFunction();
// Runs the instantiated captured function. This method takes ownership of
// the tensors in `args`, in order to be able to deallocate them as early as
// possible. Use `RunWithBorrowedArgs()` if the caller needs to retain
// ownership of the `args`.
Status Run(IteratorContext* ctx, std::vector<Tensor>&& args,
std::vector<Tensor>* rets) const;
// Synchronously runs the captured function on the given `args`, and stores
// the results in `*rets`. Prefer to use `Run()` or `RunAsync()` when
// possible.
Status RunWithBorrowedArgs(IteratorContext* ctx,
const std::vector<Tensor>& args,
std::vector<Tensor>* rets) const;
// Synchronously runs the captured function on the given `args`, and stores
// the results in `*rets`. Prefer to use `Run()` or `RunAsync()` when
// possible. This can be useful for calling a captured
// function in cases where an `IteratorContext*` is not available
// (such as a destructor).
Status RunInstantiated(const std::vector<Tensor>& args,
std::vector<Tensor>* rets);
// Asynchronously runs the captured function on the given `args`, stores
// the results in `*rets`, and calls the given `done` callback when the
// function returns. This method takes ownership of the tensors in `args`,
// in order to be able to deallocate them as early as possible.
void RunAsync(IteratorContext* ctx, std::vector<Tensor>&& args,
std::vector<Tensor>* rets,
FunctionLibraryRuntime::DoneCallback done,
const string& prefix) const;
private:
InstantiatedCapturedFunction(
FunctionLibraryRuntime* lib, FunctionLibraryRuntime::Handle f_handle,
DataTypeVector ret_types,
std::function<void(std::function<void()>)> runner,
CancellationManager* cancellation_manager,
CapturedFunction* captured_func);
// Determines whether a rendezvous object should be created when running the
// instantiated function.
bool ShouldCreateRendezvous() const;
friend class CapturedFunction;
FunctionLibraryRuntime* const lib_;
const FunctionLibraryRuntime::Handle f_handle_;
const DataTypeVector ret_types_;
std::function<void(std::function<void()>)> captured_runner_;
CancellationManager* cancellation_manager_;
CapturedFunction* const captured_func_;
TF_DISALLOW_COPY_AND_ASSIGN(InstantiatedCapturedFunction);
};
struct ShortCircuitInfo {
std::vector<int> indices;
std::vector<bool> can_move;
};
// Metadata shared across all captures of the same function.
class FunctionMetadata {
public:
struct Params {
bool is_multi_device_function = false;
bool use_inter_op_parallelism = true;
};
// Creates a new instance of the `FunctionMetadata` class, fetching function
// from a context argument.
static Status Create(tensorflow::OpKernelConstruction* ctx,
const string& func_name, Params params,
std::shared_ptr<FunctionMetadata>* out_metadata);
// Creates a new instance of the `FunctionMetadata` class, using the provided
// function.
static Status Create(tensorflow::OpKernelConstruction* ctx,
NameAttrList&& func, Params params,
std::shared_ptr<FunctionMetadata>* out_metadata);
// Returns the named list of function arguments.
const NameAttrList& func() const { return func_; }
// Indicates whether the function is a multi-device function.
bool is_multi_device_function() const { return is_multi_device_function_; }
// Returns a borrowed pointer to the function library that contains the
// transitive closure of definitions used by the function.
const FunctionLibraryDefinition* lib_def() const { return lib_def_.get(); }
// Returns short-circuit information.
const ShortCircuitInfo& short_circuit_info() const {
return short_circuit_info_;
}
// Indicates whether to use inter-op parallelism for execution of the
// function.
bool use_inter_op_parallelism() const { return use_inter_op_parallelism_; }
private:
FunctionMetadata(NameAttrList&& func, Params params)
: func_(std::move(func)),
is_multi_device_function_(params.is_multi_device_function),
use_inter_op_parallelism_(params.use_inter_op_parallelism) {}
void ValidateMultiDevice();
NameAttrList func_;
bool is_multi_device_function_ = false;
std::unique_ptr<FunctionLibraryDefinition> lib_def_ = nullptr;
ShortCircuitInfo short_circuit_info_;
bool use_inter_op_parallelism_ = true;
};
// A `CapturedFunction` encapsulates a TensorFlow function, plus any "captured"
// arguments that it closed over in the user program.
class CapturedFunction {
public:
// Creates a new instance using a list of named attributes, fetching captured
// inputs from a context argument.
static Status Create(OpKernelContext* ctx,
const std::shared_ptr<const FunctionMetadata> metadata,
const string& argument_name,
std::unique_ptr<CapturedFunction>* out_function);
// Creates a new instance using a list of named attributes, using provided
// captured inputs.
static Status Create(OpKernelContext* ctx,
const std::shared_ptr<const FunctionMetadata> metadata,
std::vector<Tensor>&& captured_inputs,
std::unique_ptr<CapturedFunction>* out_function);
// Adds the definition of this captured function into the given graph,
// returning its captured inputs and types through the respective output
// arguments.
Status AddToGraph(SerializationContext* ctx,
DatasetBase::DatasetGraphDefBuilder* b,
std::vector<Node*>* other_arguments,
DataTypeVector* other_arguments_types) const;
// Instantiates this function for use in the given context, providing an
// InstantiatedCapturedFunction that can be used to execute functions.
Status Instantiate(IteratorContext* ctx,
std::unique_ptr<InstantiatedCapturedFunction>*
instantiated_captured_function);
// Determines whether the captured function is stateful.
//
// TODO(jsimsa): Remove this method once all users of `CapturedFunction`
// migrate to `CheckExternalState`.
bool IsStateful() const;
// Determines whether the captured function is stateful.
Status CheckExternalState() const;
// Returns the additional captured inputs that will be passed to the function.
const std::vector<Tensor>& captured_inputs() const {
return captured_inputs_;
}
// Returns the named list of function arguments.
const NameAttrList& func() const { return metadata_->func(); }
// Indicates whether the function is multi-device.
bool is_multi_device_function() const {
return metadata_->is_multi_device_function();
}
// Returns the transitive set of function definition required to instantiate
// this function.
const FunctionLibraryDefinition* lib_def() const {
return metadata_->lib_def();
}
// If every function output corresponds to one of its inputs, the method
// returns the mapping from output indices to input indices. Otherwise, it
// returns an empty list.
const ShortCircuitInfo& short_circuit_info() const {
return metadata_->short_circuit_info();
}
// Indicates whether the function should use inter op parallelism.
bool use_inter_op_parallelism() const {
return metadata_->use_inter_op_parallelism();
}
private:
CapturedFunction(const std::shared_ptr<const FunctionMetadata> metadata,
std::vector<Tensor> captured_inputs);
const std::shared_ptr<const FunctionMetadata> metadata_;
const std::vector<Tensor> captured_inputs_;
TF_DISALLOW_COPY_AND_ASSIGN(CapturedFunction);
};
} // namespace data
// TODO(b/114112161): Remove these aliases when all users have moved over to the
// `tensorflow::data` namespace.
using data::CapturedFunction;
} // namespace tensorflow
#endif // TENSORFLOW_CORE_KERNELS_DATA_CAPTURED_FUNCTION_H_