blob: fb7eb41706392b9beb052c26ca9fffc063e4026d [file] [log] [blame]
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_EAGER_OPERATION_H_
#define TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_EAGER_OPERATION_H_
#include "tensorflow/core/common_runtime/eager/attr_builder.h"
#include "tensorflow/core/common_runtime/eager/context.h"
#include "tensorflow/core/common_runtime/eager/tensor_handle.h"
#include "tensorflow/core/framework/cancellation.h"
#include "tensorflow/core/util/device_name_utils.h"
namespace tensorflow {
class EagerOperation {
public:
EagerOperation(tensorflow::EagerContext* ctx, const char* op,
bool is_function, const tensorflow::AttrTypeMap* t)
: ctx_(ctx),
name_(op),
attrs_(op),
attr_types_(t),
device_(nullptr),
is_function_(is_function),
executor_(ctx ? ctx->Executor() : nullptr) {}
~EagerOperation() {
for (tensorflow::TensorHandle* h : inputs_) {
h->Unref();
}
}
bool is_function() const { return is_function_; }
tensorflow::EagerContext* EagerContext() { return ctx_; }
tensorflow::AttrBuilder* MutableAttrs() { return &attrs_; }
const tensorflow::AttrBuilder& Attrs() const { return attrs_; }
const tensorflow::gtl::InlinedVector<tensorflow::TensorHandle*, 4>& Inputs()
const {
return inputs_;
}
tensorflow::gtl::InlinedVector<tensorflow::TensorHandle*, 4>*
MutableInputs() {
return &inputs_;
}
void AddInput(tensorflow::TensorHandle* h);
void UpdateInput(int i, tensorflow::TensorHandle* h);
void ConsumeInput(tensorflow::TensorHandle* h);
const tensorflow::string& Name() const { return name_; }
const tensorflow::AttrTypeMap* AttrTypes() const { return attr_types_; }
tensorflow::Device* Device() const { return device_; }
void SetDevice(tensorflow::Device* device) {
device_ = device;
device_name_ = device->parsed_name();
}
const DeviceNameUtils::ParsedName& GetDeviceName() const {
return device_name_;
}
tensorflow::Status SetDeviceName(const char* device);
void SetUseXla(bool use_xla) { use_xla_ = use_xla; }
CancellationManager* GetCancellationManager() const {
return cancellation_manager_;
}
void SetCancellationManager(CancellationManager* cancellation_manager) {
cancellation_manager_ = cancellation_manager;
}
EagerExecutor* Executor() { return executor_; }
string DebugString() const;
private:
tensorflow::EagerContext* ctx_; // Must outlive the EagerOperation.
const tensorflow::string name_;
tensorflow::AttrBuilder attrs_;
const tensorflow::AttrTypeMap* attr_types_;
tensorflow::gtl::InlinedVector<tensorflow::TensorHandle*, 4> inputs_;
tensorflow::Device* device_;
DeviceNameUtils::ParsedName device_name_;
bool use_xla_ = false;
const bool is_function_;
CancellationManager* cancellation_manager_ = nullptr; // Not owned.
EagerExecutor* const executor_; // Not owned.
};
} // namespace tensorflow
#endif // TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_EAGER_OPERATION_H_