| /* Copyright 2022 The TensorFlow Authors. All Rights Reserved. |
| |
| Licensed under the Apache License, Version 2.0 (the "License"); |
| you may not use this file except in compliance with the License. |
| You may obtain a copy of the License at |
| |
| http://www.apache.org/licenses/LICENSE-2.0 |
| |
| Unless required by applicable law or agreed to in writing, software |
| distributed under the License is distributed on an "AS IS" BASIS, |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| See the License for the specific language governing permissions and |
| limitations under the License. |
| ==============================================================================*/ |
| |
| #include "tensorflow/dtensor/cc/dtensor_device.h" |
| |
| #include <algorithm> |
| #include <cstdint> |
| #include <memory> |
| #include <string> |
| #include <utility> |
| #include <vector> |
| |
| #include "absl/base/attributes.h" |
| #include "absl/container/flat_hash_map.h" |
| #include "absl/container/flat_hash_set.h" |
| #include "absl/memory/memory.h" |
| #include "absl/strings/match.h" |
| #include "absl/strings/str_cat.h" |
| #include "absl/strings/str_join.h" |
| #include "absl/strings/string_view.h" |
| #include "absl/strings/strip.h" |
| #include "tensorflow/c/c_api_experimental.h" |
| #include "tensorflow/c/eager/c_api.h" |
| #include "tensorflow/c/eager/parallel_device/parallel_device_lib.h" |
| #include "tensorflow/c/eager/tfe_context_internal.h" |
| #include "tensorflow/c/eager/tfe_tensorhandle_internal.h" |
| #include "tensorflow/c/tf_datatype.h" |
| #include "tensorflow/c/tf_status.h" |
| #include "tensorflow/c/tf_status_helper.h" |
| #include "tensorflow/c/tf_tensor_internal.h" |
| #include "tensorflow/compiler/xla/status_macros.h" |
| #include "tensorflow/core/common_runtime/device_set.h" |
| #include "tensorflow/core/common_runtime/eager/context.h" |
| #include "tensorflow/core/common_runtime/eager/tensor_handle.h" |
| #include "tensorflow/core/common_runtime/graph_constructor.h" |
| #include "tensorflow/core/common_runtime/shape_refiner.h" |
| #include "tensorflow/core/framework/attr_value.pb.h" |
| #include "tensorflow/core/framework/function.h" |
| #include "tensorflow/core/framework/function.pb.h" |
| #include "tensorflow/core/framework/graph_to_functiondef.h" |
| #include "tensorflow/core/framework/node_def_builder.h" |
| #include "tensorflow/core/framework/node_def_util.h" |
| #include "tensorflow/core/framework/op.h" |
| #include "tensorflow/core/framework/tensor_shape.h" |
| #include "tensorflow/core/graph/algorithm.h" |
| #include "tensorflow/core/graph/graph.h" |
| #include "tensorflow/core/lib/strings/proto_serialization.h" |
| #include "tensorflow/core/platform/errors.h" |
| #include "tensorflow/core/platform/fingerprint.h" |
| #include "tensorflow/core/platform/types.h" |
| #include "tensorflow/core/profiler/lib/traceme.h" |
| #include "tensorflow/core/util/dump_graph.h" |
| #include "tensorflow/dtensor/cc/constants.h" |
| #include "tensorflow/dtensor/cc/dstatus.h" |
| #include "tensorflow/dtensor/cc/dtensor_device_util.h" |
| #include "tensorflow/dtensor/cc/dtensor_graph_to_mlir_pass.h" |
| #include "tensorflow/dtensor/cc/small_constant_optimization.h" |
| #include "tensorflow/dtensor/cc/tensor_layout.h" |
| #include "tensorflow/dtensor/cc/tpu_system_interface.h" |
| #include "tensorflow/dtensor/proto/layout.pb.h" |
| #include "tensorflow/stream_executor/tpu/c_api_decl.h" |
| #include "tensorflow/stream_executor/tpu/tpu_platform_interface.h" |
| #include "tensorflow/stream_executor/tpu/tpu_topology.h" |
| |
| namespace tensorflow { |
| namespace dtensor { |
| |
| // TODO(b/189332820): Replace this with a Partitioner stub swapped in by the |
| // Copybara workflow. |
| StatusOr<ExecutionFunctions> ABSL_ATTRIBUTE_WEAK PipeliningPartitionerRun( |
| const absl::flat_hash_map<std::string, const MeshWithParallelDevice*>* |
| device_name_to_mesh_device, |
| FunctionLibraryDefinition* flib_def, DTensorMlirPassRunner* pass_runner, |
| const FunctionDef& fdef, const NameAttrList& eager_attributes, |
| const std::vector<TensorWithLayout*>& inputs, const DeviceSet& device_set, |
| int num_outputs) { |
| // The actual definition is in the pipelining package. |
| return errors::Unimplemented("DTensor pipelining is unavailable."); |
| } |
| |
| class DTensorDevice { |
| public: |
| explicit DTensorDevice(absl::string_view name) |
| : name_(name), |
| same_shape_policy_enabled_(false), |
| cancellation_manager_(absl::make_unique<CancellationManager>()) {} |
| |
| void AddMesh(std::unique_ptr<MeshWithParallelDevice> mesh, |
| bool is_host_mesh) { |
| // TODO(b/168730933): Consider passing a cheaper int64_t mesh identifier. |
| if (is_host_mesh) { |
| std::string& tpu_host_mesh = Mesh::tpu_host_mesh(); |
| const std::string new_tpu_host_mesh = mesh->mesh_config().ToString(); |
| if (!tpu_host_mesh.empty()) { |
| // TODO(b/180046115): Add per-TPU-mesh host mesh bookkeeping. |
| LOG(WARNING) |
| << "A new TPU host mesh is overwriting the old TPU host mesh. The " |
| "old TPU mesh cannot be used in sea of donuts mode anymore."; |
| } |
| tpu_host_mesh.assign(new_tpu_host_mesh); |
| } |
| // For idempotency, don't register the same mesh twice. |
| if (!mesh_to_device_map_.insert({mesh->mesh_config(), std::move(mesh)}) |
| .second) |
| return; |
| if (!default_mesh_) { |
| global_default_mesh_ = mesh_to_device_map_.begin()->second.get(); |
| default_mesh_ = global_default_mesh_; |
| } |
| } |
| |
| // Returns sub meshes of pipelining. |
| // Key is the name of a composite device. |
| StatusOr<absl::flat_hash_map<std::string, const MeshWithParallelDevice*>> |
| PipelineSubMeshes(TFE_Context* context) { |
| absl::flat_hash_map<std::string, const MeshWithParallelDevice*> |
| device_to_mesh; |
| for (const auto& pair : mesh_to_device_map_) { |
| TF_ASSIGN_OR_RETURN(CompositeDevice * device, |
| pair.second->FindOrCreateCompositeDevice(context)); |
| if (device != nullptr) { |
| device_to_mesh[pair.second->composite_device()->name()] = |
| pair.second.get(); |
| } |
| } |
| return device_to_mesh; |
| } |
| |
| // Runs an operation on the DTensorDevice, |
| // |
| // Ignoring the placement of the original op (TFE_OpGetDevice(original_op)). |
| // This indicates whether the user explicitly placed the op on the DTensor |
| // device (vs. having it placed on the DTensor device because an input was |
| // placed there), but DTensor is doing type-based dispatch and so handles |
| // these cases identically at the moment. |
| void Execute(const TFE_Op* original_op, int* num_outputs, |
| TFE_TensorHandle** outputs, TF_Status* status); |
| |
| void SetDefaultLayout(Layout layout) { default_layout_.emplace(layout); } |
| void ClearDefaultLayout() { default_layout_.reset(); } |
| void SetDefaultMesh(Mesh mesh) { |
| default_mesh_ = mesh_to_device_map_.at(mesh).get(); |
| } |
| void ClearDefaultMesh() { default_mesh_ = global_default_mesh_; } |
| void SetSameShapePolicy(bool enabled) { |
| same_shape_policy_enabled_ = enabled; |
| } |
| |
| Status SetTPUCoreIDs(const std::string& mesh_name, |
| const std::vector<int>& tpu_core_ids) { |
| if (VLOG_IS_ON(1)) { |
| LOG(INFO) << "Setting TPU core IDs for " |
| << (mesh_name.empty() ? "default mesh" : mesh_name) << ": "; |
| for (auto i : tpu_core_ids) { |
| LOG(INFO) << i; |
| } |
| } |
| // Setting the default mesh under an empty name repeatedly is fine, which |
| // happens when dtensor_initialize_tpu_system is called multiple times |
| // especially in tests. All the set mappings should be the same anyway. |
| if (!mesh_name.empty() && Mesh::tpu_core_ids().count(mesh_name) > 0) { |
| return errors::AlreadyExists("Mesh name already in use: ", mesh_name); |
| } |
| Mesh::tpu_core_ids()[mesh_name].assign(tpu_core_ids.begin(), |
| tpu_core_ids.end()); |
| return Status::OK(); |
| } |
| |
| void ClearTPUCoreIDs() { Mesh::tpu_core_ids().clear(); } |
| |
| std::vector<std::vector<int>> TPUCoreIDsToLocations( |
| TFE_Context* context, const std::vector<int>& tpu_core_ids) { |
| TpuSystemInterface* tpu_system = GetPreferredTpuSystem(); |
| if (tpu_system == nullptr) { |
| VLOG(1) << "Calling TPUCoreIDsToLocations on the default TPU system."; |
| std::vector<std::vector<int>> tpu_core_locations; |
| tpu_core_locations.reserve(tpu_core_ids.size()); |
| tpu::TpuPlatformInterface* tpu_platform = |
| tpu::TpuPlatformInterface::GetRegisteredPlatform(); |
| if (tpu_platform == nullptr) { |
| LOG(WARNING) << "No TPU platform is found."; |
| return {{}}; |
| } |
| if (!tpu_platform->Initialized()) { |
| LOG(WARNING) << "TPU platform is not initialized."; |
| return {{}}; |
| } |
| tpu::TpuTopologyExternal tpu_topology = tpu_platform->topology(); |
| |
| for (const int& tpu_core_id : tpu_core_ids) { |
| tpu::TpuCoreLocationExternal core = |
| tpu_topology.CoreForId(TpuCoreTypeEnum::kTensorCore, tpu_core_id); |
| tpu::TpuDimensionsExternal tpu_chip_location = core.chip_coordinates(); |
| tpu_core_locations.push_back({tpu_chip_location.x, tpu_chip_location.y, |
| tpu_chip_location.z, core.index()}); |
| } |
| return tpu_core_locations; |
| } else { |
| VLOG(1) << "Calling TPUCoreIDsToLocations on a preferred TPU system."; |
| return tpu_system->TPUCoreIDsToLocations(context, tpu_core_ids); |
| } |
| } |
| |
| std::vector<int> TPUCoreLocationsToIDs( |
| TFE_Context* context, |
| const std::vector<std::vector<int>>& tpu_core_locations) { |
| TpuSystemInterface* tpu_system = GetPreferredTpuSystem(); |
| if (tpu_system == nullptr) { |
| VLOG(1) << "Calling TPUCoreLocationsToIDs on the default TPU system."; |
| std::vector<int> tpu_core_ids; |
| tpu_core_ids.reserve(tpu_core_locations.size()); |
| tpu::TpuPlatformInterface* tpu_platform = |
| tpu::TpuPlatformInterface::GetRegisteredPlatform(); |
| if (tpu_platform == nullptr) { |
| LOG(WARNING) << "No TPU platform is found."; |
| return {}; |
| } |
| if (!tpu_platform->Initialized()) { |
| LOG(WARNING) << "TPU platform is not initialized."; |
| return {}; |
| } |
| tpu::TpuTopologyExternal tpu_topology = tpu_platform->topology(); |
| |
| for (const std::vector<int>& tpu_core_location : tpu_core_locations) { |
| tpu::TpuCoreLocationExternal core = tpu_topology.Core( |
| TpuCoreTypeEnum::kTensorCore, tpu_core_location[0], |
| tpu_core_location[1], tpu_core_location[2], tpu_core_location[3]); |
| tpu_core_ids.push_back(core.Id()); |
| } |
| return tpu_core_ids; |
| } else { |
| VLOG(1) << "Calling TPUCoreLocationsToIDs on a preferred TPU system."; |
| return tpu_system->TPUCoreLocationsToIDs(context, tpu_core_locations); |
| } |
| } |
| |
| // Waits for ops to finish in ALL meshes as we share the cancellation manager. |
| void AsyncWait(TFE_Context* context, TF_Status* status) { |
| std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> first_bad_status( |
| nullptr, TF_DeleteStatus); |
| |
| for (const auto& pair : mesh_to_device_map_) { |
| std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> async_wait_status( |
| TF_NewStatus(), TF_DeleteStatus); |
| |
| pair.second->parallel_device().AsyncWait(context, |
| async_wait_status.get()); |
| |
| TF_Code error_code = TF_GetCode(async_wait_status.get()); |
| if (error_code != TF_OK && |
| (first_bad_status == nullptr || |
| TF_GetCode(first_bad_status.get()) == TF_CANCELLED)) { |
| first_bad_status.reset(TF_NewStatus()); |
| TF_SetStatus(first_bad_status.get(), error_code, |
| TF_Message(async_wait_status.get())); |
| } |
| } |
| |
| if (first_bad_status != nullptr) { |
| TF_SetStatus(status, TF_GetCode(first_bad_status.get()), |
| TF_Message(first_bad_status.get())); |
| } |
| |
| // Reset the global function rendezvous, which otherwise stores a failure |
| // state. |
| tensorflow::unwrap(context)->ResetGlobalRendezvousForFunction(); |
| |
| // Reset the cancellation manager on (potential) failure so we don't cancel |
| // future ops. This is only safe because we have just cleared pending async |
| // nodes, which may have had a reference to he cancellation manager. |
| cancellation_manager_ = absl::make_unique<CancellationManager>(); |
| } |
| |
| TFE_TensorHandle* Pack(TFE_Context* context, int num_inputs, |
| TFE_TensorHandle** inputs, |
| const std::string& string_layout, TF_Status* status); |
| |
| std::vector<TFE_TensorHandle*> Unpack(TFE_Context* context, |
| TFE_TensorHandle* input, |
| TF_Status* status); |
| |
| // Return the layout for the input tensor. |
| std::string FetchLayout(TFE_Context* context, TFE_TensorHandle* input, |
| TF_Status* status); |
| |
| TFE_TensorHandle* SparsePack(TFE_Context* context, int num_inputs, |
| TFE_TensorHandle** indices, |
| TFE_TensorHandle** values, |
| TFE_TensorHandle** shapes, |
| const std::string& string_layout, |
| TF_Status* status); |
| |
| bool IsSparseDTensor(TFE_Context* context, TFE_TensorHandle* input, |
| TF_Status* status); |
| |
| private: |
| // If the `operation_name` of an op indicates a custom DTensor op (e.g. |
| // CopyToMesh), then separately handle those custom ops instead of running |
| // default DTensor graph compilation. |
| void MaybeHandleDTensorCustomOps( |
| const char* operation_name, const int num_inputs, |
| const TFE_OpAttrs* attributes, TFE_Context* context, |
| TFE_TensorHandle** inputs, int* num_outputs, TFE_TensorHandle** outputs, |
| bool* is_custom_dtensor_op, TF_Status* status); |
| |
| // Copies non-dtensor eager tensor or DTensor to a mesh specified by |
| // `attributes`. |
| // Currently, only copy to replicated layout on target mesh is supported. |
| void CopyToMesh(TFE_Context* context, int num_inputs, |
| TFE_TensorHandle** inputs, const TFE_OpAttrs* attributes, |
| TFE_TensorHandle** outputs, int* num_outputs, |
| TF_Status* status); |
| |
| // Update output layouts for eager ops based on same shape policy. |
| void UpdateOutputLayoutsWithSameShapePolicy( |
| const std::vector<PartialTensorShape>& global_output_shapes, |
| const absl::flat_hash_set<Mesh>& input_meshes, absl::string_view op_name, |
| tensorflow::Graph* graph, std::vector<const Layout*>* output_layouts, |
| TF_Status* status); |
| |
| const ExecutionFunctions* GetCachedFunction(tensorflow::Fprint128 cache_key) { |
| auto iter = function_cache_.find(cache_key); |
| if (iter == function_cache_.end()) { |
| return nullptr; |
| } |
| return &iter->second; |
| } |
| |
| const ExecutionFunctions* AddCachedFunction(tensorflow::Fprint128 cache_key, |
| ExecutionFunctions function) { |
| function_cache_.emplace(cache_key, std::move(function)); |
| return &function_cache_.find(cache_key)->second; |
| } |
| |
| // Takes the description of an operation and makes a function out of it with |
| // the same signature, running DTensor MLIR passes. Registers that function |
| // with `context`. `translated_function_name` is set to the name of the |
| // function. |
| // |
| // The resulting function expects a device ID as its first input. |
| void LowerToSPMDFunction(TFE_Context* context, |
| const std::vector<TensorWithLayout*>& inputs, |
| const DTensorOperation& doperation, |
| const TFE_OpAttrs* attributes, const int num_outputs, |
| const ExecutionFunctions** execution_functions, |
| TF_Status* status); |
| |
| // Execute a given function. |
| void ExecuteFunctionAndWait( |
| TFE_Context* context, const TranslatedFunction* function_ptr, |
| const MeshWithParallelDevice* parallel_device_mesh, |
| const std::vector<parallel_device::ParallelTensor*>& parallel_inputs, |
| const int64_t step_id, const TFE_OpAttrs* attributes, TF_Status* status); |
| |
| // Implements `Execute` for operations which aren't special-cased in |
| void ExecuteRegularOperation(TFE_Context* context, |
| const std::vector<TensorWithLayout*>& inputs, |
| const DTensorOperation& doperation, |
| const TFE_OpAttrs* attributes, int* num_outputs, |
| TFE_TensorHandle** outputs, TF_Status* status); |
| |
| // Wraps a TensorWithLayout into a TFE_TensorHandle. |
| TFE_TensorHandle* MakeLayoutTensorHandle(TFE_Context* context, |
| std::unique_ptr<TensorWithLayout> t, |
| TF_Status* status); |
| |
| void RecordInShapeLayoutCache(const TensorWithLayout& tensor); |
| |
| // Returns whether a given mesh is a remote mesh. |
| bool is_remote_mesh(const Mesh& mesh) const; |
| |
| // The name of the device (the custom device) |
| std::string name_; |
| // Mesh configs with matching parallel devices. |
| // |
| // For now we just consider the first entry added to dtensor_device as the |
| // default mesh. Before we reach an agreement on this, we'll leave it as is. |
| absl::flat_hash_map<Mesh, std::unique_ptr<MeshWithParallelDevice>> |
| mesh_to_device_map_; |
| // TODO(hthu): Consider whether we want to preserve the default_mesh semantic. |
| // Current default mesh consistent to default_layout_. If default_layout_ is |
| // not set, it equals to global_default_mesh_. |
| const MeshWithParallelDevice* default_mesh_ = nullptr; |
| // The default mesh of a DTensorDevice, which cannot be modified once being |
| // set. |
| const MeshWithParallelDevice* global_default_mesh_ = nullptr; |
| // If the user has specified a default output layout. |
| absl::optional<Layout> default_layout_; |
| |
| // Determines whether tensors with a shape previously associated with only one |
| // layout use that layout if nothing else can be inferred. |
| bool same_shape_policy_enabled_; |
| |
| DTensorMlirPassRunner pass_runner_; |
| |
| struct CachedLayout { |
| // The first layout seen with this shape |
| Layout layout; |
| // Whether the layout is unique for this shape |
| bool is_unique; |
| }; |
| absl::flat_hash_map<int64_t, CachedLayout> shape_layout_cache_; |
| |
| // TODO(b/169348205) Support cache eviction if the cache gets bloated. |
| absl::flat_hash_map<tensorflow::Fprint128, ExecutionFunctions, |
| tensorflow::Fprint128Hasher> |
| function_cache_; |
| |
| // Coordinates cancelling ops across meshes on error. Must outlive any queued |
| // async op launches, so we only reset it after seeing a failure status. |
| std::unique_ptr<CancellationManager> cancellation_manager_; |
| |
| // Map each function_mesh_fingerprint (based on the set of the mesh involved) |
| // to the number of times of the function execution. The |
| // function_mesh_fingerprint and the counter together are used for generating |
| // the step id, which is used for rendezvous creation. |
| absl::flat_hash_map<uint64, uint64> func_mesh_fingerprint_to_step_counter_; |
| }; |
| |
| int64_t FingerprintShape(const absl::Span<const int64_t> shape) { |
| int64_t fprint = 0; |
| for (int64_t dim : shape) { |
| fprint = FingerprintCat64(fprint, dim); |
| } |
| return fprint; |
| } |
| |
| parallel_device::ParallelTensor* MeshWithParallelDevice::DeviceIDs( |
| TFE_Context* context, TF_Status* status) const { |
| if (device_ids_tensor_ == nullptr) { |
| // Global device IDs sequentially increase. |
| // |
| // This is the assumption in the dtensor software stack. MLIR pass relies on |
| // this assumption to generate mesh coordinates for each core efficiently. |
| // |
| // The rule to set local ids and the mapping from global ids to real |
| // physical core index, e.g., TPU, is nontrivial unfortunately. It is |
| // possible to set identical mapping but the collective operation |
| // performance is terrible for most of cases. |
| // |
| // - For ICI-connected TPU slice, see go/dtensor-device-assignment-summary |
| // for guide how to create efficient core assignments toward peak |
| // performance. |
| // |
| // The global id to core assignment mapping is bridged by |
| // `Mesh::tpu_core_ids()` and consumed by `UpdateTPUCompileMetadata`. |
| // |
| // - For DCN-connected topology, we need to map different sections of the |
| // global ids to its real physical cores separately according to the |
| // runtime requirements. For example, for a 4x32 mesh, in which the outer |
| // dimension is connected via DCN and inner dimension is connected by ICI, |
| // the device assignments for inner dimension should typically form its |
| // own ring order (not plain physical core index) in each sub-meshes and |
| // the outer dimension should be assigned according to the real physical |
| // ring of DNC hosts. |
| // |
| // Note: In order to change this assumption, MLIR pass needs adjustment. One |
| // possible approach is to take a N-D mapping vector for N-D mesh and lookup |
| // the coordinates in MLIR, by consulting tensor layout as well, rather than |
| // calculation on-the-fly. |
| |
| // LINT.IfChange |
| for (int64_t i = 0; i < mesh_config_.global_device_ids().size(); ++i) { |
| if (mesh_config_.global_device_ids()[i] - i != |
| mesh_config_.global_device_ids()[0]) { |
| TF_SetStatus( |
| status, TF_INTERNAL, |
| absl::StrCat("Global device IDs should be consecutive: ", |
| absl::StrJoin(mesh_config_.global_device_ids(), ", ")) |
| .c_str()); |
| return nullptr; |
| } |
| } |
| // LINT.ThenChange(//tensorflow/dtensor/python/layout.py) |
| |
| // Local device IDs are a subset of global device IDs, arranged in device |
| // ordinal order. |
| std::vector<int32_t> ids; |
| for (int64_t id : mesh_config_.local_device_ids()) { |
| ids.push_back(id); |
| } |
| VLOG(1) << "Parallel device IDs: " << absl::StrJoin(ids, ", "); |
| device_ids_tensor_ = |
| parallel_device_->ScalarsFromSequence<int32_t>(ids, context, status); |
| if (TF_GetCode(status) != TF_OK) return nullptr; |
| } |
| return device_ids_tensor_.get(); |
| } |
| |
| int TensorWithLayoutNumDims(void* data, TF_Status* status) { |
| return reinterpret_cast<TensorWithLayout*>(data)->global_shape().size(); |
| } |
| |
| int64_t TensorWithLayoutDim(void* data, int dim_index, TF_Status* status) { |
| return reinterpret_cast<TensorWithLayout*>(data)->global_shape()[dim_index]; |
| } |
| |
| void TensorWithLayoutDeallocator(void* data) { |
| delete reinterpret_cast<TensorWithLayout*>(data); |
| } |
| |
| TF_Buffer* TensorWithLayoutSummarize(void* data, TF_Status* status) { |
| std::string summary = |
| reinterpret_cast<TensorWithLayout*>(data)->SummarizeValue(); |
| return TF_NewBufferFromString(summary.data(), summary.size()); |
| } |
| |
| TFE_TensorHandle* DTensorDevice::MakeLayoutTensorHandle( |
| TFE_Context* context, std::unique_ptr<TensorWithLayout> t, |
| TF_Status* status) { |
| TF_DataType dtype = t->dtype(); |
| TFE_CustomDeviceTensorHandleMethods handle_methods; |
| handle_methods.num_dims = &TensorWithLayoutNumDims; |
| handle_methods.dim = &TensorWithLayoutDim; |
| handle_methods.deallocator = &TensorWithLayoutDeallocator; |
| handle_methods.summarize = &TensorWithLayoutSummarize; |
| return TFE_NewCustomDeviceTensorHandle(context, name_.c_str(), dtype, |
| /*data=*/t.release(), handle_methods, |
| status); |
| } |
| |
| void DTensorDevice::RecordInShapeLayoutCache(const TensorWithLayout& tensor) { |
| auto existing = shape_layout_cache_.insert( |
| {FingerprintShape(tensor.global_shape()), |
| CachedLayout{tensor.layout(), /*is_unique=*/true}}); |
| |
| if (!existing.second) { |
| // There is an entry already; if the layout doesn't match we should record |
| // the fact that it's not unique. |
| if (tensor.layout() != existing.first->second.layout) { |
| existing.first->second.is_unique = false; |
| } |
| } |
| } |
| |
| bool DTensorDevice::is_remote_mesh(const Mesh& mesh) const { |
| // An empty mesh might be assigned to VarHandleOp during DTensor MLIR lowering |
| // pass. Decide whether the empty mesh is remote based on the current default |
| // mesh. |
| return mesh.is_remote() || |
| (mesh.IsEmpty() && default_mesh_->mesh_config().is_remote()); |
| } |
| |
| StatusOr<NameAttrList> FetchAttributes(const TFE_OpAttrs* attributes) { |
| // TODO(allenl): Should we just give up on the public C API to save on |
| // serialization/deserialization? We need all of the attributes and to treat |
| // them generically, which isn't going to be pleasant with typed attribute |
| // methods. |
| std::unique_ptr<TF_Buffer, decltype(&TF_DeleteBuffer)> serialized_attributes( |
| TF_NewBuffer(), TF_DeleteBuffer); |
| |
| TF_Status* status = TF_NewStatus(); |
| TFE_OpAttrsSerialize(attributes, serialized_attributes.get(), status); |
| if (TF_GetCode(status) == TF_OK) { |
| TF_DeleteStatus(status); |
| } else { |
| Status failure_status = StatusFromTF_Status(status); |
| TF_DeleteStatus(status); |
| return failure_status; |
| } |
| |
| NameAttrList name_and_attrs; |
| if (!name_and_attrs.ParseFromArray(serialized_attributes->data, |
| serialized_attributes->length)) { |
| return tensorflow::errors::Unknown("Could not parse attributes"); |
| } |
| return name_and_attrs; |
| } |
| |
| StatusOr<Layout> FetchLayoutFromAttributes(const TFE_OpAttrs* attributes, |
| absl::string_view attribute_name) { |
| // Get attributes. |
| TF_ASSIGN_OR_RETURN(NameAttrList name_and_attrs, FetchAttributes(attributes)); |
| |
| // Get layout string from attributes. |
| absl::string_view layout_str = |
| name_and_attrs.attr().find(std::string(attribute_name))->second.s(); |
| |
| // This would probably be slow at the moment without caching. |
| // We should consider making this faster in the future. |
| return Layout::FromString(string(layout_str)); |
| } |
| |
| std::string DTensorDevice::FetchLayout(TFE_Context* context, |
| TFE_TensorHandle* input, |
| TF_Status* status) { |
| VLOG(1) << "Checking layout..."; |
| const char* input_device = TFE_TensorHandleDeviceName(input, status); |
| if (input_device != name_) { |
| TF_SetStatus(status, TF_INVALID_ARGUMENT, |
| "FetchLayout expects a tensor placed on the layout device."); |
| return {}; |
| } |
| TensorWithLayout* t = reinterpret_cast<TensorWithLayout*>( |
| TFE_TensorHandleDevicePointer(input, status)); |
| if (TF_GetCode(status) != TF_OK) return {}; |
| return t->layout().ToString(); |
| } |
| |
| std::vector<TFE_TensorHandle*> DTensorDevice::Unpack(TFE_Context* context, |
| TFE_TensorHandle* input, |
| TF_Status* status) { |
| std::vector<TFE_TensorHandle*> outputs; |
| |
| const char* input_device = TFE_TensorHandleDeviceName(input, status); |
| if (TF_GetCode(status) != TF_OK) return outputs; |
| if (input_device != name_) { |
| TF_SetStatus( |
| status, TF_INVALID_ARGUMENT, |
| absl::StrCat( |
| "DTensorUnpack expects a tensor placed on the DTensor device: ", |
| name_, ", but input was placed on device: ", input_device) |
| .c_str()); |
| return outputs; |
| } |
| TensorWithLayout* t = reinterpret_cast<TensorWithLayout*>( |
| TFE_TensorHandleDevicePointer(input, status)); |
| if (TF_GetCode(status) != TF_OK) return outputs; |
| |
| if (is_remote_mesh(t->mesh().mesh_config())) { |
| TF_SetStatus(status, TF_UNIMPLEMENTED, |
| "DTensorUnpack is not supported on a remote mesh."); |
| return outputs; |
| } |
| const int output_size = t->num_tensors(); |
| outputs.resize(output_size); |
| |
| for (int output_index = 0; output_index < output_size; ++output_index) { |
| outputs[output_index] = |
| TFE_TensorHandleCopySharingTensor(t->get_tensor(output_index), status); |
| if (TF_GetCode(status) != TF_OK) { |
| return outputs; |
| } |
| } |
| return outputs; |
| } |
| |
| void DTensorDevice::MaybeHandleDTensorCustomOps( |
| const char* operation_name, const int num_inputs, |
| const TFE_OpAttrs* attributes, TFE_Context* context, |
| TFE_TensorHandle** inputs, int* num_outputs, TFE_TensorHandle** outputs, |
| bool* is_custom_dtensor_op, TF_Status* status) { |
| *is_custom_dtensor_op = true; |
| if (operation_name == std::string("_EagerConst")) { |
| // Op-by-op const has no obvious layout. DTensor skips an SPMD expansion and |
| // instead relies on copy-on when the value is used later. |
| std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op( |
| TFE_NewOp(context, operation_name, status), TFE_DeleteOp); |
| if (TF_GetCode(status) != TF_OK) return; |
| for (int input_index = 0; input_index < num_inputs; ++input_index) { |
| TFE_OpAddInput(op.get(), inputs[input_index], status); |
| if (TF_GetCode(status) != TF_OK) return; |
| } |
| TFE_OpAddAttrs(op.get(), attributes); |
| TFE_Execute(op.get(), outputs, num_outputs, status); |
| return; |
| } |
| if (operation_name == std::string("CopyToMesh")) { |
| CopyToMesh(context, num_inputs, inputs, attributes, outputs, num_outputs, |
| status); |
| return; |
| } |
| |
| *is_custom_dtensor_op = false; |
| } |
| |
| void DTensorDevice::CopyToMesh(TFE_Context* context, int num_inputs, |
| TFE_TensorHandle** inputs, |
| const TFE_OpAttrs* attributes, |
| TFE_TensorHandle** outputs, int* num_outputs, |
| TF_Status* status) { |
| if (num_inputs != 1) { |
| RETURN_STATUS(status, TF_INVALID_ARGUMENT, |
| "DTensor CopyToMesh requires exactly 1 input."); |
| } |
| if (*num_outputs < 1) { |
| RETURN_STATUS(status, TF_INTERNAL, |
| "DTensor CopyToMesh must have output buffer to allocate at " |
| "least 1 output."); |
| } |
| |
| // Assign layout. |
| StatusOr<Layout> target_layout_or = |
| FetchLayoutFromAttributes(attributes, kQualifiedLayoutAttr); |
| if (!target_layout_or.ok()) { |
| RETURN_STATUS(status, TF_INVALID_ARGUMENT, |
| "DTensor CopyToMesh requires valid layout attribute for " |
| "destination DTensor."); |
| } |
| |
| const Layout target_layout = *target_layout_or; |
| const Mesh& target_mesh = target_layout.mesh(); |
| |
| // TODO(b/193443769): Support sharded layout for eager copy to mesh. |
| if (!target_layout.IsFullyReplicated()) { |
| RETURN_STATUS(status, TF_UNIMPLEMENTED, |
| "Target layout of DTensor CopyToMesh must be replicated. " |
| "Consider changing the target layout to replicated layout or " |
| "file a bug to the DTensor team (b/193443769)."); |
| } |
| |
| TFE_TensorHandle* input_tensor = inputs[0]; |
| |
| // Check that if input tensor is DTensor, then input layout of the DTensor |
| // must be replicated. |
| const char* input_device = TFE_TensorHandleDeviceName(input_tensor, status); |
| if (TF_GetCode(status) != TF_OK) return; |
| |
| if (name_ == input_device) { |
| // Handle input which is on DTensor device already. |
| TensorWithLayout* t = reinterpret_cast<TensorWithLayout*>( |
| TFE_TensorHandleDevicePointer(input_tensor, status)); |
| if (TF_GetCode(status) != TF_OK) return; |
| |
| if (!t->layout().IsFullyReplicated()) |
| RETURN_STATUS(status, TF_INVALID_ARGUMENT, |
| "Input tensor to CopyToMesh must be replicated DTensor or " |
| "normal eager Tensor."); |
| |
| // If input to CopyToMesh is a DTensor, we use the first local tensor as |
| // input tensor handle to invoke copy. |
| input_tensor = t->get_tensor(0); |
| } |
| |
| auto it = mesh_to_device_map_.find(target_mesh); |
| if (it == mesh_to_device_map_.end()) { |
| RETURN_STATUS( |
| status, TF_INTERNAL, |
| "DTensor CopyToMesh target mesh is not registered. Meshes should be " |
| "automatically registered. Please file a bug. (component id: 833864)"); |
| } |
| |
| const MeshWithParallelDevice* target_parallel_mesh = it->second.get(); |
| |
| // Broadcast non-dtensor value to dtensor. |
| std::unique_ptr<TensorWithLayout> wrapper = TensorWithLayout::Broadcast( |
| context, input_tensor, *target_parallel_mesh, name_, status); |
| if (TF_GetCode(status) != TF_OK) return; |
| |
| RecordInShapeLayoutCache(*wrapper); |
| *num_outputs = 1; |
| *outputs = MakeLayoutTensorHandle(context, std::move(wrapper), status); |
| } |
| |
| namespace { |
| |
| // Verifies that all components have the same dtype and shape. |
| // The component shape will be set upon success. |
| void VerifyPackTensorShapeAndDtype( |
| std::vector<parallel_device::TensorHandlePtr>& components, |
| std::vector<int64_t>* component_shape, TF_Status* status) { |
| TF_DataType dtype = TFE_TensorHandleDataType(components[0].get()); |
| auto size = TFE_TensorHandleNumDims(components[0].get(), status); |
| if (TF_GetCode(status) != TF_OK) return; |
| component_shape->clear(); |
| component_shape->reserve(size); |
| for (int i = 0; i < size; ++i) { |
| component_shape->push_back( |
| TFE_TensorHandleDim(components[0].get(), i, status)); |
| if (TF_GetCode(status) != TF_OK) return; |
| } |
| |
| // Verify that the TensorHandle's shape and dtype match all of the component |
| // shapes and dtypes. |
| for (const auto& component : components) { |
| for (int i = 0; i < component_shape->size(); ++i) { |
| int64_t tensor_dim = TFE_TensorHandleDim(component.get(), i, status); |
| if (TF_GetCode(status) != TF_OK) return; |
| if (tensor_dim != (*component_shape)[i]) { |
| TF_SetStatus(status, TF_UNIMPLEMENTED, |
| "Components of a PackedTensor must currently all have " |
| "the same shape"); |
| return; |
| } |
| if (TFE_TensorHandleDataType(component.get()) != dtype) { |
| TF_SetStatus(status, TF_INTERNAL, |
| "Components of a PackedTensor must all have " |
| "the same dtype"); |
| return; |
| } |
| } |
| } |
| } |
| |
| // Verifies that all TensorHandles have rank `rank` of dtype `dtype`. |
| void VerifyTensorRankAndDType(TFE_TensorHandle** tensors, int num_input, |
| int expected_rank, TF_DataType* expected_dtype, |
| TF_Status* status) { |
| for (int i = 0; i < num_input; ++i) { |
| auto actual_rank = TFE_TensorHandleNumDims(tensors[i], status); |
| if (TF_GetCode(status) != TF_OK) |
| RETURN_STATUS(status, TF_INTERNAL, "Error getting rank of tensor."); |
| if (actual_rank != expected_rank) |
| RETURN_STATUS(status, TF_INVALID_ARGUMENT, |
| "Rank of tensor did not match the expected rank."); |
| if (expected_dtype != nullptr && |
| TFE_TensorHandleDataType(tensors[i]) != *expected_dtype) |
| RETURN_STATUS(status, TF_INVALID_ARGUMENT, |
| "Dtype of tensor did not match the expected dtype."); |
| } |
| } |
| |
| } // namespace |
| |
| TFE_TensorHandle* DTensorDevice::Pack(TFE_Context* context, int num_inputs, |
| TFE_TensorHandle** inputs, |
| const std::string& string_layout, |
| TF_Status* status) { |
| if (num_inputs < 1) { |
| TF_SetStatus(status, TF_INVALID_ARGUMENT, |
| "DTensorPack requires 1 or more inputs"); |
| return nullptr; |
| } |
| StatusOr<Layout> target_layout = Layout::FromString(string_layout); |
| if (!target_layout.ok()) { |
| TF_SetStatus(status, TF_INVALID_ARGUMENT, |
| "Failed to parse layout from string layout"); |
| return nullptr; |
| } |
| const Mesh& target_mesh = target_layout->mesh(); |
| const MeshWithParallelDevice* target_parallel_device = |
| mesh_to_device_map_[target_mesh].get(); |
| if (target_parallel_device == nullptr) { |
| TF_SetStatus(status, TF_INVALID_ARGUMENT, |
| absl::StrCat("Required mesh : ", target_mesh.ToString(), |
| "is not registered with DTensor ") |
| .c_str()); |
| return nullptr; |
| } |
| |
| std::unique_ptr<TensorWithLayout> packed_tensor; |
| if (is_remote_mesh(target_parallel_device->mesh_config())) { |
| // Create a dummy output for DTensorPack if inputs are on a remote mesh. |
| TF_DataType dtype = TFE_TensorHandleDataType(inputs[0]); |
| auto size = TFE_TensorHandleNumDims(inputs[0], status); |
| if (TF_GetCode(status) != TF_OK) return nullptr; |
| std::vector<int64_t> component_shape; |
| component_shape.reserve(size); |
| for (int i = 0; i < size; ++i) { |
| component_shape.push_back(TFE_TensorHandleDim(inputs[0], i, status)); |
| if (TF_GetCode(status) != TF_OK) return nullptr; |
| } |
| packed_tensor = TensorWithLayout::Dummy( |
| component_shape, dtype, *target_parallel_device, *target_layout); |
| |
| } else { |
| auto local_devices = target_parallel_device->mesh_config().local_devices(); |
| |
| if (num_inputs != |
| target_parallel_device->parallel_device().num_underlying_devices()) { |
| TF_SetStatus(status, TF_INVALID_ARGUMENT, |
| absl::StrCat("The dtensor device ", name_, " expected ", |
| local_devices.size(), |
| " inputs to DTensorPack, but got ", num_inputs) |
| .c_str()); |
| return nullptr; |
| } |
| |
| std::vector<parallel_device::TensorHandlePtr> components; |
| components.reserve(num_inputs); |
| for (int i = 0; i < num_inputs; ++i) { |
| TFE_TensorHandle* input = inputs[i]; |
| const char* input_device = TFE_TensorHandleDeviceName(input, status); |
| if (TF_GetCode(status) != TF_OK) return nullptr; |
| if (name_ == input_device) { |
| TF_SetStatus(status, TF_INVALID_ARGUMENT, |
| "Does not support packing a Tensor that is already on " |
| "dtensor device"); |
| return nullptr; |
| } |
| // If `input` is on the target device, this creates a new handle sharing |
| // the underlying data; otherwise, async copies are invoked. |
| components.emplace_back(TFE_TensorHandleCopyToDevice( |
| input, context, local_devices[i].c_str(), status)); |
| if (TF_GetCode(status) != TF_OK) return nullptr; |
| } |
| |
| std::vector<int64_t> component_shape; |
| VerifyPackTensorShapeAndDtype(components, &component_shape, status); |
| if (TF_GetCode(status) != TF_OK) return nullptr; |
| |
| std::unique_ptr<parallel_device::ParallelTensor> parallel_tensor = |
| parallel_device::ParallelTensor::FromTensorHandles( |
| target_parallel_device->parallel_device(), std::move(components), |
| status); |
| if (TF_GetCode(status) != TF_OK) return nullptr; |
| |
| if (target_layout->rank() != component_shape.size()) { |
| TF_SetStatus( |
| status, TF_INVALID_ARGUMENT, |
| absl::StrCat( |
| "Packed layout should have the same rank as the rank for each " |
| "component. The rank of each component is: ", |
| component_shape.size(), |
| ", while layout has rank: ", target_layout->rank(), |
| "\nLayout: ", target_layout->ToString(), "\n") |
| .c_str()); |
| return nullptr; |
| } |
| |
| packed_tensor = |
| TensorWithLayout::Wrap(std::move(parallel_tensor), |
| *target_parallel_device, *target_layout) |
| .ValueOrDie(); |
| } |
| |
| RecordInShapeLayoutCache(*packed_tensor); |
| TFE_TensorHandle* output = |
| MakeLayoutTensorHandle(context, std::move(packed_tensor), status); |
| if (TF_GetCode(status) != TF_OK) return nullptr; |
| return output; |
| } |
| |
| TFE_TensorHandle* DTensorDevice::SparsePack( |
| TFE_Context* context, int num_inputs, TFE_TensorHandle** indices, |
| TFE_TensorHandle** values, TFE_TensorHandle** shapes, |
| const std::string& string_layout, TF_Status* status) { |
| StatusOr<Layout> target_layout = Layout::FromString(string_layout); |
| if (!target_layout.ok()) { |
| TF_SetStatus(status, TF_INVALID_ARGUMENT, |
| "Failed to parse layout from string layout"); |
| return nullptr; |
| } |
| const Mesh& target_mesh = target_layout->mesh(); |
| const MeshWithParallelDevice* target_parallel_device = |
| mesh_to_device_map_[target_mesh].get(); |
| if (target_parallel_device == nullptr) { |
| TF_SetStatus(status, TF_INVALID_ARGUMENT, |
| absl::StrCat("Required mesh : ", target_mesh.ToString(), |
| "is not registered with DTensor ") |
| .c_str()); |
| return nullptr; |
| } |
| |
| TF_DataType tf_int64 = TF_INT64; |
| // Verify rank and dtype of shapes. |
| VerifyTensorRankAndDType(shapes, num_inputs, /*expected_rank=*/1, |
| /*expected_dtype=*/&tf_int64, status); |
| if (TF_GetCode(status) != TF_OK) return nullptr; |
| |
| // Verify rank and dtype of indices. |
| VerifyTensorRankAndDType(indices, num_inputs, /*expected_rank=*/2, |
| /*expected_dtype=*/&tf_int64, status); |
| if (TF_GetCode(status) != TF_OK) return nullptr; |
| |
| // Verify rank of values. |
| VerifyTensorRankAndDType(values, num_inputs, /*expected_rank=*/1, |
| /*expected_dtype=*/nullptr, status); |
| if (TF_GetCode(status) != TF_OK) return nullptr; |
| |
| // Compute the local shape from a shape tensor. |
| std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> shape_tensor( |
| TFE_TensorHandleResolve(shapes[0], status), TF_DeleteTensor); |
| if (TF_GetCode(status) != TF_OK) { |
| TF_SetStatus( |
| status, TF_GetCode(status), |
| absl::StrCat("Error resolving the tensor handle of shape tensor" |
| ". Original message: ", |
| TF_Message(status)) |
| .c_str()); |
| return nullptr; |
| } |
| int shape_tensor_size = TFE_TensorHandleDim(shapes[0], 0, status); |
| if (TF_GetCode(status) != TF_OK || shape_tensor_size <= 0) { |
| TF_SetStatus(status, TF_GetCode(status), |
| absl::StrCat("Error computing the num dims of shape tensor", |
| TF_Message(status)) |
| .c_str()); |
| return nullptr; |
| } |
| |
| const int64_t* data = |
| static_cast<int64_t*>(TF_TensorData(shape_tensor.get())); |
| std::vector<int64_t> local_shape(data, data + shape_tensor_size); |
| if (local_shape.size() != target_layout->rank()) { |
| TF_SetStatus( |
| status, TF_INVALID_ARGUMENT, |
| absl::StrCat( |
| "Packed layout should have the same rank as the rank for each " |
| "component. The rank of each component is: ", |
| local_shape.size(), |
| ", while layout has rank: ", target_layout->rank(), |
| "\nLayout: ", target_layout->ToString(), "\n") |
| .c_str()); |
| return nullptr; |
| } |
| |
| // Create the SparseTensorWithLayout. |
| std::unique_ptr<TensorWithLayout> packed_tensor; |
| if (is_remote_mesh(target_parallel_device->mesh_config())) { |
| // Create a dummy SparseTensorWithLayout. |
| packed_tensor = SparseTensorWithLayout::Dummy( |
| local_shape, *target_parallel_device, target_layout.ValueOrDie()); |
| } else { |
| // Parse the indices, values, and dense_shape tensors and put them into |
| // parallel tensors, and then pack it into a single SparseTensorWithLayout. |
| auto local_devices = target_parallel_device->mesh_config().local_devices(); |
| |
| std::vector<parallel_device::TensorHandlePtr> indices_components; |
| std::vector<parallel_device::TensorHandlePtr> values_components; |
| std::vector<parallel_device::TensorHandlePtr> dense_shapes_components; |
| |
| // Just a nice trick to make code cleaner to pack each of indices, values, |
| // shapes. |
| std::vector<std::vector<parallel_device::TensorHandlePtr>*> components{ |
| &indices_components, &values_components, &dense_shapes_components}; |
| std::vector<TFE_TensorHandle**> input_vectors{indices, values, shapes}; |
| for (int component_index = 0; component_index < 3; ++component_index) { |
| components[component_index]->reserve(num_inputs); |
| TFE_TensorHandle** inputs = input_vectors[component_index]; |
| for (int i = 0; i < num_inputs; ++i) { |
| const char* input_device = |
| TFE_TensorHandleDeviceName(inputs[i], status); |
| if (TF_GetCode(status) != TF_OK) return nullptr; |
| if (name_ == input_device) { |
| TF_SetStatus(status, TF_INVALID_ARGUMENT, |
| "Does not support packing a Tensor that is already on " |
| "dtensor device."); |
| return nullptr; |
| } |
| |
| components[component_index]->emplace_back(TFE_TensorHandleCopyToDevice( |
| inputs[i], context, local_devices[i].c_str(), status)); |
| if (TF_GetCode(status) != TF_OK) return nullptr; |
| } |
| } |
| std::unique_ptr<parallel_device::ParallelTensor> parallel_indices_tensor = |
| parallel_device::ParallelTensor::FromTensorHandles( |
| target_parallel_device->parallel_device(), |
| std::move(indices_components), status); |
| |
| std::unique_ptr<parallel_device::ParallelTensor> parallel_values_tensor = |
| parallel_device::ParallelTensor::FromTensorHandles( |
| target_parallel_device->parallel_device(), |
| std::move(values_components), status); |
| |
| std::unique_ptr<parallel_device::ParallelTensor> |
| parallel_dense_shapes_tensor = |
| parallel_device::ParallelTensor::FromTensorHandles( |
| target_parallel_device->parallel_device(), |
| std::move(dense_shapes_components), status); |
| |
| if (TF_GetCode(status) != TF_OK) return nullptr; |
| packed_tensor = |
| SparseTensorWithLayout::Wrap(std::move(parallel_indices_tensor), |
| std::move(parallel_values_tensor), |
| std::move(parallel_dense_shapes_tensor), |
| *target_parallel_device, |
| target_layout.ValueOrDie(), local_shape) |
| .ValueOrDie(); |
| } |
| |
| RecordInShapeLayoutCache(*packed_tensor); |
| TFE_TensorHandle* output = |
| MakeLayoutTensorHandle(context, std::move(packed_tensor), status); |
| if (TF_GetCode(status) != TF_OK) return nullptr; |
| return output; |
| } |
| |
| bool DTensorDevice::IsSparseDTensor(TFE_Context* context, |
| TFE_TensorHandle* input, |
| TF_Status* status) { |
| const char* input_device = TFE_TensorHandleDeviceName(input, status); |
| if (input_device != name_) { |
| TF_SetStatus( |
| status, TF_INVALID_ARGUMENT, |
| "DTensorSparseUnpack expects a tensor placed on the DTensor device."); |
| return false; |
| } |
| TensorWithLayout* t = reinterpret_cast<TensorWithLayout*>( |
| TFE_TensorHandleDevicePointer(input, status)); |
| if (TF_GetCode(status) != TF_OK) return false; |
| return t->tensor_type() == TensorType::kSparse; |
| } |
| |
| void DTensorDevice::UpdateOutputLayoutsWithSameShapePolicy( |
| const std::vector<PartialTensorShape>& global_output_shapes, |
| const absl::flat_hash_set<Mesh>& input_meshes, absl::string_view op_name, |
| tensorflow::Graph* graph, std::vector<const Layout*>* output_layouts, |
| TF_Status* status) { |
| if (!same_shape_policy_enabled_) return; |
| // Simply do not hint if inputs span across multiple meshes. |
| if (input_meshes.size() > 1) return; |
| |
| for (Node* node : graph->op_nodes()) { |
| if (!node->IsRetval()) { |
| continue; |
| } |
| int output_index; |
| RETURN_C_STATUS_IF_NOT_OK( |
| GetNodeAttr(node->attrs(), "index", &output_index), status); |
| if (output_layouts->at(output_index)) { |
| continue; |
| } |
| |
| const auto& global_output_shape = global_output_shapes.at(output_index); |
| const Layout* layout = nullptr; |
| // TODO(b/180022708): This is useful information, we should be |
| // able to hint to layout propagation without making it a hard |
| // requirement |
| // |
| // Special cases at the moment: |
| // - Relayout needs an exemption. |
| // - VarHandleOp does not need hint. VarHandleOp has scalar shape so layout |
| // is trivial. On the other hande, downstream system "thinks' Variable has |
| // shape same as the pointing value. So, providing a layout based on |
| // VarHandleOp (scalar) might confuse the downstream system. |
| if (op_name != std::string("Relayout") && |
| op_name != std::string("VarHandleOp")) { |
| // TODO(b/162009702): Support matching between partially-known shapes. |
| if (global_output_shape.IsFullyDefined()) { |
| gtl::InlinedVector<int64, 4> shape_vector( |
| global_output_shape.dim_sizes()); |
| auto layout_iterator = |
| shape_layout_cache_.find(FingerprintShape(shape_vector)); |
| if (layout_iterator != shape_layout_cache_.end() && |
| layout_iterator->second.is_unique) { |
| // We have a cached layout for this shape. Send it to MLIR. |
| layout = &layout_iterator->second.layout; |
| VLOG(3) << op_name << ": found a cached layout for shape " |
| << global_output_shape.DebugString() << ": \"" |
| << layout->ToString() << "\""; |
| if (input_meshes.empty() && |
| layout->mesh() != default_mesh_->mesh_config()) { |
| VLOG(3) << "But we can't infer a input mesh and cached layout: " |
| << "mesh \"" << (layout->mesh().ToString()) << " " |
| << "is different than the default mesh : \"" |
| << default_mesh_->mesh_config().ToString() << "\"\n" |
| << "Not applying the cached layout."; |
| } else if (!input_meshes.empty() && |
| layout->mesh() != *input_meshes.begin()) { |
| VLOG(3) |
| << "But the layout mesh is different than the executing mesh: " |
| << "\"" << (*input_meshes.begin()).ToString() << "\"\n" |
| << "Not applying the cached layout."; |
| } else { |
| (*output_layouts)[output_index] = layout; |
| node->AddAttr(kDefaultLayoutAttr, layout->ToString()); |
| } |
| } else if (layout_iterator == shape_layout_cache_.end()) { |
| VLOG(3) << op_name << ": no cached layout found for shape " |
| << global_output_shape.DebugString(); |
| } else { |
| VLOG(3) << op_name << ": found multiple layouts for shape " |
| << global_output_shape.DebugString(); |
| } |
| } else { |
| VLOG(3) << op_name |
| << ": not applying same-shape-same-layout due to " |
| "not-fully-known shape " |
| << global_output_shape.DebugString(); |
| } |
| } |
| } |
| } |
| |
| // Cache key computation should consider all features of an op that affects |
| // the SPMD lowering. The cache keys of two ops must be different if the |
| // translated functions are different. |
| // - op name and attr |
| // - input shapes and layouts |
| // - default layout of outputs. |
| tensorflow::Fprint128 CacheKeyForGraph( |
| const DTensorOperation& doperation, const NameAttrList& attributes, |
| const std::vector<TensorWithLayout*>& inputs, |
| const std::vector<const Layout*>& output_layouts) { |
| tensorflow::Fprint128 cache_key = tensorflow::Fingerprint128(doperation.name); |
| std::string serialized; |
| SerializeToStringDeterministic(attributes, &serialized); |
| cache_key = |
| FingerprintCat128(cache_key, tensorflow::Fingerprint128(serialized)); |
| for (const auto* input : inputs) { |
| cache_key = FingerprintCat128(cache_key, input->CacheKey()); |
| } |
| |
| for (int output_index = 0; output_index < output_layouts.size(); |
| ++output_index) { |
| if (output_layouts[output_index]) { |
| cache_key = FingerprintCat128(cache_key, output_index); |
| cache_key = FingerprintCat128( |
| cache_key, |
| tensorflow::Fingerprint128(output_layouts[output_index]->ToString())); |
| } |
| } |
| return cache_key; |
| } |
| |
| // From `graph` containing computation for all meshes, extract/select |
| // computation for mesh specified in `function`. Returned graph is a cloned |
| // graph with ops only for single mesh execution. |
| StatusOr<std::unique_ptr<Graph>> SelectGraphToExecute( |
| const TranslatedFunction& function, const Graph& graph, |
| std::string* stateful_partitioned_call_name) { |
| auto new_graph = absl::make_unique<Graph>(graph.flib_def()); |
| CopyGraph(graph, new_graph.get()); |
| std::vector<Node*> arg_nodes; |
| std::vector<Node*> retval_nodes; |
| for (Node* node : new_graph->nodes()) { |
| if (node->IsArg()) arg_nodes.emplace_back(node); |
| if (node->IsRetval()) retval_nodes.emplace_back(node); |
| } |
| |
| // Remove irrelevant function calls. |
| for (Node* node : new_graph->nodes()) { |
| if (node->op_def().name() != "StatefulPartitionedCall") continue; |
| |
| if (node->name() != function.node_to_execute->name()) { |
| // Remove function call that does not match mesh specification and all |
| // output retval nodes connected to the function call node. |
| std::queue<Node*> nodes_to_remove; |
| nodes_to_remove.push(node); |
| |
| while (!nodes_to_remove.empty()) { |
| Node* n = nodes_to_remove.front(); |
| for (const Edge* out_edge : n->out_edges()) { |
| if (out_edge->IsControlEdge()) continue; |
| Node* out_node = out_edge->dst(); |
| if (!out_node->IsSink()) nodes_to_remove.push(out_node); |
| } |
| if (n->IsRetval()) { |
| auto pos = std::find(retval_nodes.begin(), retval_nodes.end(), n); |
| TF_RET_CHECK(pos != retval_nodes.end()); |
| retval_nodes.erase(pos); |
| } |
| nodes_to_remove.pop(); |
| new_graph->RemoveNode(n); |
| } |
| } |
| } |
| |
| *stateful_partitioned_call_name = function.node_to_execute->name(); |
| VLOG(1) << "Selected call " << *stateful_partitioned_call_name; |
| |
| // Remove unused arg nodes in graph. |
| for (auto it = arg_nodes.begin(); it != arg_nodes.end(); it++) { |
| Node* arg_node = *it; |
| bool arg_unused = true; |
| for (const Edge* e : arg_node->out_edges()) { |
| if (e->dst()->IsOp()) { |
| arg_unused = false; |
| } |
| } |
| if (!arg_unused) continue; |
| |
| new_graph->RemoveNode(arg_node); |
| arg_nodes.erase(it--); |
| } |
| |
| // Reset index attributes for arg and retval nodes. |
| for (Node* n : new_graph->nodes()) { |
| // Reset arg node index attributes. |
| if (n->IsArg()) { |
| auto pos = std::find(arg_nodes.begin(), arg_nodes.end(), n); |
| TF_RET_CHECK(pos != arg_nodes.end()); |
| const int new_index = std::distance(arg_nodes.begin(), pos); |
| n->AddAttr("index", new_index); |
| } |
| |
| // Reset retval nodes index attributes. |
| if (n->IsRetval()) { |
| auto retval_pos = std::find(retval_nodes.begin(), retval_nodes.end(), n); |
| TF_RET_CHECK(retval_pos != retval_nodes.end()); |
| const int new_index = std::distance(retval_nodes.begin(), retval_pos); |
| n->AddAttr("index", new_index); |
| } |
| } |
| |
| VLOG(4) << tensorflow::DumpGraphToFile("selected_graph_to_execute_", |
| *new_graph); |
| |
| return new_graph; |
| } |
| |
| // Adds processed graph to run for each mesh computation in |
| // `execution_functions` to function definition library. |
| Status AddExecutionFunctionDefsToFunctionDefLibrary( |
| const absl::flat_hash_set<Node*>& control_ret_nodes, TFE_Context* context, |
| const Graph& graph, ExecutionFunctions* execution_functions) { |
| // Note: We use node name instead of node pointer for comparison because |
| // node address in the new graph is different with the original graph. |
| absl::flat_hash_set<std::string> control_ret_names; |
| for (auto* n : control_ret_nodes) { |
| control_ret_names.emplace(n->name()); |
| } |
| for (TranslatedFunction& function : execution_functions->function_list) { |
| std::string selected_call_node_name; |
| // TODO(bfontain): We should just try to call the functions directly rather |
| // than wrap |
| // Construct graph that executes only computation for `function`. |
| TF_ASSIGN_OR_RETURN( |
| std::unique_ptr<Graph> new_graph, |
| SelectGraphToExecute(function, graph, &selected_call_node_name)); |
| VLOG(4) << tensorflow::DumpGraphToFile("selected_graph_", *new_graph); |
| |
| // Add unique identifier based on the function we are executing to the |
| // function/graph and convert graph to functiondef. |
| NameAttrList func; |
| TF_RETURN_IF_ERROR( |
| GetNodeAttr(function.node_to_execute->attrs(), "f", &func)); |
| |
| static std::atomic<int64_t> unique_function_number(0); |
| function.translated_function_name = |
| absl::StrCat(func.name(), "_", unique_function_number.fetch_add(1)); |
| auto control_ret_node_names = |
| [&control_ret_names, &selected_call_node_name]( |
| const Node* node) -> absl::optional<std::string> { |
| // Add the stateful partitioned call node as a control return as we need |
| // to process any control deps inside the inner function. |
| if (control_ret_names.contains(node->name()) || |
| node->name() == selected_call_node_name) { |
| return node->name(); |
| } |
| return absl::nullopt; |
| }; |
| |
| tensorflow::FunctionDef to_run; |
| TF_RETURN_IF_ERROR(tensorflow::GraphToFunctionDef( |
| *new_graph, function.translated_function_name, control_ret_node_names, |
| &to_run)); |
| |
| for (const auto& out : to_run.signature().output_arg()) { |
| function.output_dtypes.emplace_back(static_cast<TF_DataType>(out.type())); |
| } |
| |
| AddDTensorFunctionAttr(to_run); |
| TF_RETURN_IF_ERROR(tensorflow::unwrap(context)->AddFunctionDef(to_run)); |
| } |
| |
| return Status::OK(); |
| } |
| |
| void DTensorDevice::LowerToSPMDFunction( |
| TFE_Context* context, const std::vector<TensorWithLayout*>& inputs, |
| const DTensorOperation& doperation, const TFE_OpAttrs* attributes, |
| const int num_outputs, const ExecutionFunctions** execution_functions, |
| TF_Status* status) { |
| profiler::TraceMe activity( |
| [&] { return "DTensorDevice::LowerToSPMDFunction"; }, |
| profiler::TraceMeLevel::kInfo); |
| FunctionLibraryDefinition* flib_def = |
| tensorflow::unwrap(context)->FuncLibDef(); |
| auto graph(absl::make_unique<tensorflow::Graph>(flib_def)); |
| NameAttrList eager_attributes; |
| ASSIGN_OR_RETURN_C_STATUS(eager_attributes, FetchAttributes(attributes), |
| status); |
| |
| std::vector<PartialTensorShape> global_output_shapes; |
| std::vector<const Layout*> output_layouts; |
| const FunctionDef* function_def = doperation.function_def; |
| if (!function_def) { |
| // Output layouts of an eager op (e.g. fill) must be inferred before cache |
| // key computation, since they might depend on the current DTensorDevice |
| // state. |
| Status s = PrepareGraphForMlir( |
| inputs, doperation, *flib_def, eager_attributes, default_layout_, |
| graph.get(), &global_output_shapes, &output_layouts); |
| RETURN_C_STATUS_IF_NOT_OK(s, status); |
| |
| // Finds all meshes the inputs are lied on. |
| absl::flat_hash_set<Mesh> input_meshes; |
| for (const TensorWithLayout* tensor : inputs) { |
| if (!tensor->layout().mesh().IsEmpty()) { |
| input_meshes.insert(tensor->layout().mesh()); |
| } |
| } |
| // Currently we only provide layout hints for op-by-op, since |
| // they interact badly with layout propagation. |
| UpdateOutputLayoutsWithSameShapePolicy(global_output_shapes, input_meshes, |
| doperation.name, graph.get(), |
| &output_layouts, status); |
| if (TF_GetCode(status) != TF_OK) return; |
| } |
| |
| const tensorflow::Fprint128 cache_key = |
| CacheKeyForGraph(doperation, eager_attributes, inputs, output_layouts); |
| *execution_functions = GetCachedFunction(cache_key); |
| if (*execution_functions != nullptr) { |
| return; |
| } else if (function_def) { |
| LOG(INFO) << "DTensor cache key lookup missed for " << doperation.name |
| << ". DTensor is (re-)computing its SPMD transformation."; |
| } |
| |
| // It includes remote devices when the coordination service is enabled. |
| const auto device_list = tensorflow::unwrap(context)->ListAllTfDevices(); |
| DeviceSet device_set; |
| for (const auto device : device_list) device_set.AddDevice(device); |
| |
| if (function_def) { |
| ASSIGN_OR_RETURN_C_STATUS(auto device_name_to_mesh_device, |
| PipelineSubMeshes(context), status); |
| const bool is_pipelining_function = !device_name_to_mesh_device.empty(); |
| // For a multi-mesh function for pipelining, we take a different execution |
| // path. Call the partitioner to lower and partition the graph into multiple |
| // sub functions to execute (one per sub mesh). |
| if (is_pipelining_function) { |
| ASSIGN_OR_RETURN_C_STATUS( |
| ExecutionFunctions functions, |
| PipeliningPartitionerRun(&device_name_to_mesh_device, flib_def, |
| &pass_runner_, *doperation.function_def, |
| eager_attributes, inputs, device_set, |
| num_outputs), |
| status); |
| *execution_functions = AddCachedFunction(cache_key, std::move(functions)); |
| return; |
| } |
| // Output layouts of a function are inferred by MLIR lowering. They are |
| // not necessary for cache key computation, so run PrepareGraphForMlir after |
| // cache key computation to reduce the overheads of running the same |
| // function multiple times. |
| Status s = PrepareGraphForMlir( |
| inputs, doperation, *flib_def, eager_attributes, default_layout_, |
| graph.get(), &global_output_shapes, &output_layouts); |
| RETURN_C_STATUS_IF_NOT_OK(s, status); |
| } |
| |
| TranslatedFunction function; |
| function.output_layouts.reserve(num_outputs); |
| |
| absl::flat_hash_set<Node*> control_ret_nodes; |
| // Run DTensor MLIR passes that convert input graph to SPMD version. |
| { |
| profiler::TraceMe activity([&] { return "DTensorDevice::RunMLIRPasses"; }, |
| profiler::TraceMeLevel::kInfo); |
| RETURN_C_STATUS_IF_NOT_OK( |
| pass_runner_.RunOnGraph(device_set, doperation.is_func(), flib_def, |
| &graph, control_ret_nodes, cache_key), |
| status); |
| } |
| VLOG(4) << tensorflow::DumpGraphToFile("after_mlir_spmd_lowering", *graph, |
| flib_def); |
| if (flib_def->Contains(kLoadEmbeddingFn)) { |
| Status s = InsertFunctionForTPUEmbeddingCheckpoint( |
| status, graph.get(), inputs, kLoadEmbeddingFn); |
| RETURN_C_STATUS_IF_NOT_OK(s, status); |
| } |
| |
| // After MLIR transformations, exactly one StatefulPartitionedCall op is |
| // returned for mesh cluster in computation. Identity all functions to execute |
| // for each mesh and relevant input and output information. |
| ASSIGN_OR_RETURN_C_STATUS( |
| ExecutionFunctions functions, |
| IdentifyAllFunctionsToExecute(*graph.get(), global_output_shapes), |
| status); |
| |
| // In order to ensure that all resource assign operations as well as side |
| // effecting ops are executed, we add identity ops before function outputs |
| // with control rets. |
| RETURN_C_STATUS_IF_NOT_OK(MaybeInsertIdentityNodes(function_def, graph.get()), |
| status); |
| |
| VLOG(4) << tensorflow::DumpGraphToFile("after_post_processing_graph", *graph, |
| flib_def); |
| |
| RETURN_C_STATUS_IF_NOT_OK( |
| AddExecutionFunctionDefsToFunctionDefLibrary(control_ret_nodes, context, |
| *graph.get(), &functions), |
| status); |
| functions.num_device_ids = 1; |
| if (function_def) { |
| for (TranslatedFunction& function : functions.function_list) { |
| functions.function_mesh_fingerprint = |
| FingerprintCat64(functions.function_mesh_fingerprint, |
| function.function_mesh.GlobalFingerprint()); |
| } |
| } |
| |
| *execution_functions = AddCachedFunction(cache_key, std::move(functions)); |
| } |
| |
| void DTensorDevice::ExecuteFunctionAndWait( |
| TFE_Context* context, const TranslatedFunction* function_ptr, |
| const MeshWithParallelDevice* parallel_device_mesh, |
| const std::vector<parallel_device::ParallelTensor*>& parallel_inputs, |
| const int64_t step_id, const TFE_OpAttrs* attributes, TF_Status* status) { |
| const std::string mesh_str = function_ptr->function_mesh.ToString(); |
| VLOG(4) << "Launching computation for mesh : " << mesh_str; |
| parallel_device_mesh->parallel_device().StartExecute( |
| context, |
| /*inputs=*/parallel_inputs, |
| /*operation_name=*/function_ptr->translated_function_name.c_str(), |
| /*attributes=*/attributes, |
| /*expected_max_outputs=*/function_ptr->local_output_shapes.size(), |
| /*cancellation_manager=*/*cancellation_manager_, |
| /*step_id=*/step_id); |
| |
| VLOG(4) << "Joining computation result from mesh : " << mesh_str; |
| parallel_device_mesh->parallel_device().Join( |
| function_ptr->local_output_shapes, status); |
| VLOG(4) << "Joining status: " << TF_Message(status); |
| if (TF_GetCode(status) != TF_OK && TF_GetCode(status) != TF_CANCELLED) { |
| LOG(ERROR) << "Encountered error while executing function: " |
| << function_ptr->translated_function_name |
| << " for mesh : " << mesh_str |
| << " / error : " << TF_Message(status); |
| } |
| |
| std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> async_wait_status( |
| TF_NewStatus(), TF_DeleteStatus); |
| AsyncWait(context, async_wait_status.get()); |
| TF_Code error_code = TF_GetCode(async_wait_status.get()); |
| if (error_code != TF_OK && error_code != TF_CANCELLED) { |
| LOG(ERROR) << "Async status: " << TF_Message(async_wait_status.get()); |
| } |
| } |
| |
| void DTensorDevice::ExecuteRegularOperation( |
| TFE_Context* context, const std::vector<TensorWithLayout*>& inputs, |
| const DTensorOperation& doperation, const TFE_OpAttrs* attributes, |
| int* num_outputs, TFE_TensorHandle** outputs, TF_Status* status) { |
| const ExecutionFunctions* execution_functions = nullptr; |
| |
| LowerToSPMDFunction(context, inputs, doperation, attributes, *num_outputs, |
| &execution_functions, status); |
| if (TF_GetCode(status) != TF_OK) return; |
| |
| // Update input layouts for resource arguments. |
| for (const TranslatedFunction& function : |
| execution_functions->function_list) { |
| for (const auto& entry : function.resource_input_layouts) { |
| // TODO(hthu): Add an TensorWithLayout in the inputs vector at location 0 |
| // for DeviceId. This is done as the first arg is always DeviceId, and it |
| // isn't mapped to input Tensors. |
| const int resource_index_to_update = entry.first - 1; |
| inputs[resource_index_to_update]->UpdateLayout(entry.second, status); |
| if (TF_GetCode(status) != TF_OK) { |
| RETURN_STATUS(status, TF_GetCode(status), |
| absl::StrCat("Attempt to update layout input arg: ", |
| resource_index_to_update, |
| ". Original message: ", TF_Message(status)) |
| .c_str()); |
| } |
| } |
| } |
| |
| int num_global_outputs = 0; |
| |
| // TODO(b/168730933): Lookup is slow as it takes all the devices in the Mesh |
| // object. Ideally we'd just use a fingerprinted int64_t as a unique |
| // identifier for a mesh. |
| std::map<std::string, const MeshWithParallelDevice*> |
| function_name_and_mesh_mapping; |
| absl::flat_hash_set<std::string> excluded_fn_names; |
| std::unique_ptr<const TranslatedFunction> epu_fn_ptr, load_embedding_ptr; |
| for (const TranslatedFunction& function : |
| execution_functions->function_list) { |
| StatusOr<Mesh> maybe_converted_mesh = function.function_mesh; |
| if (function.function_mesh.is_epu_mesh()) { |
| maybe_converted_mesh = function.function_mesh.ToDeviceType("CPU"); |
| } |
| |
| if (!maybe_converted_mesh.ok()) { |
| RETURN_STATUS(status, TF_INVALID_ARGUMENT, |
| absl::StrCat("Failed to convert mesh, get error: ", |
| maybe_converted_mesh.status().error_message()) |
| .c_str()); |
| } |
| const Mesh& mesh = *maybe_converted_mesh; |
| // TODO(b/168730933): Lookup is slow as it takes all the devices in the Mesh |
| // object. Ideally we'd just use a fingerprinted int64_t as a unique |
| // identifier for a mesh. |
| const MeshWithParallelDevice* parallel_device_mesh = |
| mesh_to_device_map_.contains(mesh) ? mesh_to_device_map_[mesh].get() |
| : default_mesh_; |
| if (parallel_device_mesh == nullptr) { |
| RETURN_STATUS(status, TF_INTERNAL, |
| "required mesh is not registered with DTensor device"); |
| } |
| function_name_and_mesh_mapping[function.translated_function_name] = |
| parallel_device_mesh; |
| |
| if (function.function_mesh.is_epu_mesh()) { |
| if (epu_fn_ptr != nullptr) { |
| RETURN_STATUS(status, TF_INTERNAL, |
| "There are more than one function defined on EPU mesh."); |
| } |
| epu_fn_ptr = std::make_unique<const TranslatedFunction>(function); |
| excluded_fn_names.insert(function.translated_function_name); |
| } |
| if (absl::StartsWith(function.translated_function_name, kLoadEmbeddingFn)) { |
| if (load_embedding_ptr != nullptr) { |
| RETURN_STATUS(status, TF_INTERNAL, |
| "There are more than one function defined on EPU mesh."); |
| } |
| load_embedding_ptr = std::make_unique<const TranslatedFunction>(function); |
| excluded_fn_names.insert(function.translated_function_name); |
| } |
| } |
| |
| // Compute the step_id based on the function_mesh_fingerprint and the |
| // corresponding function execution counter. |
| uint64 function_mesh_fingerprint = |
| execution_functions->function_mesh_fingerprint; |
| if (func_mesh_fingerprint_to_step_counter_.contains( |
| function_mesh_fingerprint)) { |
| func_mesh_fingerprint_to_step_counter_.at(function_mesh_fingerprint)++; |
| } else { |
| func_mesh_fingerprint_to_step_counter_.insert( |
| {function_mesh_fingerprint, 0}); |
| } |
| const uint64 step_id = FingerprintCat64( |
| function_mesh_fingerprint, |
| func_mesh_fingerprint_to_step_counter_.at(function_mesh_fingerprint)); |
| |
| // Execute excluded functions in sequence. |
| if (epu_fn_ptr != nullptr) { |
| ExecuteFunctionAndWait( |
| context, |
| /*function_ptr=*/epu_fn_ptr.get(), |
| /*parallel_device_mesh=*/ |
| function_name_and_mesh_mapping[epu_fn_ptr->translated_function_name], |
| /*parallel_inputs=*/{}, /*step_id=*/step_id, /*attributes=*/attributes, |
| /*status=*/status); |
| } |
| |
| if (load_embedding_ptr != nullptr) { |
| StatusOr<std::vector<parallel_device::ParallelTensor*>> parallel_inputs = |
| PrepareEmbeddingInputs(inputs); |
| if (!parallel_inputs.ok()) { |
| RETURN_STATUS(status, TF_INTERNAL, |
| parallel_inputs.status().error_message().c_str()); |
| } |
| ExecuteFunctionAndWait( |
| context, |
| /*function_ptr=*/load_embedding_ptr.get(), |
| /*parallel_device_mesh=*/ |
| function_name_and_mesh_mapping[load_embedding_ptr |
| ->translated_function_name], |
| /*parallel_inputs=*/*parallel_inputs, /*step_id=*/step_id, |
| /*attributes=*/attributes, /*status=*/status); |
| } |
| |
| // Execute all functions in parallel. |
| for (const TranslatedFunction& function : |
| execution_functions->function_list) { |
| const Mesh& mesh = function.function_mesh; |
| const std::string& translated_function_name = |
| function.translated_function_name; |
| |
| num_global_outputs += function.local_output_shapes.size(); |
| |
| if (is_remote_mesh(mesh) || |
| (excluded_fn_names.find(translated_function_name) != |
| excluded_fn_names.end())) { |
| // Skip execution for a translated function has remote mesh or when it is |
| // excluded. |
| continue; |
| } |
| |
| const MeshWithParallelDevice* parallel_device_mesh = |
| function_name_and_mesh_mapping[translated_function_name]; |
| |
| std::vector<parallel_device::ParallelTensor*> parallel_inputs; |
| parallel_inputs.reserve(inputs.size() + 1); |
| auto input_mapping = function.input_index_map; |
| std::sort(input_mapping.begin(), input_mapping.end()); |
| |
| absl::flat_hash_set<int> skip_input; |
| int offset_from_sparsetensors = 0; |
| std::vector<parallel_device::ParallelTensor*> sparse_parallel_inputs; |
| |
| for (const int global_index : input_mapping) { |
| if (skip_input.find(global_index) != skip_input.end()) continue; |
| auto input_index = global_index - execution_functions->num_device_ids - |
| offset_from_sparsetensors; |
| |
| if (global_index < execution_functions->num_device_ids) { |
| parallel_inputs.push_back( |
| parallel_device_mesh->DeviceIDs(context, status)); |
| if (TF_GetCode(status) != TF_OK) return; |
| } else if (inputs[input_index]->tensor_type() == TensorType::kSparse) { |
| // Save the SparseTensor component inputs so we can add it in later. |
| SparseTensorWithLayout* sparse_input = |
| dynamic_cast<SparseTensorWithLayout*>(inputs[input_index]); |
| sparse_parallel_inputs.insert( |
| sparse_parallel_inputs.end(), |
| {sparse_input->indices(), sparse_input->dense_shapes(), |
| sparse_input->values()}); |
| skip_input.insert({global_index + 1, global_index + 2}); |
| offset_from_sparsetensors += 2; |
| } else { |
| parallel_inputs.push_back(inputs[input_index]->tensor()); |
| } |
| } |
| // Add in the SparseTensors to the end. |
| parallel_inputs.insert(parallel_inputs.end(), |
| sparse_parallel_inputs.begin(), |
| sparse_parallel_inputs.end()); |
| |
| VLOG(4) << "Launching computation for mesh : " << mesh.ToString(); |
| parallel_device_mesh->parallel_device().StartExecute( |
| context, parallel_inputs, translated_function_name.c_str(), attributes, |
| /*expected_max_outputs=*/function.local_output_shapes.size(), |
| *cancellation_manager_, /*step_id=*/step_id); |
| } |
| |
| *num_outputs = num_global_outputs; |
| std::vector<std::unique_ptr<TensorWithLayout>> typed_outputs; |
| typed_outputs.resize(num_global_outputs); |
| |
| // Join all mesh computation together. |
| // TODO(b/177932563): Expose cancel logic to handle failures. |
| std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> join_status( |
| TF_NewStatus(), TF_DeleteStatus); |
| for (const TranslatedFunction& function : |
| execution_functions->function_list) { |
| // Skip execution for a function when it's excluded. |
| if (excluded_fn_names.contains(function.translated_function_name)) { |
| continue; |
| } |
| const Mesh& mesh = function.function_mesh; |
| // TODO(b/168730933): Lookup is slow as it takes all the devices in the Mesh |
| // object. Ideally we'd just use a fingerprinted int64_t as a unique |
| // identifier for a mesh. |
| const MeshWithParallelDevice* parallel_device_mesh = |
| function_name_and_mesh_mapping[function.translated_function_name]; |
| |
| std::vector<std::unique_ptr<TensorWithLayout>> output_with_layout; |
| output_with_layout.reserve(function.output_index_map.size()); |
| if (is_remote_mesh(mesh)) { |
| // Create dummy outputs on a remote mesh. |
| for (int i = 0; i < function.output_index_map.size(); ++i) { |
| const auto dim_sizes = function.local_output_shapes.at(i).dim_sizes(); |
| std::vector<int64_t> local_shape = |
| std::vector<int64_t>(dim_sizes.begin(), dim_sizes.end()); |
| TF_DataType dtype = |
| static_cast<TF_DataType>(function.output_dtypes.at(i)); |
| auto remote_output = |
| TensorWithLayout::Dummy(local_shape, dtype, *parallel_device_mesh, |
| function.output_layouts[i]); |
| output_with_layout.push_back(std::move(remote_output)); |
| } |
| } else { |
| VLOG(4) << "Joining computation result from mesh : " << mesh.ToString(); |
| auto result = parallel_device_mesh->parallel_device().Join( |
| function.local_output_shapes, status); |
| if (TF_GetCode(join_status.get()) != TF_OK && |
| // Preserve the first failure we see, but only if it is a real failure |
| // and not a cancellation (which was probably triggered by the error |
| // we want to propagate). |
| (TF_GetCode(status) == TF_OK || |
| TF_GetCode(join_status.get()) != TF_CANCELLED)) { |
| continue; |
| } |
| if (TF_GetCode(status) != TF_OK) { |
| if (TF_GetCode(status) != TF_CANCELLED) { |
| LOG(ERROR) << "Encountered error while executing function: " |
| << function.translated_function_name |
| << " for mesh : " << mesh.ToString() |
| << " / error : " << TF_Message(status); |
| } |
| TF_SetStatus(join_status.get(), TF_GetCode(status), TF_Message(status)); |
| continue; |
| } |
| |
| for (int i = 0; i < result->size(); ++i) { |
| ASSIGN_OR_RETURN_C_STATUS( |
| auto local_output, |
| TensorWithLayout::Wrap(std::move((*result)[i]), |
| *parallel_device_mesh, |
| function.output_layouts[i]), |
| status); |
| output_with_layout.push_back(std::move(local_output)); |
| } |
| } |
| |
| for (int i = 0; i < function.output_index_map.size(); ++i) { |
| // TODO(b/162744844): Generalize this pattern so that the extraction is |
| // not special cased. |
| if (function.shape_output_metadata.find(i) != |
| function.shape_output_metadata.end()) { |
| output_with_layout[i]->set_input_layout_for_shape_op_result( |
| function.shape_output_metadata.at(i)); |
| } |
| |
| RecordInShapeLayoutCache(*output_with_layout[i]); |
| typed_outputs[function.output_index_map[i]] = |
| std::move(output_with_layout[i]); |
| } |
| } |
| if (TF_GetCode(join_status.get()) != TF_OK) { |
| std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> async_wait_status( |
| TF_NewStatus(), TF_DeleteStatus); |
| AsyncWait(context, async_wait_status.get()); |
| TF_Code error_code = TF_GetCode(async_wait_status.get()); |
| if (error_code != TF_OK && error_code != TF_CANCELLED) { |
| // Ignore the AsyncWait() status return since we already have a bad status |
| // to propagate. We've just canceled a bunch of operations, so we expect |
| // cancellation status returns. We'll log anything else just to be safe. |
| LOG(ERROR) << "Error executing " << doperation.name << " " |
| << TF_Message(async_wait_status.get()); |
| } |
| |
| TF_SetStatus(status, TF_GetCode(join_status.get()), |
| TF_Message(join_status.get())); |
| return; |
| } |
| if (VLOG_IS_ON(2)) { |
| LOG(INFO) << "Executed " << doperation.name << ", got " |
| << typed_outputs.size() << " outputs:"; |
| for (const std::unique_ptr<TensorWithLayout>& output : typed_outputs) { |
| LOG(INFO) << " " << output->DebugString(); |
| } |
| } |
| if (doperation.name == std::string("VarHandleOp")) { |
| // For new variables, set the dereferenced shape/dtype so we can pass it in |
| // as _handle_dtype and _handle_shape in the future. |
| // |
| // Note that VarHandleOps generated by `tf.Variable` objects are always run |
| // eagerly, which is almost all of the op's usage in TF2. Theoretically a |
| // user could run it in a tf.function via tf.raw_ops.VarHandleOp, return it |
| // from that function, and add it as an input to another, and it would |
| // currently be missing handle information. |
| if (typed_outputs.size() != 1) { |
| RETURN_STATUS(status, TF_INTERNAL, |
| "Expected one output from VarHandleOp"); |
| } |
| NameAttrList name_and_attrs; |
| ASSIGN_OR_RETURN_C_STATUS(name_and_attrs, FetchAttributes(attributes), |
| status); |
| |
| typed_outputs[0]->UpdateShapeAndDType( |
| name_and_attrs.attr().at("shape").shape(), |
| name_and_attrs.attr().at("dtype").type(), status); |
| if (TF_GetCode(status) != TF_OK) return; |
| } |
| |
| for (int i = 0; i < *num_outputs; ++i) { |
| outputs[i] = |
| MakeLayoutTensorHandle(context, std::move(typed_outputs[i]), status); |
| if (TF_GetCode(status) != TF_OK) return; |
| } |
| } |
| |
| void DTensorDevice::Execute(const TFE_Op* original_op, int* num_outputs, |
| TFE_TensorHandle** outputs, TF_Status* status) { |
| TFE_Context* context = TFE_OpGetContext(original_op, status); |
| if (TF_GetCode(status) != TF_OK) return; |
| const char* operation_name = TFE_OpGetName(original_op, status); |
| if (TF_GetCode(status) != TF_OK) return; |
| const TFE_OpAttrs* attributes = TFE_OpGetAttrs(original_op); |
| int num_inputs = TFE_OpGetFlatInputCount(original_op, status); |
| if (TF_GetCode(status) != TF_OK) return; |
| std::vector<TFE_TensorHandle*> inputs_vector; |
| inputs_vector.reserve(num_inputs); |
| for (int input_index = 0; input_index < num_inputs; ++input_index) { |
| TFE_TensorHandle* input = |
| TFE_OpGetFlatInput(original_op, input_index, status); |
| if (TF_GetCode(status) != TF_OK) return; |
| inputs_vector.push_back(input); |
| } |
| TFE_TensorHandle** inputs = inputs_vector.data(); |
| |
| DTensorOperation dtensor_operation{}; |
| dtensor_operation.name = operation_name; |
| { |
| dtensor_operation.function_def = |
| tensorflow::unwrap(context)->FindFunctionDef(operation_name); |
| } |
| |
| // First handle DTensor-specific virtual operations. |
| bool is_op_handled = false; |
| MaybeHandleDTensorCustomOps(operation_name, num_inputs, attributes, context, |
| inputs, num_outputs, outputs, &is_op_handled, |
| status); |
| if (is_op_handled) return; |
| |
| // This isn't a special op, so we'll defer to TFE_Execute to actually execute |
| // it, but we'll also run DTensor MLIR passes and propagate the layout. |
| std::vector<TensorWithLayout*> typed_inputs; |
| std::vector<std::unique_ptr<TensorWithLayout>> inputs_with_no_layout; |
| |
| // Record a unique mesh identified through all inputs that's already on |
| // DTensor device. If we can identify a single mesh, the same mesh is used as |
| // the mesh to broadcast non-dtensor inputs. |
| absl::flat_hash_set<Mesh> input_meshes; |
| std::vector<int> not_on_device_input_indices; |
| |
| typed_inputs.resize(num_inputs); |
| for (int j = 0; j < num_inputs; ++j) { |
| TFE_TensorHandle* input = inputs[j]; |
| const char* input_device = TFE_TensorHandleDeviceName(input, status); |
| if (TF_GetCode(status) != TF_OK) return; |
| if (name_ != input_device) { |
| not_on_device_input_indices.push_back(j); |
| continue; |
| } |
| // Handle input which is on DTensor device already. |
| TensorWithLayout* t = reinterpret_cast<TensorWithLayout*>( |
| TFE_TensorHandleDevicePointer(input, status)); |
| if (TF_GetCode(status) != TF_OK) return; |
| |
| // VarHandleOp runs on empty mesh, and that isn't registered with device. |
| if (!t->layout().mesh().IsEmpty()) { |
| input_meshes.insert(t->layout().mesh()); |
| } |
| // Try to extract the input to a constant node for fully replicated small |
| // tensor. This is especially useful when executing in eager mode and |
| // allows certains op to work, e.g.: For reduce, the reduction_indices |
| // from BroadcastGradientArgs would be lifted as a constant that allows |
| // proper computation. |
| if (!dtensor_operation.is_func() && !t->const_value().has_value() && |
| t->layout().IsFullyReplicated()) { |
| absl::optional<NodeDef> maybe_const = |
| ExtractSmallTensorValue(context, input, t->layout(), status); |
| if (TF_GetCode(status) != TF_OK) return; |
| if (maybe_const.has_value()) { |
| // If we extracted a constant value from the tensor, check if this |
| // value was the output from `tf.shape`. In this case, we need to |
| // forward the kShapeOpInputLayout attribute to the new node def. This |
| // is needed for layout propagation when running in op-by-op mode. |
| // |
| // TODO(b/162747667): Improve the presentation for Shape input Op |
| // layout. |
| if (t->shape_metadata_layout().has_value()) { |
| AddNodeAttr(kShapeOpInputLayout, |
| {t->shape_metadata_layout()->ToString()}, |
| &(*maybe_const)); |
| } |
| t->set_const_value(maybe_const.value()); |
| } |
| } |
| typed_inputs[j] = t; |
| } |
| |
| // If a unique mesh is identified across all inputs, we use that mesh as the |
| // mesh to broadcast to. Otherwise we fallback to default mesh. |
| const MeshWithParallelDevice* broadcast_mesh = |
| input_meshes.size() == 1 |
| ? mesh_to_device_map_[*input_meshes.begin()].get() |
| : default_mesh_; |
| if (!broadcast_mesh) { |
| RETURN_STATUS(status, TF_INVALID_ARGUMENT, |
| "No mesh has been registered to DTensor. Use copy_to_mesh to " |
| "explicit specify a mesh instead."); |
| } |
| for (int not_on_device_input_index : not_on_device_input_indices) { |
| TFE_TensorHandle* input = inputs[not_on_device_input_index]; |
| // DTensor creation should be explicit, with some exceptions for usability |
| // (scalars/shapes/slice specs/etc.) Here we do some trivial validation to |
| // enforce this rule. |
| int num_dims = TFE_TensorHandleNumDims(input, status); |
| if (TF_GetCode(status) != TF_OK) return; |
| int64_t num_elements = TFE_TensorHandleNumElements(input, status); |
| if (TF_GetCode(status) != TF_OK) return; |
| TF_DataType dtype = TFE_TensorHandleDataType(input); |
| const bool small_int_tensor = num_elements < kSmallTensorThreshold && |
| (dtype == TF_INT32 || dtype == TF_INT64); |
| if (!(num_dims == 0 || dtype == TF_STRING || small_int_tensor)) { |
| std::vector<int64_t> tensor_shape(TensorShapeAsVector(input, status)); |
| if (TF_GetCode(status) != TF_OK) return; |
| RETURN_STATUS( |
| status, TF_UNIMPLEMENTED, |
| absl::StrCat( |
| "The op/function ", operation_name, |
| " got a regular tensor for input ", not_on_device_input_index, |
| " (shape ", ShapeToDebugString(tensor_shape), |
| ") but was expecting a DTensor. Currently only scalars and " |
| "small integer/string tensors are auto-broadcast to " |
| "DTensors. For other tensors, please use copy_to_mesh to " |
| "make a DTensor explicitly; note that this may be slow if it " |
| "happens frequently.") |
| .c_str()); |
| } |
| // Construct temporary TensorWithLayout objects for inputs that didn't |
| // have any to start. These are owned by the `inputs_with_no_layout` |
| // vector, whereas the input `TFE_TensorHandle`s maintain ownership for |
| // inputs that already had layouts (and therefor had TensorWithLayout |
| // objects). |
| std::unique_ptr<TensorWithLayout> wrapper = TensorWithLayout::Broadcast( |
| context, input, *broadcast_mesh, name_, status); |
| if (TF_GetCode(status) != TF_OK) return; |
| |
| if (!ShouldFoldInputArgument(dtensor_operation.is_func(), |
| dtensor_operation.name, |
| /*input_index=*/not_on_device_input_index)) { |
| wrapper->reset_const_value(); |
| } |
| typed_inputs[not_on_device_input_index] = wrapper.get(); |
| inputs_with_no_layout.emplace_back(wrapper.release()); |
| } |
| |
| ExecuteRegularOperation(context, typed_inputs, dtensor_operation, attributes, |
| num_outputs, outputs, status); |
| } |
| |
| void ExecuteOnDTensorDevice(const TFE_Op* original_op, int* num_outputs, |
| TFE_TensorHandle** outputs, TF_Status* status, |
| void* device_info) { |
| DTensorDevice* dev = reinterpret_cast<DTensorDevice*>(device_info); |
| dev->Execute(original_op, num_outputs, outputs, status); |
| } |
| |
| void DeleteDTensorDevice(void* device_info) { |
| delete static_cast<DTensorDevice*>(device_info); |
| } |
| |
| TFE_TensorHandle* CopyToDTensorDevice(TFE_Context* context, |
| TFE_TensorHandle* tensor, |
| TF_Status* status, void* device_info) { |
| TF_SetStatus(status, TF_UNIMPLEMENTED, |
| "Trying to copy a tensor on to a DTensor mesh without a layout " |
| "(use the CopyToMesh op for now)."); |
| return nullptr; |
| } |
| |
| TFE_TensorHandle* CopyFromDTensorDevice(TFE_Context* context, |
| TFE_TensorHandle* tensor, |
| const char* target_device_name, |
| TF_Status* status, void* device_info) { |
| TensorWithLayout* typed_input = reinterpret_cast<TensorWithLayout*>( |
| TFE_TensorHandleDevicePointer(tensor, status)); |
| if (!tensorflow::dtensor::Layout(typed_input->layout()).IsFullyReplicated()) { |
| TF_SetStatus(status, TF_UNIMPLEMENTED, |
| "Trying to copy a non-replicated DTensor is not supported."); |
| return nullptr; |
| } |
| if (typed_input->tensor()->dtype() == TF_RESOURCE) { |
| TF_SetStatus(status, TF_UNIMPLEMENTED, |
| "Trying to copy a DTensor resource handle is not supported."); |
| return nullptr; |
| } |
| DTensorDevice* dev = reinterpret_cast<DTensorDevice*>(device_info); |
| // Since operations are executed asynchronously, the operation which should |
| // produce the tensor we're trying to copy off the DTensor device may be |
| // canceled due to a failure on another device. If so, we want to report the |
| // failure that caused the cancellation, not the cancellation itself. This |
| // requires blocking waiting for other devices to flush their execution |
| // queues. |
| // Note that we also only need to sync the threads on the parallel_device() |
| // directly, or a context level sync might cause unintentional deadlocks when |
| // grabbing locks on other threads. |
| dev->AsyncWait(context, status); |
| if (TF_GetCode(status) != TF_OK) return nullptr; |
| return TFE_TensorHandleCopySharingTensor(typed_input->get_tensor(0), status); |
| } |
| |
| void AllocateDTensorDevice(absl::string_view device_name, |
| TFE_CustomDevice* device, void** device_info) { |
| device->copy_tensor_to_device = &CopyToDTensorDevice; |
| device->copy_tensor_from_device = &CopyFromDTensorDevice; |
| device->delete_device = &DeleteDTensorDevice; |
| device->execute = &ExecuteOnDTensorDevice; |
| *device_info = new DTensorDevice(device_name); |
| } |
| |
| void AddMesh(const std::string& serialized_mesh, void* device_info, |
| bool is_async, bool is_host_mesh, TF_Status* status) { |
| auto mesh_config_or_status = Mesh::FromString(serialized_mesh); |
| if (!mesh_config_or_status.ok()) { |
| TF_SetStatus(status, TF_INTERNAL, |
| absl::StrCat("Failed to parse mesh config. ", |
| mesh_config_or_status.status().error_message()) |
| .c_str()); |
| return; |
| } |
| auto mesh_config = mesh_config_or_status.ValueOrDie(); |
| std::vector<std::string> underlying_devices; |
| underlying_devices.insert(underlying_devices.end(), |
| mesh_config.local_devices().begin(), |
| mesh_config.local_devices().end()); |
| // DTensor uses multi-client setup which doesn't use remote eager, so we can |
| // enable eager async execution in ParallelDevice. |
| std::unique_ptr<tensorflow::parallel_device::ParallelDevice> parallel( |
| new tensorflow::parallel_device::ParallelDevice(underlying_devices, |
| is_async)); |
| |
| std::string composite_device_name; |
| if (absl::StartsWith(mesh_config.name(), kPipelineMeshNamePrefix)) { |
| composite_device_name = std::string( |
| absl::StripPrefix(mesh_config.name(), kPipelineMeshNamePrefix)); |
| } |
| |
| auto mesh = absl::make_unique<MeshWithParallelDevice>( |
| std::move(mesh_config), std::move(parallel), composite_device_name); |
| DTensorDevice* device = reinterpret_cast<DTensorDevice*>(device_info); |
| device->AddMesh(std::move(mesh), is_host_mesh); |
| } |
| |
| void ExperimentalSetDefaultLayout(const std::string& serialized_layout, |
| void* device_info, TF_Status* status) { |
| StatusOr<Layout> layout = Layout::FromString(serialized_layout); |
| if (!layout.ok()) { |
| RETURN_STATUS(status, TF_INTERNAL, layout.status().error_message().c_str()); |
| } |
| DTensorDevice* device = reinterpret_cast<DTensorDevice*>(device_info); |
| device->SetDefaultLayout(layout.ValueOrDie()); |
| } |
| |
| void ExperimentalClearDefaultLayout(void* device_info, TF_Status* status) { |
| DTensorDevice* device = reinterpret_cast<DTensorDevice*>(device_info); |
| device->ClearDefaultLayout(); |
| } |
| |
| void ExperimentalSetDefaultMesh(const std::string& serialized_mesh, |
| void* device_info, TF_Status* status) { |
| StatusOr<Mesh> mesh = Mesh::FromString(serialized_mesh); |
| if (!mesh.ok()) { |
| RETURN_STATUS(status, TF_INTERNAL, mesh.status().error_message().c_str()); |
| } |
| DTensorDevice* device = reinterpret_cast<DTensorDevice*>(device_info); |
| device->SetDefaultMesh(mesh.ValueOrDie()); |
| } |
| |
| void ExperimentalClearDefaultMesh(void* device_info, TF_Status* status) { |
| DTensorDevice* device = reinterpret_cast<DTensorDevice*>(device_info); |
| device->ClearDefaultMesh(); |
| } |
| |
| void SetSameShapePolicy(void* device_info, bool enabled) { |
| DTensorDevice* device = reinterpret_cast<DTensorDevice*>(device_info); |
| device->SetSameShapePolicy(enabled); |
| } |
| |
| void SetTPUCoreIDs(const std::string& mesh_name, |
| const std::vector<int>& tpu_core_ids, void* device_info, |
| TF_Status* status) { |
| DTensorDevice* device = reinterpret_cast<DTensorDevice*>(device_info); |
| RETURN_C_STATUS_IF_NOT_OK(device->SetTPUCoreIDs(mesh_name, tpu_core_ids), |
| status); |
| } |
| |
| void ClearTPUCoreIDs(void* device_info) { |
| DTensorDevice* device = reinterpret_cast<DTensorDevice*>(device_info); |
| device->ClearTPUCoreIDs(); |
| } |
| |
| std::vector<std::vector<int>> TPUCoreIDsToLocations( |
| TFE_Context* context, const std::vector<int>& tpu_core_ids, |
| void* device_info) { |
| DTensorDevice* device = reinterpret_cast<DTensorDevice*>(device_info); |
| return device->TPUCoreIDsToLocations(context, tpu_core_ids); |
| } |
| |
| std::vector<int> TPUCoreLocationsToIDs( |
| TFE_Context* context, |
| const std::vector<std::vector<int>>& tpu_core_locations, |
| void* device_info) { |
| DTensorDevice* device = reinterpret_cast<DTensorDevice*>(device_info); |
| return device->TPUCoreLocationsToIDs(context, tpu_core_locations); |
| } |
| |
| TFE_TensorHandle* Pack(TFE_Context* context, int num_inputs, |
| TFE_TensorHandle** inputs, |
| const std::string& string_layout, void* device_info, |
| TF_Status* status) { |
| DTensorDevice* device = reinterpret_cast<DTensorDevice*>(device_info); |
| return device->Pack(context, num_inputs, inputs, string_layout, status); |
| } |
| |
| std::vector<TFE_TensorHandle*> Unpack(TFE_Context* context, |
| TFE_TensorHandle* input, |
| void* device_info, TF_Status* status) { |
| DTensorDevice* device = reinterpret_cast<DTensorDevice*>(device_info); |
| return device->Unpack(context, input, status); |
| } |
| |
| std::string FetchLayout(TFE_Context* context, TFE_TensorHandle* input, |
| void* device_info, TF_Status* status) { |
| DTensorDevice* device = reinterpret_cast<DTensorDevice*>(device_info); |
| return device->FetchLayout(context, input, status); |
| } |
| |
| TFE_TensorHandle* SparsePack(TFE_Context* context, int num_inputs, |
| TFE_TensorHandle** indices, |
| TFE_TensorHandle** values, |
| TFE_TensorHandle** shapes, |
| const std::string& string_layout, |
| void* device_info, TF_Status* status) { |
| DTensorDevice* device = reinterpret_cast<DTensorDevice*>(device_info); |
| return device->SparsePack(context, num_inputs, indices, values, shapes, |
| string_layout, status); |
| } |
| |
| bool IsSparseDTensor(TFE_Context* context, TFE_TensorHandle* input, |
| void* device_info, TF_Status* status) { |
| DTensorDevice* device = reinterpret_cast<DTensorDevice*>(device_info); |
| return device->IsSparseDTensor(context, input, status); |
| } |
| } // namespace dtensor |
| } // namespace tensorflow |