| /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. |
| |
| Licensed under the Apache License, Version 2.0 (the "License"); |
| you may not use this file except in compliance with the License. |
| You may obtain a copy of the License at |
| |
| http://www.apache.org/licenses/LICENSE-2.0 |
| |
| Unless required by applicable law or agreed to in writing, software |
| distributed under the License is distributed on an "AS IS" BASIS, |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| See the License for the specific language governing permissions and |
| limitations under the License. |
| ==============================================================================*/ |
| #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_EAGER_OPERATION_H_ |
| #define TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_EAGER_OPERATION_H_ |
| |
| #include "absl/container/inlined_vector.h" |
| #include "absl/types/optional.h" |
| #include "absl/types/variant.h" |
| #include "tensorflow/core/common_runtime/eager/attr_builder.h" |
| #include "tensorflow/core/common_runtime/eager/context.h" |
| #include "tensorflow/core/common_runtime/eager/eager_executor.h" |
| #include "tensorflow/core/common_runtime/eager/kernel_and_device.h" |
| #include "tensorflow/core/common_runtime/eager/tensor_handle.h" |
| #include "tensorflow/core/framework/cancellation.h" |
| #include "tensorflow/core/util/device_name_utils.h" |
| |
| namespace tensorflow { |
| |
| class EagerOperation { |
| public: |
| explicit EagerOperation(tensorflow::EagerContext* ctx) : ctx_(*ctx) {} |
| ~EagerOperation() { |
| for (TensorHandle* h : inputs_) { |
| h->Unref(); |
| } |
| } |
| |
| // An EagerOperation object can be reused for a different op by calling |
| // Clear(), and then Reset(...) with the same arguments that would have |
| // been provided to the constructor. |
| void Clear() { |
| for (TensorHandle* h : inputs_) { |
| h->Unref(); |
| } |
| inputs_.clear(); |
| ClearInferenceState(); |
| } |
| |
| Status Reset(const char* op, const char* raw_device_name, bool remote, |
| EagerExecutor* executor, |
| const absl::optional<EagerRemoteFunctionParams> |
| remote_func_params = absl::nullopt); |
| |
| bool is_function() const { return is_function_; } |
| bool colocation_exempt() const { return colocation_exempt_; } |
| |
| tensorflow::EagerContext& EagerContext() { return ctx_; } |
| const tensorflow::EagerContext& EagerContext() const { return ctx_; } |
| |
| AttrBuilder* MutableAttrs() { return &attrs_; } |
| const AttrBuilder& Attrs() const { return attrs_; } |
| const tensorflow::OpDef* OpDef() const { return op_def_; } |
| |
| const absl::InlinedVector<TensorHandle*, 4>& Inputs() const { |
| return inputs_; |
| } |
| absl::InlinedVector<TensorHandle*, 4>* MutableInputs() { return &inputs_; } |
| |
| void AddInput(TensorHandle* h); |
| void UpdateInput(int i, TensorHandle* h); |
| |
| const string& Name() const { return attrs_.op_name(); } |
| const AttrTypeMap* AttrTypes() const { return attr_types_; } |
| |
| // Like TensorHandles, EagerOperations may be placed either on a virtual |
| // CustomDevice or on a physical Device. |
| absl::variant<tensorflow::Device*, tensorflow::CustomDevice*> Device() const { |
| return device_; |
| } |
| |
| void SetDevice(tensorflow::Device* device) { |
| device_ = device; |
| raw_device_name_.clear(); |
| device_name_ = device->name(); |
| device_parsed_name_ = device->parsed_name(); |
| } |
| |
| void SetDevice(tensorflow::CustomDevice* device) { |
| device_ = device; |
| raw_device_name_.clear(); |
| device_name_ = device->name(); |
| DeviceNameUtils::ParseFullName(device_name_, &device_parsed_name_); |
| } |
| |
| const string& GetDeviceName() const { return device_name_; } |
| const DeviceNameUtils::ParsedName& GetDeviceParsedName() const { |
| return device_parsed_name_; |
| } |
| Status SetDeviceName(const char* device); |
| |
| // Indicates whether the op is assigned to a device that is local to the |
| // current host. |
| bool IsLocal() const; |
| |
| void SetUseXla(bool use_xla) { use_xla_ = use_xla; } |
| |
| CancellationManager* GetCancellationManager() const { |
| return cancellation_manager_; |
| } |
| void SetCancellationManager(CancellationManager* cancellation_manager) { |
| cancellation_manager_ = cancellation_manager; |
| } |
| |
| EagerExecutor& Executor() { return *executor_; } |
| |
| string DebugString() const; |
| |
| const absl::optional<EagerRemoteFunctionParams>& remote_func_params() const { |
| return remote_func_params_; |
| } |
| |
| // Op name recorded for memory debugging purpose. |
| const char* op_name() const { return op_name_; } |
| const char* op_name_ = nullptr; |
| |
| Status MaybeInferSingleInputAttrs(TensorHandle* handle); |
| Status InferInputListAttrs(int num_inputs); |
| |
| private: |
| void ClearInferenceState() { |
| op_def_ = nullptr; |
| inference_arg_idx_ = 0; |
| inference_attrs_.clear_no_resize(); |
| } |
| void InferSingleTypeInputListAttrs(const OpDef::ArgDef& input_def, |
| const DataType dtype, int num_inputs); |
| void InferMixedTypeInputListAttrs(const OpDef::ArgDef& input_def, |
| const std::vector<DataType>& dtypes); |
| |
| tensorflow::EagerContext& ctx_; |
| AttrBuilder attrs_; |
| const AttrTypeMap* attr_types_; |
| absl::InlinedVector<TensorHandle*, 4> inputs_; |
| absl::variant<tensorflow::Device*, tensorflow::CustomDevice*> device_; |
| string raw_device_name_; |
| string device_name_; |
| DeviceNameUtils::ParsedName device_parsed_name_; |
| bool use_xla_ = false; |
| bool is_function_; // Conceptually const, but can't be because of Reset |
| bool colocation_exempt_; |
| CancellationManager* cancellation_manager_ = nullptr; // Not owned. |
| EagerExecutor* executor_; // Not owned. |
| absl::optional<EagerRemoteFunctionParams> remote_func_params_; |
| |
| // Inference information |
| const tensorflow::OpDef* op_def_; // op definition from protobuf |
| int inference_arg_idx_; // arg definition index for the next input to be |
| // added |
| gtl::FlatSet<std::string> inference_attrs_; // attributes inferred so far |
| }; |
| |
| inline void EagerOperation::AddInput(TensorHandle* h) { |
| h->Ref(); |
| inputs_.push_back(h); |
| attrs_.NumInputs(static_cast<int>(inputs_.size())); |
| } |
| |
| inline void EagerOperation::UpdateInput(int i, TensorHandle* h) { |
| TensorHandle** slot = &inputs_[i]; |
| TensorHandle* existing = *slot; |
| if (existing != h) { |
| h->Ref(); |
| existing->Unref(); |
| *slot = h; // Update inputs_[i] to h |
| } |
| } |
| } // namespace tensorflow |
| |
| #endif // TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_EAGER_OPERATION_H_ |