| /* Copyright 2022 The TensorFlow Authors. All Rights Reserved. |
| |
| Licensed under the Apache License, Version 2.0 (the "License"); |
| you may not use this file except in compliance with the License. |
| You may obtain a copy of the License at |
| |
| http://www.apache.org/licenses/LICENSE-2.0 |
| |
| Unless required by applicable law or agreed to in writing, software |
| distributed under the License is distributed on an "AS IS" BASIS, |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| See the License for the specific language governing permissions and |
| limitations under the License. |
| ==============================================================================*/ |
| |
| #include "tensorflow/dtensor/cc/dtensor_device_util.h" |
| |
| #include <cstddef> |
| #include <string> |
| #include <utility> |
| |
| #include "absl/container/flat_hash_map.h" |
| #include "absl/strings/str_cat.h" |
| #include "tensorflow/c/eager/c_api_internal.h" |
| #include "tensorflow/c/eager/tfe_tensorhandle_internal.h" |
| #include "tensorflow/c/tf_status.h" |
| #include "tensorflow/compiler/xla/status_macros.h" |
| #include "tensorflow/core/common_runtime/graph_constructor.h" |
| #include "tensorflow/core/common_runtime/shape_refiner.h" |
| #include "tensorflow/core/framework/function.h" |
| #include "tensorflow/core/framework/node_def.pb.h" |
| #include "tensorflow/core/framework/node_def_util.h" |
| #include "tensorflow/core/framework/tensor.pb.h" |
| #include "tensorflow/core/framework/types.pb.h" |
| #include "tensorflow/core/graph/graph.h" |
| #include "tensorflow/core/lib/strings/proto_serialization.h" |
| #include "tensorflow/core/platform/errors.h" |
| #include "tensorflow/core/public/version.h" |
| #include "tensorflow/dtensor/cc/constants.h" |
| #include "tensorflow/dtensor/cc/dstatus.h" |
| #include "tensorflow/dtensor/cc/small_constant_optimization.h" |
| |
| namespace tensorflow { |
| namespace dtensor { |
| namespace { |
| // Represents an input node during graph construction. |
| // When executing a Function, `output` is used to align graph inputs |
| // with the inputs to the function call. |
| struct FunctionArgument { |
| Node* node; |
| NodeDefBuilder::NodeOut output; |
| }; |
| |
| bool LayoutsAreCompatible(absl::optional<Layout> first_layout, |
| absl::optional<Layout> second_layout) { |
| if (!first_layout.has_value() && !second_layout.has_value()) { |
| return true; |
| } |
| if (!first_layout.has_value() || !second_layout.has_value()) { |
| return false; |
| } |
| return first_layout.value() == second_layout.value(); |
| } |
| |
| // Parse a pair of attribute of (indices, layouts) into a map. |
| Status ParseAttrMap(const Node& node, absl::string_view indices_attr, |
| absl::string_view layout_attr, |
| std::map<int, Layout>* indices_layout_map) { |
| std::vector<std::string> layouts; |
| if (!TryGetNodeAttr(node.attrs(), layout_attr, &layouts)) { |
| return Status::OK(); |
| } |
| const TensorProto* indices; |
| if (!TryGetNodeAttr(node.attrs(), indices_attr, &indices)) { |
| return errors::Internal( |
| "Arg indices must be set when setting inferred resource layouts."); |
| } |
| if (indices->int_val_size() != layouts.size()) { |
| return errors::Internal( |
| "Arg indices for inferred resource argument must match the " |
| "size of inferred resource layout."); |
| } |
| for (int i = 0; i < indices->int_val_size(); ++i) { |
| const auto arg_index = indices->int_val(i); |
| const auto& arg_layout = layouts[i]; |
| indices_layout_map->emplace( |
| arg_index, |
| tensorflow::dtensor::Layout::FromString(arg_layout).ValueOrDie()); |
| } |
| return Status::OK(); |
| } |
| |
| Status ParseResourceArgumentLayouts( |
| const Node& node, std::map<int, Layout>* inferred_resource_input_layouts) { |
| return ParseAttrMap(node, kNewResourceLayoutIndices, kNewResourceArgLayouts, |
| inferred_resource_input_layouts); |
| } |
| |
| Status ParseShapeInputLayouts(const Node& node, |
| std::map<int, Layout>* shape_output_metadata) { |
| return ParseAttrMap(node, kShapeOpInputLayoutIndices, kShapeOpInputLayout, |
| shape_output_metadata); |
| } |
| |
| // Gets the layout attached to a specific node at a given index, ignoring any |
| // Identity ops. |
| StatusOr<Layout> GetLayoutThroughIdentityOps(Node* op, int output_index) { |
| while (op->op_def().name() == "Identity" || |
| op->op_def().name() == "IdentityN") { |
| const Edge* edge; |
| TF_RETURN_IF_ERROR(op->input_edge(output_index, &edge)); |
| op = edge->src(); |
| output_index = edge->src_output(); |
| } |
| const auto serialized_layouts = op->attrs().Find(kLayoutAttr); |
| |
| if (!serialized_layouts) { |
| return errors::InvalidArgument( |
| op->op_def().name(), " doesn't contain attribute : ", kLayoutAttr); |
| } |
| |
| // We assume that there is one layout for each output. |
| if (serialized_layouts->list().s_size() != op->num_outputs()) { |
| return errors::InvalidArgument( |
| "Number of outputs to ", op->op_def().name(), |
| " does not match number of layouts attached"); |
| } |
| |
| return Layout::FromString(serialized_layouts->list().s(output_index)); |
| } |
| |
| } // namespace |
| |
| tensorflow::Fprint128 TensorWithLayout::CacheKey() const { |
| tensorflow::Fprint128 f = tensorflow::Fingerprint128(layout_.ToString()); |
| // Use exact shape to compute the key. |
| for (const int64_t dim : local_shape()) { |
| f = FingerprintCat128(f, dim); |
| } |
| if (const_value_.has_value()) { |
| std::string serialized; |
| SerializeToStringDeterministic(const_value_.value(), &serialized); |
| f = FingerprintCat128(f, tensorflow::Fingerprint128(serialized)); |
| } |
| return f; |
| } |
| |
| std::unique_ptr<TensorWithLayout> TensorWithLayout::Broadcast( |
| TFE_Context* context, TFE_TensorHandle* tensor, |
| const MeshWithParallelDevice& mesh, const std::string& dtensor_device_name, |
| TF_Status* status) { |
| const char* input_device = TFE_TensorHandleDeviceName(tensor, status); |
| if (TF_GetCode(status) != TF_OK) return nullptr; |
| |
| if (dtensor_device_name == input_device) { |
| TF_SetStatus(status, TF_INVALID_ARGUMENT, |
| "Input to Broadcast must be eager tensor."); |
| return nullptr; |
| } |
| |
| if (TFE_TensorHandleDataType(tensor) == TF_RESOURCE) { |
| std::string error_message = |
| "Using a non-DTensor variable with DTensor is not supported. If you " |
| "are using a scope-based API, create variables inside the DTensor " |
| "scope.\n"; |
| |
| // Resolve the Tensor as resource handle and try to get the stack_trace and |
| // Summaries out of it. |
| std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> tf_tensor( |
| TFE_TensorHandleResolve(tensor, status), TF_DeleteTensor); |
| Tensor t; |
| Status convert_status = TF_TensorToTensor(tf_tensor.get(), &t); |
| if (convert_status.ok() && t.dtype() == DataType::DT_RESOURCE) { |
| ResourceHandle r = t.flat<ResourceHandle>()(0); |
| absl::StrAppend( |
| &error_message, "Offending variable summary: ", r.SummarizeValue(), |
| "\nStack trace: ", DefinitionLocationMsg(r.definition_stack_trace())); |
| } |
| TF_SetStatus(status, TF_INVALID_ARGUMENT, error_message.c_str()); |
| return nullptr; |
| } |
| |
| if (mesh.mesh_config().is_remote()) { |
| TF_DataType dtype = TFE_TensorHandleDataType(tensor); |
| std::vector<int64_t> shape(TensorShapeAsVector(tensor, status)); |
| if (TF_GetCode(status) != TF_OK) return nullptr; |
| auto layout = Layout::ReplicatedOnMesh(mesh.mesh_config(), shape.size()); |
| |
| auto ret = TensorWithLayout::Dummy(shape, dtype, mesh, layout); |
| absl::optional<NodeDef> const_value = |
| ExtractSmallTensorValue(context, tensor, layout, status); |
| if (TF_GetCode(status) != TF_OK) return nullptr; |
| if (const_value) { |
| ret->set_const_value(const_value.value()); |
| } |
| return ret; |
| } |
| |
| // Broadcast tensor value to local devices. |
| const Mesh& target_mesh = mesh.mesh_config(); |
| absl::Span<const std::string> local_devices = target_mesh.local_devices(); |
| const int num_local_devices = local_devices.size(); |
| |
| std::vector<parallel_device::TensorHandlePtr> components; |
| components.reserve(num_local_devices); |
| for (int i = 0; i < num_local_devices; ++i) { |
| // Create tensor copies to each local devices specifie by `target_mesh`. |
| components.emplace_back(TFE_TensorHandleCopyToDevice( |
| tensor, context, local_devices[i].c_str(), status)); |
| if (TF_GetCode(status) != TF_OK) { |
| TF_SetStatus( |
| status, TF_INTERNAL, |
| absl::StrCat( |
| "Unable to copy tensor value for broadcast. Original message: ", |
| TF_Message(status)) |
| .c_str()); |
| return nullptr; |
| } |
| } |
| |
| std::unique_ptr<parallel_device::ParallelTensor> parallel_tensor = |
| parallel_device::ParallelTensor::FromTensorHandles( |
| mesh.parallel_device(), std::move(components), status); |
| if (TF_GetCode(status) != TF_OK) return nullptr; |
| |
| const std::vector<int64_t>* shape; |
| Status s = parallel_tensor->Shape(&shape); |
| if (!s.ok()) { |
| TF_SetStatus(status, static_cast<TF_Code>(s.code()), |
| s.error_message().c_str()); |
| return nullptr; |
| } |
| size_t num_dims = shape->size(); |
| |
| const Layout layout = Layout::ReplicatedOnMesh(mesh.mesh_config(), num_dims); |
| absl::optional<NodeDef> const_value = |
| ExtractSmallTensorValue(context, tensor, layout, status); |
| if (TF_GetCode(status) != TF_OK) return nullptr; |
| |
| std::unique_ptr<TensorWithLayout> result(new TensorWithLayout( |
| std::move(parallel_tensor), mesh, std::move(layout), *shape, |
| /*dtype=*/absl::nullopt, std::move(const_value))); |
| return result; |
| } |
| |
| StatusOr<std::unique_ptr<TensorWithLayout>> TensorWithLayout::Wrap( |
| std::unique_ptr<parallel_device::ParallelTensor> tensor, |
| const MeshWithParallelDevice& mesh, const Layout& layout) { |
| const std::vector<int64_t>* shape; |
| TF_RETURN_IF_ERROR(tensor->Shape(&shape)); |
| |
| if (tensor->dtype() != TF_RESOURCE) { |
| return std::unique_ptr<TensorWithLayout>( |
| new TensorWithLayout(std::move(tensor), mesh, layout, *shape)); |
| } else { |
| return std::unique_ptr<TensorWithLayout>( |
| new ResourceHandleWithLayout(std::move(tensor), mesh, layout, *shape)); |
| } |
| } |
| |
| std::unique_ptr<TensorWithLayout> TensorWithLayout::Dummy( |
| const std::vector<int64_t>& local_shape, const TF_DataType dtype, |
| const MeshWithParallelDevice& mesh, const Layout& layout) { |
| if (dtype != TF_RESOURCE) { |
| return std::unique_ptr<TensorWithLayout>(new TensorWithLayout( |
| /*tensor=*/nullptr, mesh, layout, local_shape, dtype)); |
| } else { |
| return std::unique_ptr<TensorWithLayout>(new ResourceHandleWithLayout( |
| /*tensor=*/nullptr, mesh, layout, local_shape)); |
| } |
| } |
| |
| std::string TensorWithLayout::SummarizeValue() const { |
| std::string value_summary; |
| Status status; |
| if (layout().IsFullyReplicated()) { |
| status = |
| tensorflow::unwrap(tensor()->tensor(0))->SummarizeValue(value_summary); |
| } else { |
| // Note that this just prints the local values for sharded tensors. We could |
| // instead run a collective here to relayout to replicated. |
| status = tensor()->SummarizeValue(value_summary); |
| } |
| if (!status.ok()) { |
| value_summary = "<error computing value>"; |
| } |
| return absl::StrCat(value_summary, ", layout=\"", layout().ToString(), "\""); |
| } |
| |
| std::string TensorWithLayout::DebugString() const { |
| auto dtype = static_cast<DataType>(tensor()->dtype()); |
| |
| const auto& shape_vector = global_shape(); |
| return absl::StrCat("DTensor(", SummarizeValue(), |
| ", shape=", ShapeToDebugString(shape_vector), |
| ", type=", DataTypeString(dtype), ")"); |
| } |
| |
| void ResourceHandleWithLayout::EncodeAttributes( |
| tensorflow::NodeDefBuilder& builder) const { |
| // If set, attach shape and dtype to the given node def. |
| if (dereferenced_shape().has_value()) { |
| builder.Attr("_handle_shapes", {*dereferenced_shape()}); |
| } |
| if (dereferenced_dtype().has_value()) { |
| builder.Attr("_handle_dtypes", {*dereferenced_dtype()}); |
| } |
| } |
| |
| tensorflow::Fprint128 ResourceHandleWithLayout::CacheKey() const { |
| tensorflow::Fprint128 f = tensorflow::Fingerprint128(layout().ToString()); |
| if (dereferenced_shape().has_value()) { |
| std::string serialized; |
| SerializeToStringDeterministic(dereferenced_shape().value(), &serialized); |
| f = FingerprintCat128(f, tensorflow::Fingerprint128(serialized)); |
| } |
| if (dereferenced_dtype().has_value()) { |
| f = FingerprintCat128(f, dereferenced_dtype().value()); |
| } |
| return f; |
| } |
| |
| void ResourceHandleWithLayout::UpdateLayout(const Layout& new_layout, |
| TF_Status* status) { |
| // Only set the value for deferenced layout if the incoming layout is not |
| // empty. This is still hacky as we use empty layout as placeholder for |
| // eagerly placed VarHandleOp. |
| if (!dereferenced_layout_.has_value() && new_layout.IsEmpty()) return; |
| if (dereferenced_layout_.has_value() && |
| !LayoutsAreCompatible(dereferenced_layout_, new_layout)) { |
| // TODO(xiejw, allenl): Consider allowing variables to switch layouts. |
| RETURN_STATUS(status, TF_INVALID_ARGUMENT, |
| "Attempted to overwrite an existing Layout."); |
| } |
| dereferenced_layout_.emplace(new_layout); |
| } |
| |
| void ResourceHandleWithLayout::UpdateAttrs(const EmbeddingResourceAttrs& attrs, |
| TF_Status* status) { |
| if (attrs_.has_value()) { |
| RETURN_STATUS(status, TF_INVALID_ARGUMENT, |
| "Attepted to overwrite an existing embedding resource " |
| "attribute."); |
| } |
| attrs_.emplace(attrs); |
| } |
| |
| StatusOr<std::unique_ptr<TensorWithLayout>> SparseTensorWithLayout::Wrap( |
| std::unique_ptr<parallel_device::ParallelTensor> indices_tensor, |
| std::unique_ptr<parallel_device::ParallelTensor> values_tensor, |
| std::unique_ptr<parallel_device::ParallelTensor> shapes_tensor, |
| const MeshWithParallelDevice& mesh, const Layout& layout, |
| std::vector<int64_t> local_shape) { |
| return std::unique_ptr<TensorWithLayout>(new SparseTensorWithLayout( |
| std::move(indices_tensor), std::move(values_tensor), |
| std::move(shapes_tensor), mesh, layout, local_shape)); |
| } |
| |
| std::string SparseTensorWithLayout::SummarizeValue() const { |
| std::string indices_summary; |
| std::string values_summary; |
| std::string dense_shapes_summary; |
| |
| Status indices_status; |
| Status values_status; |
| Status dense_shapes_status; |
| |
| if (layout().IsFullyReplicated()) { |
| indices_status = tensorflow::unwrap(indices_->tensor(0)) |
| ->SummarizeValue(indices_summary); |
| values_status = |
| tensorflow::unwrap(values_->tensor(0))->SummarizeValue(values_summary); |
| dense_shapes_status = tensorflow::unwrap(dense_shapes_->tensor(0)) |
| ->SummarizeValue(dense_shapes_summary); |
| } else { |
| indices_status = indices_->SummarizeValue(indices_summary); |
| values_status = values_->SummarizeValue(values_summary); |
| dense_shapes_status = dense_shapes_->SummarizeValue(dense_shapes_summary); |
| } |
| |
| if (!indices_status.ok()) |
| values_summary = "<error computing summary for indices>"; |
| if (!values_status.ok()) |
| indices_summary = "<error computing summary for values>"; |
| if (!dense_shapes_status.ok()) |
| indices_summary = "<error computing summary for dense_shapes>"; |
| |
| return absl::StrCat("indices: ", indices_summary, ", ", |
| "values: ", values_summary, ", ", |
| "dense_shapes: ", dense_shapes_summary, ", layout=\"", |
| layout().ToString(), "\""); |
| } |
| |
| std::string SparseTensorWithLayout::DebugString() const { |
| auto dtype = static_cast<DataType>(values_->dtype()); |
| |
| const auto& shape_vector = global_shape(); |
| return absl::StrCat("DTensor(", SummarizeValue(), |
| ", shape=", ShapeToDebugString(shape_vector), |
| ", type=", DataTypeString(dtype), ")"); |
| } |
| |
| TF_DataType SparseTensorWithLayout::dtype() const { |
| if (dtype_.has_value()) { |
| return dtype_.value(); |
| } else { |
| return values_->dtype(); |
| } |
| } |
| |
| TFE_TensorHandle* SparseTensorWithLayout::get_tensor(size_t index) const { |
| int num_sparse_tensors = num_tensors() / 3; |
| if (index < num_sparse_tensors) { |
| return indices()->tensor(index); |
| } else if (index < 2 * num_sparse_tensors) { |
| return values()->tensor(index % num_sparse_tensors); |
| } else { |
| return dense_shapes()->tensor(index % num_sparse_tensors); |
| } |
| } |
| |
| std::vector<int64_t> TensorShapeAsVector(TFE_TensorHandle* tensor, |
| TF_Status* status) { |
| std::vector<int64_t> shape(TFE_TensorHandleNumDims(tensor, status)); |
| if (TF_GetCode(status) != TF_OK) return {}; |
| for (int i = 0; i < shape.size(); ++i) { |
| shape[i] = TFE_TensorHandleDim(tensor, i, status); |
| if (TF_GetCode(status) != TF_OK) return {}; |
| } |
| return shape; |
| } |
| |
| Status PrepareGraphForMlir( |
| const std::vector<TensorWithLayout*>& inputs, |
| const DTensorOperation& doperation, |
| const tensorflow::FunctionLibraryDefinition& flib_def, |
| const NameAttrList& attributes, |
| const absl::optional<Layout>& default_layout, tensorflow::Graph* graph, |
| std::vector<PartialTensorShape>* global_output_shapes, |
| std::vector<const Layout*>* output_layouts) { |
| // We run shape inference on the graph to find output shapes, which may |
| // determine default layouts. |
| ShapeRefiner shape_refiner(TF_GRAPH_DEF_VERSION, &flib_def); |
| shape_refiner.set_function_library_for_shape_inference(&flib_def); |
| tensorflow::Status status; |
| { |
| // We include an _Arg node for the device ID, but this isn't used by the |
| // initial function. It will be provided a value, though, so it's available |
| // for use in rewrites. |
| tensorflow::NodeDefBuilder builder("device_id", "_Arg"); |
| tensorflow::PartialTensorShape partial_shape; |
| TF_RETURN_IF_ERROR(tensorflow::PartialTensorShape::MakePartialShape( |
| static_cast<int*>(nullptr), 0, &partial_shape)); |
| tensorflow::NodeDef arg_node_def; |
| TF_RETURN_IF_ERROR(builder.Attr("shape", partial_shape) |
| .Attr("T", tensorflow::DT_INT32) |
| .Attr("index", 0) |
| .Finalize(&arg_node_def, /*consume=*/true)); |
| tensorflow::Node* arg_node = graph->AddNode(arg_node_def, &status); |
| TF_RETURN_IF_ERROR(status); |
| graph->AddControlEdge(graph->source_node(), arg_node); |
| TF_RETURN_IF_ERROR(shape_refiner.AddNode(arg_node)); |
| } |
| std::vector<FunctionArgument> graph_op_inputs; |
| graph_op_inputs.reserve(inputs.size()); |
| for (int i = 0; i < inputs.size(); ++i) { |
| const TensorWithLayout* input = inputs[i]; |
| // TODO(allenl): This will block until async execution is complete, which |
| // will be slow. We should find a non-blocking way of fetching the shape, |
| // at least pre-cache. |
| // The shape passed into MLIR transformation represents the global shape of |
| // the tensor. Ideally, the local shape on each parallel device should not |
| // be consulted at all and we should use the shape on our input tensor |
| // directly. |
| const auto& shape = input->global_shape(); |
| std::vector<tensorflow::int64> cast_shape(shape.begin(), shape.end()); |
| tensorflow::PartialTensorShape partial_shape; |
| // For resource tensors, `shape` attribute should not be specified as shape |
| // of resource tensors is specified by resource shape subtype -- not the |
| // shape attribute. |
| auto* resource = dynamic_cast<const ResourceHandleWithLayout*>(input); |
| if (!resource) { |
| TF_RETURN_IF_ERROR(tensorflow::PartialTensorShape::MakePartialShape( |
| cast_shape.data(), cast_shape.size(), &partial_shape)); |
| } |
| |
| tensorflow::NodeDef arg_node_def; |
| auto dtype = static_cast<tensorflow::DataType>(input->dtype()); |
| tensorflow::NodeDefBuilder builder(absl::StrCat("op_input_", i), "_Arg"); |
| |
| // Delegate TensorWithLayout to encode attributes if applicable. |
| input->EncodeAttributes(builder); |
| |
| TF_RETURN_IF_ERROR( |
| builder.Attr("shape", partial_shape) |
| .Attr("T", dtype) |
| .Attr("index", i + 1) // Indices are offset by 1 for device_id |
| .Attr(kLayoutAttr, input->layout().ToString()) |
| .Attr(kMeshAttr, input->mesh().mesh_config().ToString()) |
| .Finalize(&arg_node_def, /*consume=*/true)); |
| Node* arg_node = graph->AddNode(arg_node_def, &status); |
| TF_RETURN_IF_ERROR(status); |
| TF_RETURN_IF_ERROR(shape_refiner.AddNode(arg_node)); |
| |
| shape_inference::InferenceContext* inference_context = |
| shape_refiner.GetContext(arg_node); |
| shape_inference::ShapeHandle shape_handle; |
| TF_RETURN_IF_ERROR(inference_context->MakeShapeFromPartialTensorShape( |
| partial_shape, &shape_handle)); |
| TF_RETURN_IF_ERROR(shape_refiner.SetShape(arg_node, 0, shape_handle)); |
| |
| // Small constants are converted into constant graph nodes, instead of being |
| // passed in as input arguments. This provides more information to the SPMD |
| // and layout propagation passes. |
| if (!input->const_value().has_value()) { |
| graph_op_inputs.push_back(FunctionArgument{ |
| arg_node, NodeDefBuilder::NodeOut{arg_node->name(), i, dtype}}); |
| graph->AddControlEdge(graph->source_node(), arg_node); |
| } else { |
| // TODO(xiejw): Refactor the TensorWithLayout representation to avoid |
| // special code here. |
| NodeDef const_node = input->const_value().value(); |
| const_node.set_name(absl::StrCat("input_", i, "_const_value")); |
| Node* const_value_n = graph->AddNode(const_node, &status); |
| TF_RETURN_IF_ERROR(status); |
| TF_RETURN_IF_ERROR(shape_refiner.AddNode(const_value_n)); |
| graph_op_inputs.push_back(FunctionArgument{ |
| const_value_n, tensorflow::NodeDefBuilder::NodeOut{ |
| const_value_n->name(), i, dtype}}); |
| } |
| } |
| |
| tensorflow::NodeDef op_node_def; |
| const FunctionDef* function_def = doperation.function_def; |
| if (function_def) { |
| AttrValue func_attr; |
| func_attr.mutable_func()->set_name(doperation.name); |
| std::vector<tensorflow::NodeDefBuilder::NodeOut> func_inputs; |
| std::vector<tensorflow::DataType> inputs_types; |
| for (const auto& in : graph_op_inputs) { |
| func_inputs.emplace_back(in.output); |
| inputs_types.emplace_back(in.output.data_type); |
| } |
| |
| std::vector<tensorflow::DataType> output_types; |
| for (const auto& out : function_def->signature().output_arg()) |
| output_types.emplace_back(out.type()); |
| |
| TF_RETURN_IF_ERROR( |
| NodeDefBuilder("eager_operation", "StatefulPartitionedCall") |
| .Attr("Tin", inputs_types) |
| .Attr("Tout", output_types) |
| .Attr("f", func_attr) |
| .Input(func_inputs) |
| .Finalize(&op_node_def, true)); |
| } else { |
| op_node_def.set_op(doperation.name); |
| op_node_def.set_name("eager_operation"); |
| } |
| |
| op_node_def.mutable_attr()->insert(attributes.attr().begin(), |
| attributes.attr().end()); |
| |
| tensorflow::Node* op_node = graph->AddNode(op_node_def, &status); |
| TF_RETURN_IF_ERROR(status); |
| |
| for (int i = 0; i < graph_op_inputs.size(); ++i) { |
| graph->AddEdge(graph_op_inputs[i].node, 0, op_node, i); |
| } |
| TF_RETURN_IF_ERROR(shape_refiner.AddNode(op_node)); |
| |
| output_layouts->clear(); |
| output_layouts->reserve(op_node->num_outputs()); |
| global_output_shapes->reserve(op_node->num_outputs()); |
| for (int output_index = 0; output_index < op_node->num_outputs(); |
| ++output_index) { |
| tensorflow::NodeDefBuilder builder(absl::StrCat("op_output_", output_index), |
| "_Retval"); |
| tensorflow::NodeDef ret_node_def; |
| tensorflow::DataType output_type = op_node->output_type(output_index); |
| |
| TF_RETURN_IF_ERROR(builder.Attr("T", output_type) |
| .Attr("index", output_index) |
| .Input("eager_operation", output_index, output_type) |
| .Finalize(&ret_node_def, /*consume=*/true)); |
| tensorflow::Node* ret_node = graph->AddNode(ret_node_def, &status); |
| TF_RETURN_IF_ERROR(status); |
| graph->AddEdge(op_node, output_index, ret_node, 0); |
| graph->AddControlEdge(ret_node, graph->sink_node()); |
| |
| shape_inference::InferenceContext* inference_context = |
| shape_refiner.GetContext(op_node); |
| shape_inference::ShapeHandle output_shape_handle = |
| inference_context->output(output_index); |
| TensorShapeProto output_shape_proto; |
| inference_context->ShapeHandleToProto(output_shape_handle, |
| &output_shape_proto); |
| PartialTensorShape global_output_shape(output_shape_proto); |
| VLOG(3) << "Inferred shape for operation '" << doperation.name |
| << "':" << global_output_shape.DebugString(); |
| global_output_shapes->push_back(global_output_shape); |
| |
| const Layout* layout = nullptr; |
| if (default_layout.has_value() && output_index == 0) { |
| // Record the user's requested output layout. The scope currently only |
| // covers the first output of an op. |
| layout = &default_layout.value(); |
| ret_node->AddAttr(kDefaultLayoutAttr, layout->ToString()); |
| } |
| output_layouts->push_back(layout); |
| } |
| return Status::OK(); |
| } |
| |
| // Returns set of functions to run to execute DTensor computation. |
| StatusOr<ExecutionFunctions> IdentifyAllFunctionsToExecute( |
| const tensorflow::Graph& graph, |
| const std::vector<PartialTensorShape>& global_output_shapes) { |
| ExecutionFunctions execution_functions; |
| execution_functions.function_list = std::vector<TranslatedFunction>(); |
| for (Node* node : graph.nodes()) { |
| if (node->op_def().name() != "StatefulPartitionedCall") continue; |
| // Extract mesh to execute the function. |
| std::string serialized_mesh; |
| TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), kMeshAttr, &serialized_mesh)); |
| Mesh mesh; |
| TF_ASSIGN_OR_RETURN(mesh, Mesh::FromString(serialized_mesh)); |
| |
| TranslatedFunction function; |
| function.function_mesh = std::move(mesh); |
| function.node_to_execute = node; |
| |
| // Identify input arg information. |
| TF_RETURN_IF_ERROR( |
| ParseResourceArgumentLayouts(*node, &function.resource_input_layouts)); |
| |
| TF_RETURN_IF_ERROR( |
| ParseShapeInputLayouts(*node, &function.shape_output_metadata)); |
| |
| function.input_index_map.resize(node->num_inputs()); |
| // Identity mapping between local mesh function input index and global |
| // input index. |
| for (int in_index = 0; in_index < node->num_inputs(); ++in_index) { |
| Node* input_node; |
| |
| TF_RETURN_IF_ERROR(node->input_node(in_index, &input_node)); |
| if (!input_node->IsArg()) |
| return errors::InvalidArgument( |
| "Input node to mesh computation must be arg node."); |
| |
| int global_index; |
| TF_RETURN_IF_ERROR( |
| GetNodeAttr(input_node->attrs(), "index", &global_index)); |
| function.input_index_map[in_index] = global_index; |
| } |
| |
| // Identify output mappings and layouts for each outputs. |
| std::map<int, const Edge*> output_edges; |
| for (const Edge* out_edge : node->out_edges()) { |
| if (out_edge->IsControlEdge()) continue; |
| |
| const Node* retval_or_identity_node = out_edge->dst(); |
| while (retval_or_identity_node->IsIdentity()) { |
| retval_or_identity_node = |
| *(retval_or_identity_node->out_nodes().begin()); |
| } |
| |
| TF_RET_CHECK(retval_or_identity_node->IsRetval()); |
| int global_index; |
| TF_RETURN_IF_ERROR(GetNodeAttr(retval_or_identity_node->attrs(), "index", |
| &global_index)); |
| output_edges[global_index] = out_edge; |
| } |
| |
| for (auto it = output_edges.begin(); it != output_edges.end(); it++) { |
| const int global_index = it->first; |
| function.output_index_map.emplace_back(global_index); |
| |
| const Edge* retval_edge = it->second; |
| const int output_index = retval_edge->src_output(); |
| |
| // Add output layout and shape information. |
| TF_ASSIGN_OR_RETURN( |
| const Layout output_layout, |
| GetLayoutThroughIdentityOps(retval_edge->src(), output_index)); |
| |
| function.output_layouts.emplace_back(output_layout); |
| function.local_output_shapes.emplace_back( |
| output_layout.LocalShapeFromGlobalShape( |
| global_output_shapes[global_index])); |
| } |
| |
| execution_functions.function_list.emplace_back(std::move(function)); |
| } |
| |
| if (execution_functions.function_list.empty()) { |
| return errors::InvalidArgument( |
| "MLIR transformed graph does not have any functions to execute for " |
| "mesh."); |
| } |
| |
| return execution_functions; |
| } |
| |
| // For functions with control outputs, add identity nodes between |
| // StatefulPartitionedCall and _Retvals, in order to preserve control output |
| // dependencies after StatefulPartitionedCall is inlined at runtime. |
| // Consider calling this in PrepareGraphForMlir, once the identity nodes won't |
| // be dropped during MLIR lowering. |
| // TODO(b/171265131): fix the underlying issue to avoid inserting identity |
| // nodes. |
| Status MaybeInsertIdentityNodes(const FunctionDef* function_def, Graph* graph) { |
| if (function_def == nullptr || function_def->control_ret().empty()) { |
| return Status::OK(); |
| } |
| tensorflow::Status status; |
| for (Node* n : graph->nodes()) { |
| if (!n->IsRetval()) { |
| continue; |
| } |
| const Edge* edge; |
| TF_RETURN_IF_ERROR(n->input_edge(0, &edge)); |
| int ret_index; |
| TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &ret_index)); |
| tensorflow::NodeDefBuilder identity_builder( |
| absl::StrCat("op_output_identity_", ret_index), "Identity"); |
| tensorflow::NodeDef ret_identity_node_def; |
| tensorflow::DataType output_type = n->input_type(0); |
| TF_RETURN_IF_ERROR( |
| identity_builder.Attr("T", output_type) |
| .Input(edge->src()->name(), edge->src_output(), output_type) |
| .Finalize(&ret_identity_node_def, /*consume=*/true)); |
| Node* ret_identity_node = graph->AddNode(ret_identity_node_def, &status); |
| TF_RETURN_IF_ERROR(status); |
| // Delete the edge between StatefulPartitionedCall and _Retval. |
| graph->RemoveEdge(edge); |
| // Add an edge between StatefulPartitionedCall and Identity. |
| graph->AddEdge(edge->src(), edge->src_output(), ret_identity_node, 0); |
| graph->AddControlEdge(edge->src(), ret_identity_node); |
| // Add an edge between Identity and _Retval. |
| graph->AddEdge(ret_identity_node, 0, n, 0); |
| } |
| return Status::OK(); |
| } |
| |
| void AddDTensorFunctionAttr(FunctionDef& function_def) { |
| // Do not xla compile function returned by DTensor MLIR graph transformation |
| // as it already returns compiled graph. |
| AttrValue xla_must_compile_val; |
| xla_must_compile_val.set_b(false); |
| function_def.mutable_attr()->insert( |
| {"_XlaMustCompile", xla_must_compile_val}); |
| |
| // Explicitly place function outputs on the default function device to avoid |
| // redundant host <-> device copies (Placer may place outputs on the host |
| // CPU). |
| AttrValue outputs_on_op_device; |
| outputs_on_op_device.set_b(true); |
| function_def.mutable_attr()->insert( |
| {"_OutputsOnOpDevice", outputs_on_op_device}); |
| } |
| |
| StatusOr<std::vector<parallel_device::ParallelTensor*>> PrepareEmbeddingInputs( |
| const std::vector<TensorWithLayout*>& inputs) { |
| absl::flat_hash_map<int64_t, int64_t> table_and_input_index; |
| for (int64_t i = 0; i < inputs.size(); ++i) { |
| if (inputs[i]->tensor_type() != kResource) continue; |
| auto* resource = dynamic_cast<ResourceHandleWithLayout*>(inputs[i]); |
| if (resource == nullptr) { |
| return errors::Internal("Failed to cast a resource handle"); |
| } |
| |
| const absl::optional<EmbeddingResourceAttrs>& resource_attrs = |
| resource->attrs(); |
| if (resource_attrs.has_value()) { |
| table_and_input_index.insert({resource_attrs->table_id, i}); |
| } |
| } |
| |
| // Check if there is no embedding resource input found. |
| if (table_and_input_index.empty()) { |
| return errors::Internal("There are no TPU embedding resource input found."); |
| } |
| const size_t num_tables = table_and_input_index.size(); |
| std::vector<parallel_device::ParallelTensor*> parallel_inputs; |
| parallel_inputs.reserve(num_tables); |
| |
| // Assure parallel inputs has numeric order as table ids. |
| for (int64_t table_id = 0; table_id < num_tables; ++table_id) { |
| const int64_t input_index = table_and_input_index[table_id]; |
| parallel_inputs.push_back(inputs[input_index]->tensor()); |
| } |
| return parallel_inputs; |
| } |
| |
| StatusOr<std::map<int64_t, Node*>> GetTPUEmbeddingInputNodes( |
| TF_Status* s, const Graph& graph, |
| const std::vector<TensorWithLayout*>& inputs) { |
| std::map<int64_t, Node*> table_id_node_map; |
| for (Node* node : graph.nodes()) { |
| if (!node->IsArg()) continue; |
| |
| const int64_t& arg_id = node->attrs().Find("index")->i(); |
| const AttrValue* embedding_attr = |
| node->attrs().Find("_tpu_embedding_table_id"); |
| |
| if (embedding_attr == nullptr) continue; |
| |
| // Offset due to device id. |
| const int64_t table_id = embedding_attr->i(); |
| EmbeddingResourceAttrs embedding_attrs; |
| embedding_attrs.table_id = table_id; |
| inputs[arg_id - 1]->UpdateAttrs(embedding_attrs, s); |
| if (!s->status.ok()) { |
| return errors::Internal( |
| "Failed to set embedding resource attrs. \n Got error: ", |
| s->status.error_message()); |
| } |
| table_id_node_map.insert({table_id, node}); |
| } |
| return table_id_node_map; |
| } |
| |
| StatusOr<std::string> ValidateResourceMeshConsistency( |
| const std::vector<TensorWithLayout*>& inputs) { |
| std::string mesh_str; |
| for (TensorWithLayout* inp : inputs) { |
| if (inp->tensor_type() != kResource) continue; |
| |
| auto* resource = dynamic_cast<ResourceHandleWithLayout*>(inp); |
| if (!resource || !resource->attrs().has_value()) continue; |
| const std::string& input_mesh_str = inp->layout().mesh().ToString(); |
| if (mesh_str.empty()) { |
| mesh_str = input_mesh_str; |
| } else if (mesh_str != input_mesh_str) { |
| return errors::Internal(absl::StrCat( |
| "All inputs of embedding resource must be on same mesh. but get : ", |
| mesh_str, " != ", input_mesh_str)); |
| } |
| } |
| VLOG(1) << "Resource input mesh is : " << mesh_str; |
| return mesh_str; |
| } |
| |
| Status InsertFunctionForTPUEmbeddingCheckpoint( |
| TF_Status* status, Graph* graph, |
| const std::vector<TensorWithLayout*>& inputs, |
| const std::string& checkpoint_fn_name) { |
| if (checkpoint_fn_name != kLoadEmbeddingFn && |
| checkpoint_fn_name != kRetrieveEmbeddingFn) { |
| return errors::InvalidArgument(absl::StrCat( |
| "Found wrong function name: ", checkpoint_fn_name, |
| " \n expects : ", kLoadEmbeddingFn, " or ", kRetrieveEmbeddingFn)); |
| } |
| |
| StatusOr<std::map<int64_t, Node*>> table_id_node_map = |
| GetTPUEmbeddingInputNodes(status, *graph, inputs); |
| if (!table_id_node_map.ok()) { |
| return errors::Internal(table_id_node_map.status().error_message()); |
| } |
| |
| StatusOr<std::string> mesh_str = ValidateResourceMeshConsistency(inputs); |
| |
| const int64_t& num_tables = table_id_node_map->size(); |
| NodeDef func_node_def; |
| std::vector<NodeDefBuilder::NodeOut> func_inputs; |
| std::vector<DataType> input_types, output_types; |
| |
| func_inputs.reserve(num_tables); |
| input_types.reserve(num_tables); |
| |
| for (int i = 0; i < num_tables; ++i) { |
| auto node_ptr = table_id_node_map->find(i); |
| if (node_ptr == table_id_node_map->end()) { |
| return errors::Internal( |
| absl::StrCat("Embedding table id ", i, " is not found.")); |
| } |
| const std::string& node_name = node_ptr->second->name(); |
| func_inputs.push_back({node_name, i, DT_RESOURCE}); |
| input_types.push_back(DT_RESOURCE); |
| } |
| |
| AttrValue mesh_attr; |
| *mesh_attr.mutable_s() = *mesh_str; |
| NameAttrList func_attr; |
| func_attr.set_name(checkpoint_fn_name); |
| TF_RETURN_IF_ERROR( |
| NodeDefBuilder(checkpoint_fn_name, "StatefulPartitionedCall") |
| .Attr("Tin", input_types) |
| .Attr("Tout", output_types) |
| .Attr("f", func_attr) |
| .Attr(kMeshAttr, mesh_attr) |
| .Attr("config", mesh_attr) |
| .Input(func_inputs) |
| .Finalize(&func_node_def, true)); |
| |
| TF_ASSIGN_OR_RETURN(Node * func_node, graph->AddNode(func_node_def)); |
| for (int i = 0; i < num_tables; ++i) { |
| Node* node = table_id_node_map->find(i)->second; |
| graph->AddEdge(node, 0, func_node, i); |
| } |
| |
| return Status::OK(); |
| } |
| |
| } // namespace dtensor |
| } // namespace tensorflow |