| /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
| |
| Licensed under the Apache License, Version 2.0 (the "License"); |
| you may not use this file except in compliance with the License. |
| You may obtain a copy of the License at |
| |
| http://www.apache.org/licenses/LICENSE-2.0 |
| |
| Unless required by applicable law or agreed to in writing, software |
| distributed under the License is distributed on an "AS IS" BASIS, |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| See the License for the specific language governing permissions and |
| limitations under the License. |
| ==============================================================================*/ |
| #ifndef TENSORFLOW_C_EAGER_C_API_INTERNAL_H_ |
| #define TENSORFLOW_C_EAGER_C_API_INTERNAL_H_ |
| |
| #include <algorithm> |
| #include <cstddef> |
| #include <map> |
| #include <memory> |
| #include <queue> |
| #include <string> |
| #include <vector> |
| |
| #include "tensorflow/c/c_api.h" |
| #include "tensorflow/c/c_api_internal.h" |
| #include "tensorflow/c/eager/c_api.h" |
| #include "tensorflow/c/eager/c_api_experimental.h" |
| #include "tensorflow/core/common_runtime/device_factory.h" |
| #include "tensorflow/core/common_runtime/eager/attr_builder.h" |
| #include "tensorflow/core/common_runtime/eager/context.h" |
| #include "tensorflow/core/common_runtime/eager/eager_executor.h" |
| #include "tensorflow/core/common_runtime/eager/eager_operation.h" |
| #include "tensorflow/core/common_runtime/eager/kernel_and_device.h" |
| #include "tensorflow/core/common_runtime/eager/tensor_handle.h" |
| #include "tensorflow/core/common_runtime/function.h" |
| #include "tensorflow/core/common_runtime/rendezvous_mgr.h" |
| #include "tensorflow/core/framework/cancellation.h" |
| #include "tensorflow/core/framework/rendezvous.h" |
| #include "tensorflow/core/lib/core/errors.h" |
| #include "tensorflow/core/lib/core/stringpiece.h" |
| #include "tensorflow/core/lib/gtl/inlined_vector.h" |
| #include "tensorflow/core/lib/gtl/map_util.h" |
| #include "tensorflow/core/lib/gtl/stl_util.h" |
| #include "tensorflow/core/lib/monitoring/counter.h" |
| #include "tensorflow/core/lib/monitoring/gauge.h" |
| #include "tensorflow/core/lib/monitoring/sampler.h" |
| #include "tensorflow/core/platform/mutex.h" |
| #include "tensorflow/core/platform/thread_annotations.h" |
| #include "tensorflow/core/profiler/lib/profiler_session.h" |
| #include "tensorflow/core/public/version.h" |
| |
| struct TFE_ContextOptions { |
| TF_SessionOptions session_options; |
| // true if async execution is enabled. |
| bool async = false; |
| TFE_ContextDevicePlacementPolicy device_placement_policy{ |
| TFE_DEVICE_PLACEMENT_SILENT}; |
| TFE_ContextMirroringPolicy mirroring_policy{TFE_MIRRORING_NONE}; |
| }; |
| |
| struct TFE_Context { |
| TFE_Context(const tensorflow::SessionOptions& opts, |
| TFE_ContextDevicePlacementPolicy default_device_placement_policy, |
| TFE_ContextMirroringPolicy default_mirroring_policy, bool async, |
| const tensorflow::DeviceMgr* device_mgr, bool device_mgr_owned, |
| tensorflow::Rendezvous* rendezvous, |
| const tensorflow::CustomKernelCreator* custom_kernel_creator) |
| : context(new tensorflow::EagerContext( |
| opts, |
| static_cast<tensorflow::ContextDevicePlacementPolicy>( |
| default_device_placement_policy), |
| static_cast<tensorflow::ContextMirroringPolicy>( |
| default_mirroring_policy), |
| async, device_mgr, device_mgr_owned, rendezvous, |
| custom_kernel_creator)) {} |
| |
| ~TFE_Context() { context->Unref(); } |
| |
| tensorflow::EagerContext* context; |
| }; |
| |
| struct TFE_TensorHandle { |
| explicit TFE_TensorHandle(tensorflow::TensorHandle* h) : handle(h) {} |
| static TFE_TensorHandle* CreateLocalHandle(const class tensorflow::Tensor& t, |
| TF_Status* s) { |
| tensorflow::TensorHandle* handle; |
| s->status = tensorflow::TensorHandle::CreateLocalHandle(t, &handle); |
| if (!s->status.ok()) { |
| return nullptr; |
| } |
| return new TFE_TensorHandle(handle); |
| } |
| |
| tensorflow::TensorHandle* handle; |
| |
| // Create a symbolic tensor. |
| TFE_TensorHandle(TF_Output t, TF_DataType dtype) |
| : handle(new tensorflow::TensorHandle( |
| tensorflow::OutputGraphNode{t.oper, t.index}, |
| static_cast<tensorflow::DataType>(dtype))) {} |
| }; |
| |
| struct TFE_TensorDebugInfo { |
| explicit TFE_TensorDebugInfo(const std::vector<tensorflow::int64>& dims) |
| : dev_dims(dims) {} |
| |
| // Fully-padded, minor-to-major. |
| std::vector<tensorflow::int64> dev_dims; |
| }; |
| |
| struct TFE_OpInferenceContext { |
| explicit TFE_OpInferenceContext(const tensorflow::OpDef* op_def) |
| : op_def(op_def) {} |
| |
| const tensorflow::OpDef* op_def; // op definition from protobuf |
| int input_arg_idx = 0; // arg definition index for the next input to be added |
| tensorflow::gtl::FlatSet<std::string> attrs; // attributes inferred so far |
| }; |
| |
| struct TFE_Op { |
| TFE_Op(TFE_Context* ctx, const char* op, bool is_function, |
| const tensorflow::AttrTypeMap* t, |
| TFE_OpInferenceContext* inference_ctx) |
| : operation(ctx->context, op, is_function, t), |
| inference_ctx(inference_ctx) {} |
| |
| tensorflow::EagerOperation operation; |
| std::unique_ptr<TFE_OpInferenceContext> inference_ctx; |
| }; |
| |
| struct TFE_ProfilerContext { |
| tensorflow::ProfilerContext profiler_context; |
| }; |
| |
| struct TFE_Profiler { |
| explicit TFE_Profiler(TFE_ProfilerContext* ctx) { |
| profiler = tensorflow::ProfilerSession::Create(&ctx->profiler_context); |
| } |
| |
| std::unique_ptr<tensorflow::ProfilerSession> profiler; |
| }; |
| |
| struct TFE_MonitoringCounterCell { |
| tensorflow::monitoring::CounterCell cell; |
| }; |
| |
| template <int NumLabels> |
| struct TFE_MonitoringCounter { |
| template <typename... LabelDesc> |
| TFE_MonitoringCounter(const char* name, const char* description, |
| LabelDesc&&... label) { |
| counter = absl::WrapUnique(tensorflow::monitoring::Counter<NumLabels>::New( |
| name, description, label...)); |
| } |
| |
| std::unique_ptr<tensorflow::monitoring::Counter<NumLabels>> counter; |
| }; |
| |
| struct TFE_MonitoringCounter0 : TFE_MonitoringCounter<0> { |
| using TFE_MonitoringCounter::TFE_MonitoringCounter; |
| }; |
| struct TFE_MonitoringCounter1 : TFE_MonitoringCounter<1> { |
| using TFE_MonitoringCounter::TFE_MonitoringCounter; |
| }; |
| struct TFE_MonitoringCounter2 : TFE_MonitoringCounter<2> { |
| using TFE_MonitoringCounter::TFE_MonitoringCounter; |
| }; |
| |
| struct TFE_MonitoringIntGaugeCell { |
| tensorflow::monitoring::GaugeCell<tensorflow::int64> cell; |
| }; |
| struct TFE_MonitoringStringGaugeCell { |
| tensorflow::monitoring::GaugeCell<tensorflow::string> cell; |
| }; |
| struct TFE_MonitoringBoolGaugeCell { |
| tensorflow::monitoring::GaugeCell<bool> cell; |
| }; |
| |
| template <typename ValueType, int NumLabels> |
| struct TFE_MonitoringGauge { |
| template <typename... LabelDesc> |
| TFE_MonitoringGauge(const char* name, const char* description, |
| LabelDesc&&... label) { |
| gauge = absl::WrapUnique( |
| tensorflow::monitoring::Gauge<ValueType, NumLabels>::New( |
| name, description, label...)); |
| } |
| |
| std::unique_ptr<tensorflow::monitoring::Gauge<ValueType, NumLabels>> gauge; |
| }; |
| |
| struct TFE_MonitoringIntGauge0 : TFE_MonitoringGauge<tensorflow::int64, 0> { |
| using TFE_MonitoringGauge::TFE_MonitoringGauge; |
| }; |
| struct TFE_MonitoringIntGauge1 : TFE_MonitoringGauge<tensorflow::int64, 1> { |
| using TFE_MonitoringGauge::TFE_MonitoringGauge; |
| }; |
| struct TFE_MonitoringIntGauge2 : TFE_MonitoringGauge<tensorflow::int64, 2> { |
| using TFE_MonitoringGauge::TFE_MonitoringGauge; |
| }; |
| |
| struct TFE_MonitoringStringGauge0 : TFE_MonitoringGauge<tensorflow::string, 0> { |
| using TFE_MonitoringGauge::TFE_MonitoringGauge; |
| }; |
| struct TFE_MonitoringStringGauge1 : TFE_MonitoringGauge<tensorflow::string, 1> { |
| using TFE_MonitoringGauge::TFE_MonitoringGauge; |
| }; |
| struct TFE_MonitoringStringGauge2 : TFE_MonitoringGauge<tensorflow::string, 2> { |
| using TFE_MonitoringGauge::TFE_MonitoringGauge; |
| }; |
| |
| struct TFE_MonitoringBoolGauge0 : TFE_MonitoringGauge<bool, 0> { |
| using TFE_MonitoringGauge::TFE_MonitoringGauge; |
| }; |
| struct TFE_MonitoringBoolGauge1 : TFE_MonitoringGauge<bool, 1> { |
| using TFE_MonitoringGauge::TFE_MonitoringGauge; |
| }; |
| struct TFE_MonitoringBoolGauge2 : TFE_MonitoringGauge<bool, 2> { |
| using TFE_MonitoringGauge::TFE_MonitoringGauge; |
| }; |
| |
| struct TFE_MonitoringBuckets { |
| explicit TFE_MonitoringBuckets( |
| std::function<std::unique_ptr<tensorflow::monitoring::Buckets>(void)> |
| fn) { |
| create_buckets = fn; |
| } |
| |
| std::function<std::unique_ptr<tensorflow::monitoring::Buckets>(void)> |
| create_buckets; |
| }; |
| |
| struct TFE_MonitoringSamplerCell { |
| tensorflow::monitoring::SamplerCell cell; |
| }; |
| |
| template <int NumLabels> |
| struct TFE_MonitoringSampler { |
| template <typename... LabelDesc> |
| TFE_MonitoringSampler( |
| const char* name, |
| std::unique_ptr<tensorflow::monitoring::Buckets> buckets, |
| const char* description, LabelDesc&&... label) { |
| sampler = absl::WrapUnique(tensorflow::monitoring::Sampler<NumLabels>::New( |
| {name, description, label...}, std::move(buckets))); |
| } |
| |
| std::unique_ptr<tensorflow::monitoring::Sampler<NumLabels>> sampler; |
| }; |
| |
| struct TFE_MonitoringSampler0 : TFE_MonitoringSampler<0> { |
| using TFE_MonitoringSampler::TFE_MonitoringSampler; |
| }; |
| struct TFE_MonitoringSampler1 : TFE_MonitoringSampler<1> { |
| using TFE_MonitoringSampler::TFE_MonitoringSampler; |
| }; |
| struct TFE_MonitoringSampler2 : TFE_MonitoringSampler<2> { |
| using TFE_MonitoringSampler::TFE_MonitoringSampler; |
| }; |
| |
| namespace tensorflow { |
| // Set an AttrValue on the op. Doesn't handle the list types. |
| void SetOpAttrValueScalar(TFE_Context* ctx, TFE_Op* op, |
| const tensorflow::AttrValue& default_value, |
| const char* attr_name, TF_Status* status); |
| } // namespace tensorflow |
| |
| struct TFE_TraceContext { |
| TF_Graph* const graph; |
| |
| unsigned int node_counter = 0; |
| // Each tensor handle will have its ref count incremented when it's added as a |
| // map key, and decremented when this object is destroyed. |
| std::map<tensorflow::TensorHandle*, TF_Output> input_tensor_map; |
| std::vector<std::pair<tensorflow::TensorHandle*, TF_Output>>* input_tensors = |
| nullptr; |
| |
| explicit TFE_TraceContext(TF_Graph* graph) : graph(graph) {} |
| |
| ~TFE_TraceContext() { |
| delete input_tensors; |
| for (auto input : input_tensor_map) { |
| input.first->Unref(); |
| } |
| } |
| }; |
| |
| struct TFE_CancellationManager { |
| tensorflow::CancellationManager cancellation_manager; |
| }; |
| |
| #endif // TENSORFLOW_C_EAGER_C_API_INTERNAL_H_ |