| /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. |
| |
| Licensed under the Apache License, Version 2.0 (the "License"); |
| you may not use this file except in compliance with the License. |
| You may obtain a copy of the License at |
| |
| http://www.apache.org/licenses/LICENSE-2.0 |
| |
| Unless required by applicable law or agreed to in writing, software |
| distributed under the License is distributed on an "AS IS" BASIS, |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| See the License for the specific language governing permissions and |
| limitations under the License. |
| ==============================================================================*/ |
| |
| #ifndef TENSORFLOW_C_EAGER_PARALLEL_DEVICE_PARALLEL_DEVICE_LIB_H_ |
| #define TENSORFLOW_C_EAGER_PARALLEL_DEVICE_PARALLEL_DEVICE_LIB_H_ |
| |
| #include <memory> |
| #include <string> |
| #include <vector> |
| |
| #include "absl/types/optional.h" |
| #include "absl/types/variant.h" |
| #include "tensorflow/c/c_api.h" |
| #include "tensorflow/c/eager/c_api.h" |
| #include "tensorflow/c/eager/c_api_experimental.h" |
| |
| namespace tensorflow { |
| namespace parallel_device { |
| |
| // Functor for making unique_ptrs slightly more ergonomic. Using |
| // decltype(delete_fn) in the unique_ptr's second template argument requires |
| // passing a function pointer to delete_fn when constructing the unique_ptr. |
| class TensorHandleDeleter { |
| public: |
| void operator()(TFE_TensorHandle* to_delete) const { |
| TFE_DeleteTensorHandle(to_delete); |
| } |
| }; |
| |
| using TensorHandlePtr = std::unique_ptr<TFE_TensorHandle, TensorHandleDeleter>; |
| |
| class ExecutorDeleter { |
| public: |
| void operator()(TFE_Executor* to_delete) const { |
| TFE_DeleteExecutor(to_delete); |
| } |
| }; |
| |
| using ExecutorPtr = std::unique_ptr<TFE_Executor, ExecutorDeleter>; |
| |
| class ParallelTensor; |
| |
| using MaybeParallelTensorUnowned = |
| absl::variant<ParallelTensor*, TFE_TensorHandle*>; |
| |
| // Forwards operations to `devices`, maintaining ParallelTensor with components |
| // placed on each underlying device. |
| class ParallelDevice { |
| public: |
| explicit ParallelDevice(const std::vector<std::string>& devices); |
| |
| // Helper to copy a tensor handle from another device once for each component |
| // of the ParallelDevice. |
| // |
| // Sets a bad status and returns a nullptr if `tensor` is already on the |
| // ParallelDevice, or if the individual copies fail. |
| std::unique_ptr<ParallelTensor> CopyToParallelDevice(TFE_Context* context, |
| TFE_TensorHandle* tensor, |
| TF_Status* status) const; |
| |
| // A parallel tensor with scalar integers numbering component devices. |
| std::unique_ptr<ParallelTensor> DeviceIDs(TFE_Context* context, |
| TF_Status* status) const; |
| |
| // The number of devices operations run on. |
| size_t num_underlying_devices() const { return underlying_devices_.size(); } |
| |
| // Takes a description of a single operation being executed on the |
| // ParallelDevice, and in turn runs one operation per component device with |
| // its corresponding inputs from the input ParallelTensors (or |
| // implicitly-mirrored tensors on other devices). Wraps the resulting |
| // per-device and per-output TFE_TensorHandles into one ParallelTensor per |
| // output of the original operation. |
| // |
| // Attributes are forwarded to executed operations unmodified. |
| // |
| // The returned optional has a value if and only if `status` evaluates to |
| // TF_OK. Bad statuses are forwarded from underlying `TFE_Execute` calls, or |
| // if sanity checks on dtypes/metadata fail. |
| absl::optional<std::vector<std::unique_ptr<ParallelTensor>>> Execute( |
| TFE_Context* context, std::vector<MaybeParallelTensorUnowned> inputs, |
| const char* operation_name, const TFE_OpAttrs* attributes, |
| int expected_max_outputs, TF_Status* status) const; |
| |
| private: |
| // A sequence of device names, indicating which devices replicated operations |
| // are forwarded to. |
| const std::vector<std::string> underlying_devices_; |
| // A sequence of TFE_Executors, one per device, for executing operations in |
| // parallel. |
| const std::vector<ExecutorPtr> executors_; |
| }; |
| |
| // Contains a tuple of tensors, one on each of the `underlying_devices_` of the |
| // ParallelDevice. |
| class ParallelTensor { |
| public: |
| // Construct a ParallelTensor from TensorHandles placed on the component |
| // devices of a ParallelDevice. |
| static std::unique_ptr<ParallelTensor> FromTensorHandles( |
| const ParallelDevice& parallel_device, |
| std::vector<TensorHandlePtr> components, TF_Status* status); |
| |
| size_t num_tensors() const { return tensors_.size(); } |
| TFE_TensorHandle* tensor(size_t index) const { return tensors_[index].get(); } |
| |
| // A generalization of the shapes of the underlying tensors. |
| const std::vector<int64_t>& shape() const { return shape_; } |
| TF_DataType dtype() const { return dtype_; } |
| |
| private: |
| ParallelTensor(const ParallelDevice& device, |
| std::vector<TensorHandlePtr> tensors, |
| std::vector<int64_t> shape, const TF_DataType dtype) |
| : device_(device), |
| tensors_(std::move(tensors)), |
| shape_(std::move(shape)), |
| dtype_(dtype) {} |
| |
| const ParallelDevice& device_; |
| const std::vector<TensorHandlePtr> tensors_; |
| const std::vector<int64_t> shape_; |
| const TF_DataType dtype_; |
| }; |
| |
| } // namespace parallel_device |
| } // namespace tensorflow |
| |
| #endif // TENSORFLOW_C_EAGER_PARALLEL_DEVICE_PARALLEL_DEVICE_LIB_H_ |