| /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. |
| |
| Licensed under the Apache License, Version 2.0 (the "License"); |
| you may not use this file except in compliance with the License. |
| You may obtain a copy of the License at |
| |
| http://www.apache.org/licenses/LICENSE-2.0 |
| |
| Unless required by applicable law or agreed to in writing, software |
| distributed under the License is distributed on an "AS IS" BASIS, |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| See the License for the specific language governing permissions and |
| limitations under the License. |
| ==============================================================================*/ |
| |
| #include "tensorflow/core/grappler/optimizers/data/rebatch.h" |
| |
| #include "absl/container/flat_hash_map.h" |
| #include "tensorflow/core/framework/attr_value_util.h" |
| #include "tensorflow/core/framework/node_def.pb.h" |
| #include "tensorflow/core/framework/node_def_builder.h" |
| #include "tensorflow/core/framework/node_def_util.h" |
| #include "tensorflow/core/framework/op.h" |
| #include "tensorflow/core/framework/versions.pb.h" |
| #include "tensorflow/core/grappler/grappler_item.h" |
| #include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h" |
| #include "tensorflow/core/grappler/optimizers/data/function_utils.h" |
| #include "tensorflow/core/grappler/optimizers/data/graph_utils.h" |
| #include "tensorflow/core/grappler/utils/functions.h" |
| #include "tensorflow/core/lib/core/errors.h" |
| #include "tensorflow/core/util/padding.h" |
| |
| namespace tensorflow { |
| namespace grappler { |
| |
| Status RebatchOptimizer::Init( |
| const tensorflow::RewriterConfig_CustomGraphOptimizer* config) { |
| if (!config) |
| return errors::InvalidArgument( |
| "Cannot initialize RebatchOptimizer without config."); |
| |
| num_replicas_ = config->parameter_map().at("num_replicas").i(); |
| use_fallback_ = config->parameter_map().at("use_fallback").b(); |
| return Status::OK(); |
| } |
| |
| namespace { |
| |
| constexpr char kAddOp[] = "Add"; |
| constexpr char kConstOp[] = "Const"; |
| constexpr char kIdentityOp[] = "Identity"; |
| constexpr char kSubOp[] = "Sub"; |
| constexpr char kTruncateDivOp[] = "TruncateDiv"; |
| constexpr char kOutputShapesAttr[] = "output_shapes"; |
| constexpr char kOutputTypesAttr[] = "output_types"; |
| constexpr char kTOutputTypesAttr[] = "Toutput_types"; |
| constexpr char kBatchOp[] = "BatchDataset"; |
| constexpr char kBatchV2Op[] = "BatchDatasetV2"; |
| constexpr char kPaddedBatchOp[] = "PaddedBatchDataset"; |
| constexpr char kPaddedBatchV2Op[] = "PaddedBatchDatasetV2"; |
| constexpr char kMapAndBatchOp[] = "MapAndBatchDataset"; |
| constexpr char kExperimentalMapAndBatchOp[] = "ExperimentalMapAndBatchDataset"; |
| |
| constexpr std::array<const char*, 6> kBatchDatasetOps = { |
| kBatchOp, kBatchV2Op, kMapAndBatchOp, kExperimentalMapAndBatchOp, |
| kPaddedBatchOp, kPaddedBatchV2Op}; |
| |
| constexpr std::array<const char*, 2> kMultipleInputsDatasetOps = { |
| "ConcatenateDataset", |
| "ZipDataset" |
| }; |
| |
| // TODO(rachelim): We might want to be more conservative here and not allow |
| // passthrough for ops like "Map", "ParallelMap" etc which may change the |
| // batch dimension. Furthermore, transformations like "Skip" may change |
| // the semantics of the dataset (since we'd be skipping N minibatches instead |
| // of N batches). |
| constexpr std::array<const char*, 22> kPassThroughOps = { |
| "CacheDataset", |
| "CacheDatasetV2", |
| "ExperimentalScanDataset", |
| "ExperimentalParseExampleDataset", |
| "FilterDataset", |
| "Identity", |
| "MapDataset", |
| "ModelDataset", |
| "OptimizeDataset", |
| "ParallelMapDataset", |
| "ParseExampleDataset", |
| "PrefetchDataset", |
| "ReduceDataset", |
| "RepeatDataset", |
| "ScanDataset", |
| "ShardDataset", |
| "ShuffleAndRepeatDataset", |
| "ShuffleDataset", |
| "ShuffleDatasetV2", |
| "SkipDataset", |
| "TakeDataset", |
| "WindowDataset", |
| }; |
| |
| constexpr std::array<const char*, 5> kFuncDatasetOps = { |
| "ExperimentalGroupByWindowDataset", |
| "FlatMapDataset", |
| "GroupByWindowDataset", |
| "InterleaveDataset", |
| "ParallelInterleaveDatasetV2", |
| }; |
| |
| const std::map<string, const char*>* kFuncDatasetOpFuncs = |
| new std::map<string, const char*>({ |
| {"ExperimentalGroupByWindowDataset", "reduce_func"}, |
| {"FlatMapDataset", "f"}, |
| {"GroupByWindowDataset", "reduce_func"}, |
| {"InterleaveDataset", "f"}, |
| {"ParallelInterleaveDatasetV2", "f"}, |
| }); |
| |
| constexpr std::array<const char*, 9> kSourceDatasetOps = { |
| "FixedLengthRecordDataset", "FixedLengthRecordDatasetV2", |
| "GeneratorDataset", "RangeDataset", |
| "SparseTensorsSliceDataset", "TensorDataset", |
| "TensorSliceDataset", "TextLineDataset", |
| "TFRecordDataset", |
| }; |
| |
| NodeDef MakeBinaryNode(const string& input_x, const string& input_y, |
| const string& op, DataType dtype) { |
| NodeDef node; |
| node.set_op(op); |
| node.add_input(input_x); |
| node.add_input(input_y); |
| AddNodeAttr("T", dtype, &node); |
| |
| return node; |
| } |
| |
| NodeDef* AddBinaryNode(const string& input_x, const string& input_y, |
| const string& op, DataType type, FunctionDef* fdef) { |
| NodeDef* node = fdef->add_node_def(); |
| *node = MakeBinaryNode(input_x, input_y, op, type); |
| function_utils::SetUniqueFunctionNodeName(op, fdef, node); |
| |
| return node; |
| } |
| |
| // Adds a Const node to the FunctionDef. |
| Status AddConstIntNode(gtl::ArraySlice<int32> values, const TensorShape& shape, |
| FunctionDef* fdef, NodeDef** result) { |
| if (shape.dims() > 1) { |
| return errors::InvalidArgument("Cannot add const node with rank > 1"); |
| } |
| *result = fdef->add_node_def(); |
| TensorProto tensor_proto; |
| tensor_proto.set_dtype(DT_INT32); |
| if (shape.dims() == 0) { |
| // Scalar |
| DCHECK_EQ(values.size(), 1); |
| } else { |
| // vector |
| DCHECK_EQ(values.size(), shape.dim_size(0)); |
| tensor_proto.mutable_tensor_shape()->add_dim()->set_size(shape.dim_size(0)); |
| } |
| |
| for (int value : values) { |
| *tensor_proto.mutable_int_val()->Add() = value; |
| } |
| |
| TF_RETURN_IF_ERROR(NodeDefBuilder("", "Const") |
| .Attr("dtype", DT_INT32) |
| .Attr("value", tensor_proto) |
| .Finalize(*result)); |
| function_utils::SetUniqueFunctionNodeName("rebatch/const", fdef, *result); |
| |
| return Status::OK(); |
| } |
| |
| Status AddConstInt64Node(int64 value, FunctionDef* fdef, NodeDef** result) { |
| *result = fdef->add_node_def(); |
| Tensor t(value); |
| TF_RETURN_IF_ERROR(NodeDefBuilder("", "Const") |
| .Attr("dtype", DT_INT64) |
| .Attr("value", t) |
| .Finalize(*result)); |
| function_utils::SetUniqueFunctionNodeName("rebatch/const", fdef, *result); |
| |
| return Status::OK(); |
| } |
| |
| Status AddConstBoolNode(bool value, FunctionDef* fdef, NodeDef** result) { |
| *result = fdef->add_node_def(); |
| Tensor t(value); |
| TF_RETURN_IF_ERROR(NodeDefBuilder("", "Const") |
| .Attr("dtype", DT_BOOL) |
| .Attr("value", t) |
| .Finalize(*result)); |
| function_utils::SetUniqueFunctionNodeName("rebatch/const", fdef, *result); |
| |
| return Status::OK(); |
| } |
| |
| Status AddShapeNode(const NodeDefBuilder::NodeOut& input, FunctionDef* fdef, |
| NodeDef** result) { |
| *result = fdef->add_node_def(); |
| TF_RETURN_IF_ERROR( |
| NodeDefBuilder("", "Shape").Input(input).Finalize(*result)); |
| function_utils::SetUniqueFunctionNodeName("rebatch/shape", fdef, *result); |
| return Status::OK(); |
| } |
| |
| Status AddStridedSliceNode(const NodeDefBuilder::NodeOut& input, |
| const NodeDefBuilder::NodeOut& begin, |
| const NodeDefBuilder::NodeOut& end, |
| const NodeDefBuilder::NodeOut& strides, |
| DataType index, int32 begin_mask, |
| int32 ellipsis_mask, int32 end_mask, |
| int32 new_axis_mask, int32 shrink_axis_mask, |
| FunctionDef* fdef, NodeDef** result) { |
| *result = fdef->add_node_def(); |
| TF_RETURN_IF_ERROR(NodeDefBuilder("", "StridedSlice") |
| .Input(input) |
| .Input(begin) |
| .Input(end) |
| .Input(strides) |
| .Attr("Index", index) |
| .Attr("begin_mask", begin_mask) |
| .Attr("ellipsis_mask", ellipsis_mask) |
| .Attr("end_mask", end_mask) |
| .Attr("new_axis_mask", new_axis_mask) |
| .Attr("shrink_axis_mask", shrink_axis_mask) |
| .Finalize(*result)); |
| function_utils::SetUniqueFunctionNodeName("rebatch/strided_slice", fdef, |
| *result); |
| return Status::OK(); |
| } |
| |
| Status AddConcatNode(gtl::ArraySlice<NodeDefBuilder::NodeOut> values, |
| NodeDefBuilder::NodeOut axis, int32 n, FunctionDef* fdef, |
| NodeDef** result) { |
| *result = fdef->add_node_def(); |
| TF_RETURN_IF_ERROR(NodeDefBuilder("", "ConcatV2") |
| .Input(values) |
| .Input(axis) |
| .Attr("N", n) |
| .Finalize(*result)); |
| function_utils::SetUniqueFunctionNodeName("rebatch/concat", fdef, *result); |
| return Status::OK(); |
| } |
| |
| Status AddReshapeNode(NodeDefBuilder::NodeOut tensor, |
| NodeDefBuilder::NodeOut shape, FunctionDef* fdef, |
| NodeDef** result) { |
| *result = fdef->add_node_def(); |
| TF_RETURN_IF_ERROR(NodeDefBuilder("", "Reshape") |
| .Input(tensor) |
| .Input(shape) |
| .Finalize(*result)); |
| function_utils::SetUniqueFunctionNodeName("rebatch/reshape", fdef, *result); |
| return Status::OK(); |
| } |
| |
| template <std::size_t SIZE> |
| bool IsDatasetNodeOfType(const NodeDef& node, |
| const std::array<const char*, SIZE>& arr) { |
| for (const auto& dataset_op_name : arr) { |
| if (node.op() == dataset_op_name) return true; |
| } |
| return false; |
| } |
| |
| void SetUnknownShapes(int num_components, AttrValue* output_shapes) { |
| for (int i = 0; i < num_components; ++i) { |
| output_shapes->mutable_list()->mutable_shape()->Add()->set_unknown_rank( |
| true); |
| } |
| } |
| |
| Status GetBatchDim(AttrValue output_shapes, int* batch_dim) { |
| const auto& shape_0 = output_shapes.list().shape(0); |
| if (shape_0.unknown_rank() || shape_0.dim(0).size() == -1) { |
| return errors::InvalidArgument( |
| "Cannot use rebatching fallback when 0th dimensions of dataset " |
| "components are not fully known. Component 0 has shape: ", |
| shape_0.ShortDebugString()); |
| } |
| |
| *batch_dim = output_shapes.list().shape(0).dim(0).size(); |
| |
| for (int i = 1; i < output_shapes.list().shape_size(); ++i) { |
| const auto& shape_i = output_shapes.list().shape(i); |
| |
| if (shape_i.unknown_rank() || shape_i.dim(0).size() == -1) { |
| return errors::InvalidArgument( |
| "Cannot use rebatching fallback when 0th dimensions of dataset " |
| "components are not fully known. Component ", |
| i, " has shape: ", shape_i.ShortDebugString()); |
| } |
| if (shape_i.dim(0).size() != *batch_dim) { |
| return errors::InvalidArgument( |
| "Cannot use rebatching fallback when 0th dimensions of dataset " |
| "components don't match. Component ", |
| i, " has batch dimension: ", shape_i.dim(0).size(), |
| " while previous components have batch dimension: ", *batch_dim); |
| } |
| } |
| return Status::OK(); |
| } |
| |
| Status UpdateOutputShapes(const string& node_name, int64 num_replicas, |
| MutableGraphView* graph) { |
| NodeDef* node = graph->GetNode(node_name); |
| if (node->attr().contains(kOutputShapesAttr)) { |
| AttrValue output_shapes = node->attr().at(kOutputShapesAttr); |
| for (auto& shape : *output_shapes.mutable_list()->mutable_shape()) { |
| if (!shape.unknown_rank() && shape.dim(0).size() != -1) { |
| shape.mutable_dim(0)->set_size(shape.dim(0).size() / num_replicas); |
| } |
| } |
| (*node->mutable_attr())[kOutputShapesAttr] = output_shapes; |
| } |
| return Status::OK(); |
| } |
| |
| // Helper function to get the batch_size input node for a give batch node. |
| int64 GetBatchSizeArgIndex(const NodeDef& batch_node) { |
| if (batch_node.op() == kExperimentalMapAndBatchOp || |
| batch_node.op() == kMapAndBatchOp) { |
| // For MapAndBatch we take the 3rd last input. |
| return batch_node.input_size() - 3; |
| } |
| // For all the batching datasets the batch_size is input number 1 except for |
| // MapAndBatchDataset. |
| return 1; |
| } |
| |
| Status MakeNewBatchSizeNode(const string& global_batch_size_name, |
| int64 num_replicas, FunctionDef* fdef, |
| NodeDef** result) { |
| NodeDef* one_node; |
| TF_RETURN_IF_ERROR(AddConstInt64Node(1, fdef, &one_node)); |
| NodeDef* num_replicas_node; |
| TF_RETURN_IF_ERROR(AddConstInt64Node(num_replicas, fdef, &num_replicas_node)); |
| |
| NodeDef* numerator_node = |
| AddBinaryNode(global_batch_size_name, |
| strings::StrCat(num_replicas_node->name(), ":output:0"), |
| kAddOp, DT_INT64, fdef); |
| numerator_node = AddBinaryNode( |
| strings::StrCat(numerator_node->name(), ":z:0"), |
| strings::StrCat(one_node->name(), ":output:0"), kSubOp, DT_INT64, fdef); |
| |
| *result = |
| AddBinaryNode(strings::StrCat(numerator_node->name(), ":z:0"), |
| strings::StrCat(num_replicas_node->name(), ":output:0"), |
| kTruncateDivOp, DT_INT64, fdef); |
| return Status::OK(); |
| } |
| |
| // Given a "batch" dataset node, we replace the `batch_size` input with a new |
| // input that corresponds to the original input divided by `num_replicas`. |
| Status MutateBatchSize(const NodeDef& node, int64 num_replicas, |
| MutableGraphView* graph) { |
| // For all the batching datasets the batch_size is input number 1 except for |
| // MapAndBatchDataset. |
| int64 batch_size_arg_index = GetBatchSizeArgIndex(node); |
| NodeDef* batch_size_node = |
| graph_utils::GetInputNode(node, *graph, batch_size_arg_index); |
| int64 batch_size; |
| TF_RETURN_IF_ERROR( |
| graph_utils::GetScalarConstNodeValue(*batch_size_node, &batch_size)); |
| DCHECK_EQ(batch_size % num_replicas, 0); |
| batch_size = batch_size / num_replicas; |
| NodeDef* new_batch_size_node = |
| graph_utils::AddScalarConstNode<int64>(batch_size, graph); |
| // We don't call UpdateFanouts here because CSE elimination might lead to |
| // multiple nodes sharing the same batch size constant node. This is also |
| // why we don't delete batch_size_node as well. |
| TF_RETURN_IF_ERROR(graph->UpdateRegularFaninByPort( |
| node.name(), batch_size_arg_index, {new_batch_size_node->name(), 0})); |
| return Status::OK(); |
| } |
| |
| Status AddFlatMapNode(const string& input_dataset, |
| gtl::ArraySlice<string> other_arguments, |
| gtl::ArraySlice<DataType> t_arguments, |
| const FunctionDef& flat_map_fn, |
| const AttrValue& output_shapes, |
| const DataTypeVector& output_types, |
| FunctionLibraryDefinition* flib, MutableGraphView* graph, |
| NodeDef** result) { |
| TF_RETURN_IF_ERROR(flib->AddFunctionDef(flat_map_fn)); |
| AttrValue f; |
| f.mutable_func()->set_name(flat_map_fn.signature().name()); |
| |
| NodeDef flat_map_node; |
| flat_map_node.set_op("FlatMapDataset"); |
| flat_map_node.add_input(input_dataset); |
| for (const auto& arg : other_arguments) { |
| flat_map_node.add_input(arg); |
| } |
| AddNodeAttr("f", f, &flat_map_node); |
| AddNodeAttr("Targuments", t_arguments, &flat_map_node); |
| AddNodeAttr(kOutputShapesAttr, output_shapes, &flat_map_node); |
| AddNodeAttr(kOutputTypesAttr, output_types, &flat_map_node); |
| |
| graph_utils::SetUniqueGraphNodeName("rebatch/flat_map", graph->graph(), |
| &flat_map_node); |
| *result = graph->AddNode(std::move(flat_map_node)); |
| return Status::OK(); |
| } |
| |
| // def flat_map_fn(*batched_components): |
| // ds = tf.data.Dataset.from_tensor_slices(batched_components) |
| // return ds.batch(minibatch_size, drop_remainder=False) |
| Status CreateFlatMapFnWithBatch(const DataTypeVector& dtypes, |
| int64 num_replicas, FunctionDef* result) { |
| NodeDef* tensor_slice_node = result->add_node_def(); |
| tensor_slice_node->set_op("TensorSliceDataset"); |
| for (int i = 0; i < dtypes.size(); ++i) { |
| auto* input_arg = function_utils::AddFunctionInput( |
| strings::StrCat("args_", i), result, dtypes.at(i)); |
| tensor_slice_node->add_input(input_arg->name()); |
| } |
| AddNodeAttr(kTOutputTypesAttr, dtypes, tensor_slice_node); |
| |
| // The output_shapes attr here doesn't make a difference, since we |
| // set the output_shapes of the external FlatMap node. |
| AttrValue shapes; |
| SetUnknownShapes(dtypes.size(), &shapes); |
| AddNodeAttr(kOutputShapesAttr, shapes, tensor_slice_node); |
| function_utils::SetUniqueFunctionNodeName("rebatch/from_tensor_slices", |
| result, tensor_slice_node); |
| |
| NodeDef* false_node; |
| TF_RETURN_IF_ERROR(AddConstBoolNode(false, result, &false_node)); |
| NodeDef* batch_node = result->add_node_def(); |
| batch_node->set_op(kBatchV2Op); |
| batch_node->add_input( |
| strings::StrCat(tensor_slice_node->name(), ":handle:0")); |
| |
| // `batch_size` input |
| // Here, we capture the original batch size from outside the flat map fn. |
| auto* original_batch_size = |
| function_utils::AddFunctionInput("captured_batch_size", result, DT_INT64); |
| NodeDef* new_batch_size; |
| TF_RETURN_IF_ERROR(MakeNewBatchSizeNode( |
| original_batch_size->name(), num_replicas, result, &new_batch_size)); |
| batch_node->add_input(strings::StrCat(new_batch_size->name(), ":z:0")); |
| |
| // `drop_remainder` input |
| batch_node->add_input(strings::StrCat(false_node->name(), ":output:0")); |
| AddNodeAttr(kOutputTypesAttr, dtypes, batch_node); |
| AddNodeAttr(kOutputShapesAttr, shapes, batch_node); |
| function_utils::SetUniqueFunctionNodeName("rebatch/batch", result, |
| batch_node); |
| function_utils::AddFunctionOutputWithUniqueName( |
| "output", strings::StrCat(batch_node->name(), ":handle:0"), result, |
| DT_VARIANT); |
| // Because TensorSliceDataset is stateful, we set the function to stateful. |
| result->mutable_signature()->set_is_stateful(true); |
| |
| return Status::OK(); |
| } |
| |
| // Rewrite graph to add |
| // `.flat_map(lambda x: tf.data.Dataset.from_tensor_slices(x). |
| // batch(minibatch_size, drop_remainder=False))` |
| // after the batch node. This ensures that the sum of the minibatch sizes |
| // in a step adds up to the global batch size. However, since this adds |
| // additional data copies (both from_tensor_slices and batch), we only use |
| // this approach when necessary, i.e. when we need to drop remainder on the |
| // global batch, or when the global batch size does not divide num_replicas |
| // evenly. |
| Status AppendFlatMap(const NodeDef& batch_node, int64 num_replicas, |
| FunctionLibraryDefinition* flib, MutableGraphView* graph) { |
| // `.flat_map(lambda x: tf.data.Dataset.from_tensor_slices(x). |
| // batch(minibatch_size, drop_remainder=False))` |
| FunctionDef flat_map_fn; |
| FunctionDefLibrary lib = flib->ToProto(); |
| graph_utils::SetUniqueGraphFunctionName("rebatch/flat_map_fn", &lib, |
| &flat_map_fn); |
| DataTypeVector dtypes; |
| TF_RETURN_IF_ERROR( |
| graph_utils::GetDatasetOutputTypesAttr(batch_node, &dtypes)); |
| TF_RETURN_IF_ERROR( |
| CreateFlatMapFnWithBatch(dtypes, num_replicas, &flat_map_fn)); |
| |
| int64 batch_size_index = GetBatchSizeArgIndex(batch_node); |
| |
| NodeDef* flat_map_node; |
| |
| AttrValue output_shapes = batch_node.attr().at(kOutputShapesAttr); |
| for (auto& shape : *output_shapes.mutable_list()->mutable_shape()) { |
| if (!shape.unknown_rank() && shape.dim(0).size() != -1) { |
| // Because the flat map function uses drop_remainder = False, |
| // the shape might be unknown |
| auto old_dim = shape.dim(0).size(); |
| auto new_dim = old_dim % num_replicas == 0 ? old_dim / num_replicas : -1; |
| shape.mutable_dim(0)->set_size(new_dim); |
| } |
| } |
| |
| TF_RETURN_IF_ERROR(AddFlatMapNode(strings::StrCat(batch_node.name(), ":0"), |
| {batch_node.input(batch_size_index)}, |
| {DT_INT64}, flat_map_fn, output_shapes, |
| dtypes, flib, graph, &flat_map_node)); |
| |
| TF_RETURN_IF_ERROR( |
| graph->UpdateFanouts(batch_node.name(), flat_map_node->name())); |
| |
| return Status::OK(); |
| } |
| |
| // There are several things we do here, depending on the values of |
| // batch_size and drop_remainder. |
| // (1) If batch size is known and divisible by num_replicas, and drop_remainder |
| // is known to be False, we mutate the batch size directly. |
| // .batch(global_batch_size) -> .batch(global_batch_size // num_replicas) |
| // (2) Otherwise, we add a flat_map transformation to preserve the global batch |
| // size across the replicas and to preserve the drop remainder behavior. |
| bool ShouldMutateBatchSizeDirectly(const NodeDef& batch_node, |
| int64 num_replicas, |
| MutableGraphView* graph) { |
| int64 batch_size_arg_index = GetBatchSizeArgIndex(batch_node); |
| NodeDef* batch_size_node = |
| graph_utils::GetInputNode(batch_node, *graph, batch_size_arg_index); |
| |
| int64 batch_size; |
| Status s = |
| graph_utils::GetScalarConstNodeValue(*batch_size_node, &batch_size); |
| // If batch size is unknown or indivisible by num replicas, we don't |
| // mutate it directly |
| if (!s.ok() || batch_size % num_replicas != 0) return false; |
| |
| if (batch_node.op() == kBatchOp || batch_node.op() == kPaddedBatchOp) { |
| // These ops don't have a `drop_remainder` input, and behave like |
| // drop_remainder is False. |
| return true; |
| } |
| |
| // drop_remainder is the final input on the other batch nodes. |
| NodeDef* drop_remainder_node = graph_utils::GetInputNode( |
| batch_node, *graph, batch_node.input_size() - 1); |
| bool drop_remainder; |
| s = graph_utils::GetScalarConstNodeValue(*drop_remainder_node, |
| &drop_remainder); |
| return s.ok() && !drop_remainder; |
| } |
| |
| Status RewriteBatchNode(const NodeDef& batch_node, int64 num_replicas, |
| FunctionLibraryDefinition* flib, |
| MutableGraphView* graph) { |
| if (ShouldMutateBatchSizeDirectly(batch_node, num_replicas, graph)) { |
| return MutateBatchSize(batch_node, num_replicas, graph); |
| } |
| return AppendFlatMap(batch_node, num_replicas, flib, graph); |
| } |
| |
| Status OptimizeGraph(const GrapplerItem& item, int64 num_replicas, |
| bool use_fallback, GraphDef* output); |
| |
| // Helper function that starts from a node in the graph and recurses into its |
| // inputs trying to find a BatchDataset type operation to modify. During the |
| // recursion it handles four kinds of cases. |
| // 1. BatchDataset type ops: Mutates the batch_size input node and stops. |
| // 2. Zip / Concatenate dataset ops: Recurses into all inputs to these ops |
| // as they are datasets themselves. |
| // 3. Core dataset ops + Identity op: Recurses into first input parameter. |
| // 4. FlatMap type mapping dataset ops: Recurses into the function definition. |
| Status RecursivelyHandleOp(const NodeDef& node, int64 num_replicas, |
| bool use_fallback, FunctionLibraryDefinition* flib, |
| MutableGraphView* graph) { |
| if (IsDatasetNodeOfType(node, kBatchDatasetOps)) { |
| TF_RETURN_IF_ERROR(RewriteBatchNode(node, num_replicas, flib, graph)); |
| } else if (IsDatasetNodeOfType(node, kMultipleInputsDatasetOps)) { |
| // For all multiple input datasets, all inputs are datasets themselves. |
| for (int i = 0; i < node.input_size(); ++i) { |
| NodeDef* input_node = graph_utils::GetInputNode(node, *graph, i); |
| TF_RETURN_IF_ERROR(RecursivelyHandleOp(*input_node, num_replicas, |
| use_fallback, flib, graph)); |
| } |
| } else if (IsDatasetNodeOfType(node, kPassThroughOps) || IsRetval(node)) { |
| // For all the dataset ops that are passthrough, or _Retvals added to the |
| // function body graph in place of function outputs, the input dataset is |
| // input 0. |
| NodeDef* input_node = graph_utils::GetInputNode(node, *graph, 0); |
| TF_RETURN_IF_ERROR(RecursivelyHandleOp(*input_node, num_replicas, |
| use_fallback, flib, graph)); |
| } else if (IsDatasetNodeOfType(node, kFuncDatasetOps)) { |
| const string func_name = |
| node.attr().at(kFuncDatasetOpFuncs->at(node.op())).func().name(); |
| const FunctionDef* fdef = flib->Find(func_name); |
| GrapplerFunctionItem f_item; |
| TF_RETURN_IF_ERROR(MakeGrapplerFunctionItem( |
| *fdef, *flib, graph->graph()->versions().producer(), &f_item)); |
| GraphDef optimized_func_graph; |
| TF_RETURN_IF_ERROR(OptimizeGraph(f_item, num_replicas, use_fallback, |
| &optimized_func_graph)); |
| |
| // Function body optimization might have created new specialized |
| // functions for each instantiation context. Add them to the library. |
| for (const FunctionDef& func_def : |
| optimized_func_graph.library().function()) { |
| if (flib->Find(func_def.signature().name()) == nullptr) { |
| TF_RETURN_IF_ERROR(flib->AddFunctionDef(func_def)); |
| } |
| } |
| |
| // Convert optimized graph back to FunctionDef. |
| FunctionDef optimized_func; |
| f_item.SwapFunctionBody(std::move(optimized_func_graph)); |
| TF_RETURN_IF_ERROR(MakeFunctionDef(f_item, *flib, &optimized_func)); |
| |
| // Replace optimized function with a new FunctionDef. |
| TF_RETURN_IF_ERROR(flib->ReplaceFunction(func_name, optimized_func)); |
| } else if (IsDatasetNodeOfType(node, kSourceDatasetOps)) { |
| return errors::InvalidArgument( |
| "Reached a source dataset: ", node.op(), |
| " without encountering a batch transformation."); |
| } else { |
| return errors::InvalidArgument("Encountered an unsupported op: ", |
| node.op()); |
| } |
| // If we've successfully updated the batch size of this node or any nodes |
| // in the dataset tree rooted in this node, we update the output_shapes attr. |
| TF_RETURN_IF_ERROR(UpdateOutputShapes(node.name(), num_replicas, graph)); |
| return Status::OK(); |
| } |
| |
| // Add nodes to the function to reshape arg to shape (-1, new_batch_dim, ...) |
| Status ReshapeComponent(int new_batch_dim, const string& arg, DataType dtype, |
| FunctionDef* fdef, string* result) { |
| // Const with value [0] |
| NodeDef* const_vec_0; |
| TF_RETURN_IF_ERROR(AddConstIntNode({0}, {1}, fdef, &const_vec_0)); |
| |
| // Const with value [1] |
| NodeDef* const_vec_1; |
| TF_RETURN_IF_ERROR(AddConstIntNode({1}, {1}, fdef, &const_vec_1)); |
| |
| // Const with value 0 |
| NodeDef* const_0; |
| TF_RETURN_IF_ERROR(AddConstIntNode({0}, {}, fdef, &const_0)); |
| |
| // Const with value [-1, new_batch_dim] |
| NodeDef* first_two_dims; |
| TF_RETURN_IF_ERROR( |
| AddConstIntNode({-1, new_batch_dim}, {2}, fdef, &first_two_dims)); |
| |
| // shape = tf.shape(arg) |
| NodeDef* shape; |
| TF_RETURN_IF_ERROR(AddShapeNode({arg, 0, dtype}, fdef, &shape)); |
| |
| // later_dimensions = tf.shape(arg)[1:] |
| NodeDef* later_dimensions; |
| TF_RETURN_IF_ERROR(AddStridedSliceNode( |
| {strings::StrCat(shape->name(), ":output"), 0, DT_INT32}, |
| {strings::StrCat(const_vec_1->name(), ":output"), 0, DT_INT32}, |
| {strings::StrCat(const_vec_0->name(), ":output"), 0, DT_INT32}, |
| {strings::StrCat(const_vec_1->name(), ":output"), 0, DT_INT32}, DT_INT32, |
| 0, 0, 1, 0, 0, fdef, &later_dimensions)); |
| |
| // new_shape = tf.concat([pack, later_dimensions], 0) |
| NodeDef* new_shape; |
| TF_RETURN_IF_ERROR(AddConcatNode( |
| {{strings::StrCat(first_two_dims->name(), ":output"), 0, DT_INT32}, |
| {strings::StrCat(later_dimensions->name(), ":output"), 0, DT_INT32}}, |
| {strings::StrCat(const_0->name(), ":output"), 0, DT_INT32}, 2, fdef, |
| &new_shape)); |
| |
| NodeDef* reshape; |
| TF_RETURN_IF_ERROR(AddReshapeNode( |
| {arg, 0, dtype}, |
| {strings::StrCat(new_shape->name(), ":output"), 0, DT_INT32}, fdef, |
| &reshape)); |
| *result = reshape->name(); |
| |
| return Status::OK(); |
| } |
| |
| // def flat_map_fn(*batched_components): |
| // return tf.data.Dataset.from_tensor_slices( |
| // [tf.reshape(c, (-1, new_batch_size, ...)) |
| // for c in batched_components]) |
| Status CreateFlatMapFnWithReshape(int new_batch_dim, |
| const DataTypeVector& types, |
| FunctionDef* result) { |
| std::vector<NodeDefBuilder::NodeOut> tensor_slice_dataset_inputs; |
| |
| // For each component of the dataset, we reshape it from shape |
| // (old_batch_size, ...) to (-1, new_batch_size, ...) |
| // where new_batch_size = (old_batch_size + num_replicas - 1) // num_replicas |
| for (int i = 0; i < types.size(); ++i) { |
| auto* input_arg = function_utils::AddFunctionInput( |
| strings::StrCat("args_", i), result, types.at(i)); |
| |
| string reshape_node_name; |
| TF_RETURN_IF_ERROR(ReshapeComponent(new_batch_dim, input_arg->name(), |
| types.at(i), result, |
| &reshape_node_name)); |
| |
| tensor_slice_dataset_inputs.emplace_back( |
| strings::StrCat(reshape_node_name, ":output"), 0, types.at(i)); |
| } |
| |
| // The output_shapes attr here doesn't make a difference, since we |
| // set the output_shapes of the external FlatMap node. |
| AttrValue shapes; |
| SetUnknownShapes(types.size(), &shapes); |
| |
| NodeDef* tensor_slice_dataset = result->add_node_def(); |
| TF_RETURN_IF_ERROR(NodeDefBuilder("", "TensorSliceDataset") |
| .Input(tensor_slice_dataset_inputs) |
| .Attr("Toutput_types", types) |
| .Attr(kOutputShapesAttr, shapes) |
| .Finalize(tensor_slice_dataset)); |
| function_utils::SetUniqueFunctionNodeName("rebatch/tensor_slice_dataset", |
| result, tensor_slice_dataset); |
| |
| function_utils::AddFunctionOutputWithUniqueName( |
| "output", strings::StrCat(tensor_slice_dataset->name(), ":handle:0"), |
| result, DT_VARIANT); |
| // Because TensorSliceDataset is stateful, we set the function to stateful. |
| result->mutable_signature()->set_is_stateful(true); |
| |
| return Status::OK(); |
| } |
| |
| // We fallback to the following rewrite: |
| // ``` |
| // dataset = ...fetch_node... |
| // def fn(x): |
| // return tf.data.Dataset.from_tensor_slices( |
| // tf.reshape( |
| // x, |
| // tf.concat([[-1, old_batch_dim / num_replicas], tf.shape(x)[1:]], 0) |
| // ) |
| // ) |
| // |
| // dataset = dataset.flat_map(fn) |
| // ``` |
| Status RebatchWithFallback(const NodeDef* fetch_node, int64 num_replicas, |
| FunctionLibraryDefinition* flib, |
| MutableGraphView* graph) { |
| if (IsRetval(*fetch_node) || fetch_node->op() == kIdentityOp) { |
| // Get the last dataset in the pipeline |
| fetch_node = graph_utils::GetInputNode(*fetch_node, *graph, 0); |
| } |
| |
| // Note: Here, we are conservative with only using the fallback when |
| // the output_shapes attr has the 0th dimension defined for every component. |
| // This because the flat_map_fn will fail if the batch does not divide evenly |
| // because of the use of the "Reshape" op. This ensures that the error is |
| // surfaced correctly. |
| AttrValue output_shapes; |
| if (!fetch_node->attr().contains(kOutputShapesAttr)) { |
| return errors::InvalidArgument( |
| "Cannot use rebatching fallback without output_shapes attr. Node: ", |
| fetch_node->name(), " Op: ", fetch_node->op()); |
| } else { |
| output_shapes = fetch_node->attr().at(kOutputShapesAttr); |
| } |
| int batch_dim; |
| TF_RETURN_IF_ERROR(GetBatchDim(output_shapes, &batch_dim)); |
| if (batch_dim % num_replicas != 0) { |
| return errors::InvalidArgument( |
| "Cannot use rebatching fallback when batch dimension doesn't divide " |
| "num_replicas evenly."); |
| } |
| |
| // Create the flat map fn |
| FunctionDef flat_map_fn; |
| FunctionDefLibrary lib = flib->ToProto(); |
| graph_utils::SetUniqueGraphFunctionName("rebatch/flat_map_fn", &lib, |
| &flat_map_fn); |
| |
| // Get types of input arguments from the output types of the final dataset. |
| DataTypeVector output_types; |
| TF_RETURN_IF_ERROR( |
| graph_utils::GetDatasetOutputTypesAttr(*fetch_node, &output_types)); |
| TF_RETURN_IF_ERROR(CreateFlatMapFnWithReshape(batch_dim / num_replicas, |
| output_types, &flat_map_fn)); |
| |
| NodeDef* flat_map_node; |
| TF_RETURN_IF_ERROR(AddFlatMapNode(strings::StrCat(fetch_node->name(), ":0"), |
| {}, {}, flat_map_fn, output_shapes, |
| output_types, flib, graph, &flat_map_node)); |
| TF_RETURN_IF_ERROR( |
| UpdateOutputShapes(flat_map_node->name(), num_replicas, graph)); |
| |
| TF_RETURN_IF_ERROR( |
| graph->UpdateFanouts(fetch_node->name(), flat_map_node->name())); |
| |
| return Status::OK(); |
| } |
| |
| // Helper function that given a GrapplerItem generates a mutated graph def |
| // with the batch size changed. The GrapplerItem could be generated from the |
| // main graph or could be a function graph. |
| Status OptimizeGraph(const GrapplerItem& item, int64 num_replicas, |
| bool use_fallback, GraphDef* output) { |
| *output = item.graph; |
| MutableGraphView graph(output); |
| |
| FunctionLibraryDefinition flib(OpRegistry::Global(), item.graph.library()); |
| |
| NodeDef* sink_node; |
| TF_RETURN_IF_ERROR(graph_utils::GetFetchNode(graph, item, &sink_node)); |
| |
| Status s = RecursivelyHandleOp(*sink_node, num_replicas, use_fallback, &flib, |
| &graph); |
| if (!s.ok()) { |
| if (use_fallback) { |
| VLOG(1) << "Failed to rebatch by rewriting the batch transformation (" |
| << s << "). Using a fallback method instead."; |
| // If RecursivelyHandleOp fails, we reset `graph` to use the original, |
| // graph, since that function may have mutated `graph`. |
| *output = item.graph; |
| graph = MutableGraphView(output); |
| TF_RETURN_IF_ERROR( |
| RebatchWithFallback(sink_node, num_replicas, &flib, &graph)); |
| } else { |
| // Return the error |
| return s; |
| } |
| } |
| *output->mutable_library() = flib.ToProto(); |
| return Status::OK(); |
| } |
| |
| } // anonymous namespace |
| |
| Status RebatchOptimizer::OptimizeAndCollectStats(Cluster* cluster, |
| const GrapplerItem& item, |
| GraphDef* output, |
| OptimizationStats* stats) { |
| *output = item.graph; |
| MutableGraphView graph(output); |
| |
| TF_RETURN_IF_ERROR(OptimizeGraph(item, num_replicas_, use_fallback_, output)); |
| stats->num_changes++; |
| return Status::OK(); |
| } |
| |
| void RebatchOptimizer::Feedback(Cluster* cluster, const GrapplerItem& item, |
| const GraphDef& optimize_output, |
| double result) {} |
| |
| REGISTER_GRAPH_OPTIMIZER_AS(RebatchOptimizer, "tf_data_rebatcher"); |
| |
| } // namespace grappler |
| } // namespace tensorflow |