| /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
| |
| Licensed under the Apache License, Version 2.0 (the "License"); |
| you may not use this file except in compliance with the License. |
| You may obtain a copy of the License at |
| |
| http://www.apache.org/licenses/LICENSE-2.0 |
| |
| Unless required by applicable law or agreed to in writing, software |
| distributed under the License is distributed on an "AS IS" BASIS, |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| See the License for the specific language governing permissions and |
| limitations under the License. |
| ==============================================================================*/ |
| |
| #include "tensorflow/core/grappler/optimizers/auto_mixed_precision.h" |
| |
| #include <fstream> |
| #include <memory> |
| #include <unordered_map> |
| |
| #include "absl/container/flat_hash_map.h" |
| #include "absl/container/flat_hash_set.h" |
| #include "absl/strings/str_format.h" |
| #include "tensorflow/core/framework/attr_value_util.h" |
| #include "tensorflow/core/framework/node_def.pb.h" |
| #include "tensorflow/core/framework/op.h" |
| #include "tensorflow/core/framework/op_kernel.h" |
| #include "tensorflow/core/framework/types.h" |
| #include "tensorflow/core/framework/types.pb.h" |
| #include "tensorflow/core/grappler/clusters/cluster.h" |
| #include "tensorflow/core/grappler/costs/virtual_placer.h" |
| #include "tensorflow/core/grappler/devices.h" |
| #include "tensorflow/core/grappler/grappler_item.h" |
| #include "tensorflow/core/grappler/mutable_graph_view.h" |
| #include "tensorflow/core/grappler/op_types.h" |
| #include "tensorflow/core/grappler/optimizers/auto_mixed_precision_lists.h" |
| #include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h" |
| #include "tensorflow/core/grappler/utils.h" |
| #include "tensorflow/core/lib/io/path.h" |
| #include "tensorflow/core/lib/strings/numbers.h" |
| #include "tensorflow/core/lib/strings/str_util.h" |
| #include "tensorflow/core/lib/strings/strcat.h" |
| #include "tensorflow/core/platform/logging.h" |
| #include "tensorflow/core/platform/status.h" |
| #include "tensorflow/core/util/env_var.h" |
| |
| namespace tensorflow { |
| namespace grappler { |
| namespace { |
| |
| bool ShouldSimulateGpu() { |
| bool is_enabled = [] { |
| bool ret = false; |
| string var; |
| TF_CHECK_OK(ReadStringFromEnvVar( |
| "TF_AUTO_MIXED_PRECISION_GRAPH_REWRITE_SIMULATE_GPU", "", &var)); |
| TF_CHECK_OK( |
| ReadBoolFromEnvVar("TF_AUTO_MIXED_PRECISION_GRAPH_REWRITE_SIMULATE_GPU", |
| /*default_val=*/false, &ret)); |
| return ret; |
| }(); |
| return is_enabled; |
| } |
| |
| #if GOOGLE_CUDA |
| const std::pair<int, int> kMinGPUArch = {7, 0}; |
| #else |
| const std::pair<int, int> kMinGPUArch = {0, 0}; |
| #endif |
| |
| const char kSuffix[] = "AutoMixedPrecision"; |
| const char kCastToFp16[] = "CastToFp16"; |
| const char kCastToBf16[] = "CastToBf16"; |
| const char kCastToFp32[] = "CastToFp32"; |
| |
| #if GOOGLE_CUDA |
| // Returns the GPU architecture (compute capability) as a (major, minor) pair. |
| std::pair<int, int> GetDeviceGPUArch( |
| const DeviceProperties& device_properties) { |
| if (device_properties.type() != "GPU") return {0, 0}; |
| string arch_str = device_properties.environment().at("architecture"); |
| std::vector<string> split_arch_str = str_util::Split(arch_str, '.'); |
| if (split_arch_str.empty()) { |
| return {0, 0}; |
| } |
| |
| int major, minor; |
| if (!strings::safe_strto32(split_arch_str[0], &major)) { |
| return {0, 0}; |
| } |
| |
| if (split_arch_str.size() > 1) { |
| if (strings::safe_strto32(split_arch_str[1], &minor)) { |
| return {major, minor}; |
| } else { |
| return {0, 0}; |
| } |
| } else { |
| return {major, 0}; |
| } |
| } |
| #endif |
| |
| // Returns true if FP16Support is valid |
| // For CUDA, We compare the GPUArch with the kMinGPUArch, if GPUArch is >= min, |
| // return true. For AMD the corresponding gfx arch string for the detected AMD |
| // GPU is in the list for FP16 supported compute. Returns false otherwise. |
| bool HasFastFP16Support(const DeviceProperties& props) { |
| #if GOOGLE_CUDA |
| return GetDeviceGPUArch(props) >= kMinGPUArch; |
| #elif TENSORFLOW_USE_ROCM |
| absl::flat_hash_set<std::string> FP16SupportedDevices = {{"gfx906"}, |
| {"gfx908"}}; |
| std::string gcnArchName = props.environment().at("architecture"); |
| std::vector<std::string> gpu_arch = absl::StrSplit(gcnArchName, ":"); |
| return !gpu_arch.empty() && FP16SupportedDevices.contains(gpu_arch[0]); |
| #endif |
| return ShouldSimulateGpu(); |
| } |
| |
| // Instances of this class represent unique type attribute identifiers within a |
| // node. It handles regular type attributes, list type attributes (where |
| // type_index is set to the index in the type list), and fixed types. |
| struct TypeAttrId { |
| static constexpr int kSingleType = -1; |
| |
| explicit TypeAttrId(const string& _attr_name, int _type_index = kSingleType) |
| : attr_name(_attr_name), |
| type_index(_type_index), |
| fixed_type(DT_INVALID) {} |
| |
| explicit TypeAttrId(DataType _fixed_type) |
| : attr_name(), type_index(kSingleType), fixed_type(_fixed_type) {} |
| |
| bool operator==(const TypeAttrId& other) const { |
| return attr_name == other.attr_name && type_index == other.type_index && |
| fixed_type == other.fixed_type; |
| } |
| |
| bool operator<(const TypeAttrId& other) const { |
| return std::make_tuple(attr_name, type_index, fixed_type) < |
| std::make_tuple(other.attr_name, other.type_index, other.fixed_type); |
| } |
| |
| template <typename H> |
| friend H AbslHashValue(H h, const TypeAttrId& ta) { |
| return H::combine(std::move(h), ta.attr_name, ta.type_index, ta.fixed_type); |
| } |
| |
| string DebugString() const { |
| if (!attr_name.empty()) { |
| if (type_index == kSingleType) { |
| return attr_name; |
| } else { |
| return strings::StrCat(attr_name, "[", type_index, "]"); |
| } |
| } else { |
| return tensorflow::DataTypeString(fixed_type); |
| } |
| } |
| |
| string attr_name; |
| // If attr_name is a list(type), this is the index into the list. Otherwise |
| // this is kSingleType. |
| int type_index; |
| DataType fixed_type; |
| }; |
| |
| // Returns the data type of the given type attribute, or DT_INVALID if the type |
| // attribute is invalid. |
| DataType GetDataType(const NodeDef& node, const TypeAttrId& type_attr) { |
| if (type_attr.attr_name.empty()) { |
| return type_attr.fixed_type; |
| } |
| if (!node.attr().count(type_attr.attr_name)) { |
| return DT_INVALID; |
| } |
| const AttrValue& attr_value = node.attr().at(type_attr.attr_name); |
| if (type_attr.type_index == TypeAttrId::kSingleType) { |
| return attr_value.type(); |
| } else { |
| if (type_attr.type_index < 0 || |
| type_attr.type_index >= attr_value.list().type_size()) { |
| return DT_INVALID; |
| } |
| return attr_value.list().type(type_attr.type_index); |
| } |
| } |
| |
| // Sets the data type of the given type attribute. Returns false if the type |
| // attribute is invalid, otherwise true. |
| bool SetDataType(NodeDef* node, const TypeAttrId& type_attr, DataType type) { |
| if (type_attr.attr_name.empty() || !node->attr().count(type_attr.attr_name)) { |
| return false; |
| } |
| AttrValue& attr_value = node->mutable_attr()->at(type_attr.attr_name); |
| if (type_attr.type_index == TypeAttrId::kSingleType) { |
| attr_value.set_type(type); |
| } else { |
| if (type_attr.type_index < 0 || |
| type_attr.type_index >= attr_value.list().type_size()) { |
| return false; |
| } |
| attr_value.mutable_list()->set_type(type_attr.type_index, type); |
| } |
| return true; |
| } |
| |
| std::vector<std::pair<int, int>> ArgDefIndexes(const NodeDef& node, int arg_idx, |
| const OpDef::ArgDef& arg_def) { |
| std::vector<std::pair<int, int>> argdef_inds; |
| if (!arg_def.type_list_attr().empty()) { |
| int num_types = node.attr().at(arg_def.type_list_attr()).list().type_size(); |
| for (int type_idx = 0; type_idx < num_types; ++type_idx) { |
| argdef_inds.push_back({arg_idx, type_idx}); |
| } |
| } else { |
| int num_repeat = 1; |
| if (node.attr().count(arg_def.number_attr())) { |
| num_repeat = node.attr().at(arg_def.number_attr()).i(); |
| } |
| argdef_inds.insert(argdef_inds.end(), num_repeat, {arg_idx, -1}); |
| } |
| return argdef_inds; |
| } |
| |
| // Returns a pair (arg_index, type_index) for each input to the node, where |
| // arg_index is the index of the input_arg in op_def and type_index is the index |
| // of the type in type_list_attr (only defined for list arguments). |
| std::vector<std::pair<int, int>> InputPortArgDefIndexes(const NodeDef& node, |
| const OpDef& op_def) { |
| std::vector<std::pair<int, int>> argdef_inds; |
| argdef_inds.reserve(op_def.input_arg_size()); // Final size may differ. |
| for (int arg_idx = 0; arg_idx < op_def.input_arg_size(); ++arg_idx) { |
| const OpDef::ArgDef& arg_def = op_def.input_arg(arg_idx); |
| auto arg_results = ArgDefIndexes(node, arg_idx, arg_def); |
| argdef_inds.insert(argdef_inds.end(), arg_results.begin(), |
| arg_results.end()); |
| } |
| return argdef_inds; |
| } |
| |
| // Returns a pair (arg_index, type_index) for each output to the node, where |
| // arg_index is the index of the output_arg in op_def and type_index is the |
| // index of the type in type_list_attr (only defined for list arguments). |
| std::vector<std::pair<int, int>> OutputPortArgDefIndexes(const NodeDef& node, |
| const OpDef& op_def) { |
| std::vector<std::pair<int, int>> argdef_inds; |
| argdef_inds.reserve(op_def.output_arg_size()); // Final size may differ. |
| for (int arg_idx = 0; arg_idx < op_def.output_arg_size(); ++arg_idx) { |
| const OpDef::ArgDef& arg_def = op_def.output_arg(arg_idx); |
| auto arg_results = ArgDefIndexes(node, arg_idx, arg_def); |
| argdef_inds.insert(argdef_inds.end(), arg_results.begin(), |
| arg_results.end()); |
| } |
| return argdef_inds; |
| } |
| |
| TypeAttrId GetTypeAttrId(const OpDef::ArgDef& arg_def, int arg_type_index) { |
| if (!arg_def.type_list_attr().empty()) { |
| return TypeAttrId(arg_def.type_list_attr(), arg_type_index); |
| } else if (!arg_def.type_attr().empty()) { |
| return TypeAttrId(arg_def.type_attr()); |
| } else { |
| return TypeAttrId(arg_def.type()); |
| } |
| } |
| |
| std::vector<int> NonControlInputs(const NodeDef& node) { |
| std::vector<int> pos; |
| for (int i = 0; i < node.input_size(); i++) { |
| if (!IsControlInput(node.input(i))) { |
| pos.push_back(i); |
| } |
| } |
| return pos; |
| } |
| |
| // A utility class to lookup node type attributes and type attribute <-> |
| // input/output port mappings. |
| class NodeTypeAttrMap { |
| public: |
| NodeTypeAttrMap() {} |
| |
| explicit NodeTypeAttrMap(const GraphDef& graph) { TF_CHECK_OK(Init(graph)); } |
| |
| Status Init(const GraphDef& graph) { |
| if (graph_ != nullptr) { |
| return errors::InvalidArgument("NodeTypeAttrMap is already initialized."); |
| } |
| graph_ = &graph; |
| function_library_.reset( |
| new FunctionLibraryDefinition(OpRegistry::Global(), graph.library())); |
| for (const NodeDef& node : graph.node()) { |
| TF_RETURN_IF_ERROR(AddNode(node)); |
| } |
| return OkStatus(); |
| } |
| |
| bool is_initialized() const { return graph_ != nullptr; } |
| |
| // Returns the set of all type attributes in the given node. |
| absl::flat_hash_set<TypeAttrId> GetTypeAttrs(const NodeDef& node) const { |
| DCHECK(is_initialized()) << "NodeTypeAttrMap is not initialized"; |
| absl::flat_hash_set<TypeAttrId> type_attrs; |
| const auto iter = type2io_.find(&node); |
| CHECK(iter != type2io_.end()); // Crash Ok |
| for (const auto& key_value : iter->second) { |
| type_attrs.insert(key_value.first); |
| } |
| return type_attrs; |
| } |
| |
| const absl::flat_hash_set<int>& GetInputPorts( |
| const NodeDef& node, const TypeAttrId& type_attr) const { |
| DCHECK(is_initialized()) << "NodeTypeAttrMap is not initialized"; |
| return type2io_.at(&node).at(type_attr).first; |
| } |
| |
| const absl::flat_hash_set<int>& GetOutputPorts( |
| const NodeDef& node, const TypeAttrId& type_attr) const { |
| DCHECK(is_initialized()) << "NodeTypeAttrMap is not initialized"; |
| return type2io_.at(&node).at(type_attr).second; |
| } |
| |
| TypeAttrId GetInputTypeAttr(const NodeDef& node, int port) const { |
| DCHECK(is_initialized()) << "NodeTypeAttrMap is not initialized"; |
| const auto iter = io2type_.find(&node); |
| DCHECK(iter != io2type_.end()) |
| << "Node " << node.name() << " doesn't exist in a graph"; |
| auto type_vec = io2type_.at(&node).first; |
| CHECK_GE(port, 0); // Crash Ok |
| CHECK_LT(port, type_vec.size()); // Crash Ok |
| return type_vec[port]; |
| } |
| |
| TypeAttrId GetOutputTypeAttr(const NodeDef& node, int port) const { |
| DCHECK(is_initialized()) << "NodeTypeAttrMap is not initialized"; |
| auto type_vec = io2type_.at(&node).second; |
| CHECK_GE(port, 0); // Crash Ok |
| CHECK_LT(port, type_vec.size()); // Crash Ok |
| return type_vec[port]; |
| } |
| |
| private: |
| Status AddNode(const NodeDef& node) { |
| const OpDef* op_def_ptr = nullptr; |
| TF_RETURN_IF_ERROR(function_library_->LookUpOpDef(node.op(), &op_def_ptr)); |
| const OpDef& op_def = *op_def_ptr; |
| auto& type2io_entry = type2io_[&node]; |
| auto& io2type_entry = io2type_[&node]; |
| auto input_arg_inds = InputPortArgDefIndexes(node, op_def); |
| if (NonControlInputs(node).size() != input_arg_inds.size()) { |
| return errors::InvalidArgument( |
| "Expected ", node.op(), " node ", node.name(), " to have ", |
| input_arg_inds.size(), " non-control input(s), but got ", |
| node.input_size()); |
| } |
| // Note that the mappings generated here include inputs/outputs with fixed |
| // types. This makes the mappings complete (all inputs and outputs are |
| // included), and allows the graph rewriter to propagate deny paint |
| // from/through ops with fixed types. |
| io2type_entry.first.reserve(input_arg_inds.size()); |
| for (int i = 0; i < static_cast<int>(input_arg_inds.size()); ++i) { |
| const auto& arg_inds = input_arg_inds[i]; |
| const OpDef::ArgDef& arg_def = op_def.input_arg(arg_inds.first); |
| TypeAttrId type_attr = GetTypeAttrId(arg_def, arg_inds.second); |
| if (!type_attr.attr_name.empty() && |
| !node.attr().count(type_attr.attr_name)) { |
| return errors::InvalidArgument("Type attribute ", type_attr.attr_name, |
| " is not present in node ", node.name()); |
| } |
| type2io_entry[type_attr].first.insert(i); |
| io2type_entry.first.push_back(type_attr); |
| } |
| |
| auto output_arg_inds = OutputPortArgDefIndexes(node, op_def); |
| io2type_entry.second.reserve(output_arg_inds.size()); |
| for (int i = 0; i < static_cast<int>(output_arg_inds.size()); ++i) { |
| const auto& arg_inds = output_arg_inds[i]; |
| const OpDef::ArgDef& arg_def = op_def.output_arg(arg_inds.first); |
| TypeAttrId type_attr = GetTypeAttrId(arg_def, arg_inds.second); |
| if (!type_attr.attr_name.empty() && |
| !node.attr().count(type_attr.attr_name)) { |
| return errors::InvalidArgument("Type attribute ", type_attr.attr_name, |
| " is not present in node ", node.name()); |
| } |
| type2io_entry[type_attr].second.insert(i); |
| io2type_entry.second.push_back(type_attr); |
| } |
| |
| // Also ensure that type attributes that aren't associated with any inputs |
| // or outputs (e.g., StackV2's elem_type) are added to the map. |
| for (const auto& attr : node.attr()) { |
| const string& attr_name = attr.first; |
| if (!attr_name.empty() && attr_name[0] == '_') continue; |
| const AttrValue& attr_value = attr.second; |
| const OpDef::AttrDef* attr_def = FindAttr(attr_name, op_def); |
| if (!attr_def) { |
| return errors::InvalidArgument("AttrDef not found for attribute ", |
| attr_name, " of node ", node.name()); |
| } |
| if (attr_def->type() == "type") { |
| type2io_entry[TypeAttrId(attr_name)]; |
| } else if (attr_def->type() == "list(type)") { |
| for (int i = 0; i < attr_value.list().type_size(); ++i) { |
| type2io_entry[TypeAttrId(attr_name, i)]; |
| } |
| } |
| } |
| return OkStatus(); |
| } |
| |
| // WARN: `graph_` must outlive this object (node pointers must remain valid). |
| const GraphDef* graph_ = nullptr; // do not own |
| std::unique_ptr<FunctionLibraryDefinition> function_library_; |
| |
| typedef absl::flat_hash_set<int> IntSet; |
| // Maps a type attr id -> (input port set, output port set) |
| typedef absl::flat_hash_map<TypeAttrId, std::pair<IntSet, IntSet>> Type2IOMap; |
| // Maps a node -> type attr mapping |
| absl::flat_hash_map<const NodeDef*, Type2IOMap> type2io_; |
| // Maps a port -> type attr id |
| typedef std::vector<TypeAttrId> TypeAttrIdVec; |
| // Maps a node -> (input port mapping, output port mapping) |
| absl::flat_hash_map<const NodeDef*, std::pair<TypeAttrIdVec, TypeAttrIdVec>> |
| io2type_; |
| }; |
| |
| struct NodeTypeId { |
| NodeTypeId(const NodeDef* _node, const TypeAttrId& _type_attr) |
| : node(_node), type_attr(_type_attr) {} |
| |
| const NodeDef* node; |
| TypeAttrId type_attr; |
| |
| bool operator==(const NodeTypeId& other) const { |
| return node == other.node && type_attr == other.type_attr; |
| } |
| |
| template <typename H> |
| friend H AbslHashValue(H h, const NodeTypeId& nt) { |
| return H::combine(std::move(h), nt.node, nt.type_attr); |
| } |
| }; |
| |
| struct NodeTypeIdEdge { |
| NodeTypeIdEdge(const NodeTypeId& _src, const NodeTypeId& _dst) |
| : src(_src), dst(_dst) {} |
| NodeTypeId src; |
| NodeTypeId dst; |
| }; |
| |
| // TODO(benbarsdell): Investigate whether the existing GraphTopologyView can be |
| // used instead of this modified version. |
| // This is just like GraphTopologyView but with (NodeDef, TypeAttrId) pairs as |
| // the vertices instead of just NodeDef. |
| // For example, if node A has output A:0 with TypeAttrId 'T', and node B has |
| // input B:0 with TypeAttrId 'U', and input B:0 connects to output A:0, there |
| // will be an edge from (A, T) to (B, U). |
| class GraphTypeTopologyView { |
| public: |
| GraphTypeTopologyView() = default; |
| explicit GraphTypeTopologyView(bool skip_invalid_edges) |
| : skip_invalid_edges_(skip_invalid_edges) {} |
| |
| // Initialize graph topology view from the graph. It's possible to pass |
| // additional edges that do not exist in a graph, but must be respected when |
| // computing graph topology. Example: Tensorflow runtime allows concurrent |
| // execution of dequeue/enqueue ops from the same queue resource, but we might |
| // want to enforce ordering between them for the purpose of graph analysis. |
| Status InitializeFromGraph(const GraphDef& graph, |
| const NodeTypeAttrMap& node_type_map); |
| |
| Status AddEphemeralEdges(absl::Span<const NodeTypeIdEdge> ephemeral_edges); |
| |
| bool is_initialized() const { return graph_ != nullptr; } |
| int num_nodes() const { return num_nodes_; } |
| const GraphDef* graph() const { return graph_; } |
| |
| // Returns true iff the node exists in the underlying graph. |
| bool HasNode(absl::string_view node_name, const TypeAttrId& type_attr) const; |
| |
| // Finds a node by name or returns `nullptr` if it's not in the graph. |
| const NodeTypeId* GetNode(absl::string_view node_name, |
| const TypeAttrId& type_attr) const; |
| // Returns a node corresponding to the given node index. |
| const NodeTypeId* GetNode(int node_idx) const; |
| |
| // Returns a node index for the given node name, if the name exists in the |
| // underlying graph. Otherwise returns empty optional. |
| const absl::optional<int> GetNodeIndex(absl::string_view node_name, |
| const TypeAttrId& type_attr) const; |
| // Returns a node index for the given node, if the node belongs to the |
| // underlying graph. Otherwise returns empty optional. |
| const absl::optional<int> GetNodeIndex(const NodeTypeId& node) const; |
| |
| // Returns all the node indexes that are in the direct fanin of the given |
| // node. If the `node_idx` is outside of [0, num_nodes_) returns empty vector. |
| const absl::InlinedVector<int, 4>& GetFanin(int node_idx) const; |
| // Returns all the node indexes that are in the direct fanout of the given |
| // node. If the `node_idx` is outside of [0, num_nodes_) returns empty vector. |
| const absl::InlinedVector<int, 2>& GetFanout(int node_idx) const; |
| |
| private: |
| // The key type used to uniquely identify a type attribute on a node. |
| struct NodeTypeKey : public std::pair<absl::string_view, TypeAttrId> { |
| typedef std::pair<absl::string_view, TypeAttrId> Base; |
| |
| // Inherit the set of constructors. |
| using Base::pair; |
| |
| template <typename H> |
| friend H AbslHashValue(H h, const NodeTypeKey& nt) { |
| return H::combine(std::move(h), nt.first, nt.second); |
| } |
| }; |
| |
| // If true, all invalid edges and inputs (srd, dst or input node not found in |
| // a graph) will be skipped, otherwise initialization will fail with error. |
| bool skip_invalid_edges_ = false; |
| |
| // WARN: `graph_` must outlive this object and graph nodes must not be |
| // destructed, because node names captured with absl::string_view. |
| const GraphDef* graph_ = nullptr; // do not own |
| int num_nodes_ = 0; |
| std::vector<NodeTypeId> node_type_attrs_; |
| absl::flat_hash_map<absl::string_view, int> node_name_to_index_; |
| absl::flat_hash_map<NodeTypeKey, int> node_type_name_to_index_; |
| |
| std::vector<absl::InlinedVector<int, 4>> fanins_; |
| std::vector<absl::InlinedVector<int, 2>> fanouts_; |
| |
| // We need a valid reference to return from GetFanin/GetFanout if the |
| // `node_idx` argument is outside of the [0, num_nodes_) range. |
| absl::InlinedVector<int, 4> empty_fanin_; |
| absl::InlinedVector<int, 2> empty_fanout_; |
| }; |
| |
| template <typename T> |
| inline void SortAndRemoveDuplicates(T* v) { |
| std::sort(v->begin(), v->end()); |
| v->erase(std::unique(v->begin(), v->end()), v->end()); |
| } |
| |
| Status GraphTypeTopologyView::InitializeFromGraph( |
| const GraphDef& graph, const NodeTypeAttrMap& node_type_map) { |
| if (graph_ != nullptr) { |
| return errors::InvalidArgument( |
| "GraphTypeTopologyView is already initialized."); |
| } |
| |
| graph_ = &graph; |
| int num_nodedefs = graph.node_size(); |
| node_name_to_index_.rehash(num_nodedefs); |
| |
| // Build maps from name to index. |
| node_type_attrs_.reserve(num_nodedefs); // Only approximate. |
| node_type_name_to_index_.rehash(num_nodedefs); // Only approximate. |
| for (int node_idx = 0; node_idx < num_nodedefs; ++node_idx) { |
| const NodeDef& node = graph.node(node_idx); |
| node_name_to_index_.emplace(node.name(), node_idx); |
| |
| for (const TypeAttrId& type_attr : node_type_map.GetTypeAttrs(node)) { |
| int node_type_idx = node_type_attrs_.size(); |
| node_type_name_to_index_.emplace(NodeTypeKey(node.name(), type_attr), |
| node_type_idx); |
| node_type_attrs_.emplace_back(&node, type_attr); |
| } |
| } |
| num_nodes_ = node_type_attrs_.size(); |
| fanins_.resize(num_nodes_); |
| fanouts_.resize(num_nodes_); |
| |
| // Add graph edges to the adjacency lists. |
| for (int node_type_idx = 0; node_type_idx < num_nodes_; ++node_type_idx) { |
| const NodeTypeId& node_type = node_type_attrs_.at(node_type_idx); |
| auto input_ports = |
| node_type_map.GetInputPorts(*node_type.node, node_type.type_attr); |
| fanins_[node_type_idx].reserve(input_ports.size()); |
| for (int port : input_ports) { |
| const string& input = node_type.node->input(port); |
| TensorId tensor = ParseTensorName(input); |
| const auto it = node_name_to_index_.find(tensor.node()); |
| const bool valid_input = it != node_name_to_index_.end(); |
| |
| if (!valid_input) { |
| const string error_message = absl::StrCat( |
| "Non-existent input ", input, " in node ", node_type.node->name()); |
| if (skip_invalid_edges_) { |
| VLOG(3) << "Skip error: " << error_message; |
| } else { |
| return errors::InvalidArgument(error_message); |
| } |
| } |
| |
| if (valid_input) { |
| const int input_idx = it->second; |
| const NodeDef& input_node = graph_->node(input_idx); |
| TypeAttrId input_type_attr = |
| node_type_map.GetOutputTypeAttr(input_node, tensor.index()); |
| const auto it2 = node_type_name_to_index_.find( |
| NodeTypeKey(input_node.name(), input_type_attr)); |
| if (it2 == node_type_name_to_index_.end()) { |
| if (!skip_invalid_edges_) { |
| return errors::InvalidArgument("Did not find type attr ", |
| input_type_attr.DebugString(), |
| " in node ", input_node.name()); |
| } |
| continue; |
| } |
| int input_node_type_idx = it2->second; |
| fanins_[node_type_idx].push_back(input_node_type_idx); |
| fanouts_[input_node_type_idx].push_back(node_type_idx); |
| } |
| } |
| |
| // Dedup the input list while it's still hot in cache. |
| SortAndRemoveDuplicates(&fanins_[node_type_idx]); |
| } |
| |
| // Dedup outputs for all the graph nodes. |
| for (int node_type_idx = 0; node_type_idx < num_nodes_; ++node_type_idx) { |
| SortAndRemoveDuplicates(&fanouts_[node_type_idx]); |
| } |
| |
| return OkStatus(); |
| } |
| |
| Status GraphTypeTopologyView::AddEphemeralEdges( |
| absl::Span<const NodeTypeIdEdge> ephemeral_edges) { |
| // Add ephemeral edges to the adjacency lists. |
| for (const NodeTypeIdEdge& edge : ephemeral_edges) { |
| const auto src = node_name_to_index_.find(edge.src.node->name()); |
| const bool valid_src = src != node_name_to_index_.end(); |
| |
| if (!valid_src) { |
| const string error_message = |
| absl::StrCat("Non-existent src node: ", edge.src.node->name()); |
| if (skip_invalid_edges_) { |
| VLOG(0) << "Skip error: " << error_message; |
| } else { |
| return errors::InvalidArgument(error_message); |
| } |
| } |
| |
| const auto dst = node_name_to_index_.find(edge.dst.node->name()); |
| const bool valid_dst = dst != node_name_to_index_.end(); |
| |
| if (!valid_dst) { |
| const string error_message = |
| absl::StrCat("Non-existent dst node: ", edge.dst.node->name()); |
| if (skip_invalid_edges_) { |
| VLOG(0) << "Skip error: " << error_message; |
| } else { |
| return errors::InvalidArgument(error_message); |
| } |
| } |
| |
| if (valid_dst && valid_src) { |
| // TODO(benbarsdell): Check for failure. |
| int src_node_type_idx = node_type_name_to_index_.at( |
| NodeTypeKey(edge.src.node->name(), edge.src.type_attr)); |
| int dst_node_type_idx = node_type_name_to_index_.at( |
| NodeTypeKey(edge.dst.node->name(), edge.dst.type_attr)); |
| fanins_[dst_node_type_idx].push_back(src_node_type_idx); |
| fanouts_[src_node_type_idx].push_back(dst_node_type_idx); |
| } |
| } |
| |
| // Dedup inputs and outputs for all the graph nodes. |
| for (int node_type_idx = 0; node_type_idx < num_nodes_; ++node_type_idx) { |
| SortAndRemoveDuplicates(&fanins_[node_type_idx]); |
| SortAndRemoveDuplicates(&fanouts_[node_type_idx]); |
| } |
| |
| return OkStatus(); |
| } |
| |
| bool GraphTypeTopologyView::HasNode(absl::string_view node_name, |
| const TypeAttrId& type_attr) const { |
| DCHECK(is_initialized()) << "GraphTypeTopologyView is not initialized"; |
| NodeTypeKey key(node_name, type_attr); |
| const auto it = node_type_name_to_index_.find(key); |
| return it != node_type_name_to_index_.end(); |
| } |
| |
| const NodeTypeId* GraphTypeTopologyView::GetNode( |
| absl::string_view node_name, const TypeAttrId& type_attr) const { |
| DCHECK(is_initialized()) << "GraphTypeTopologyView is not initialized"; |
| NodeTypeKey key(node_name, type_attr); |
| const auto it = node_type_name_to_index_.find(key); |
| return it == node_type_name_to_index_.end() |
| ? nullptr |
| : &node_type_attrs_.at(it->second); |
| } |
| |
| const NodeTypeId* GraphTypeTopologyView::GetNode(int node_idx) const { |
| DCHECK(is_initialized()) << "GraphTypeTopologyView is not initialized"; |
| DCHECK(node_idx >= 0 && node_idx < num_nodes_) << "node_idx is out of range"; |
| return &node_type_attrs_.at(node_idx); |
| } |
| |
| const absl::optional<int> GraphTypeTopologyView::GetNodeIndex( |
| absl::string_view node_name, const TypeAttrId& type_attr) const { |
| DCHECK(is_initialized()) << "GraphTypeTopologyView is not initialized"; |
| NodeTypeKey key(node_name, type_attr); |
| const auto it = node_type_name_to_index_.find(key); |
| DCHECK(it != node_type_name_to_index_.end()) |
| << "Node doesn't exist in a graph"; |
| return it == node_type_name_to_index_.end() ? absl::nullopt |
| : absl::make_optional(it->second); |
| } |
| |
| const absl::optional<int> GraphTypeTopologyView::GetNodeIndex( |
| const NodeTypeId& node) const { |
| return GetNodeIndex(node.node->name(), node.type_attr); |
| } |
| |
| const absl::InlinedVector<int, 4>& GraphTypeTopologyView::GetFanin( |
| int node_idx) const { |
| DCHECK(is_initialized()) << "GraphTypeTopologyView is not initialized"; |
| const bool is_valid_node_idx = node_idx >= 0 && node_idx < num_nodes_; |
| DCHECK(is_valid_node_idx) << "node_idx is out of range"; |
| return is_valid_node_idx ? fanins_[node_idx] : empty_fanin_; |
| } |
| |
| const absl::InlinedVector<int, 2>& GraphTypeTopologyView::GetFanout( |
| int node_idx) const { |
| DCHECK(is_initialized()) << "GraphTypeTopologyView is not initialized"; |
| const bool is_valid_node_idx = node_idx >= 0 && node_idx < num_nodes_; |
| DCHECK(is_valid_node_idx) << "node_idx is out of range"; |
| return is_valid_node_idx ? fanouts_[node_idx] : empty_fanout_; |
| } |
| |
| enum class TypeTraversalDirection { |
| kFollowInputs, |
| kFollowOutputs, |
| kFollowInputsAndOutputs, |
| }; |
| |
| // Encapsulate DFS callbacks that will be called during the graph traversal. |
| // |
| // If non-empty, the `pre_order` and `post_order` functors will be called on |
| // each reachable node (including the `from` nodes) in pre and post order. If |
| // loops are found, the `on_back_edge` functor will be called on the |
| // corresponding back edges. Moreover, the pre and post order will assume that |
| // these back edges will be cut. |
| struct DfsTypeCallbacks { |
| DfsTypeCallbacks() = default; |
| DfsTypeCallbacks(std::function<void(int)> pre, std::function<void(int)> post, |
| std::function<void(int, int)> back_edge) |
| : pre_order(std::move(pre)), |
| post_order(std::move(post)), |
| on_back_edge(std::move(back_edge)) {} |
| |
| static DfsTypeCallbacks PreOrder(std::function<void(int)> pre) { |
| return DfsTypeCallbacks(std::move(pre), nullptr, nullptr); |
| } |
| |
| static DfsTypeCallbacks PostOrder(std::function<void(int)> post) { |
| return DfsTypeCallbacks(nullptr, std::move(post), nullptr); |
| } |
| |
| std::function<void(int)> pre_order; |
| std::function<void(int)> post_order; |
| std::function<void(int, int)> on_back_edge; |
| }; |
| |
| // Encapsulate DFS predicates for traversing the graph. |
| // |
| // The `enter` predicate decides if traversal should enter the node, and the |
| // `advance` predicate decides if the traversal should follow inputs/outputs |
| // from the node. |
| // |
| // If predicates are empty (default initialized), it's assumed that we can enter |
| // into any node and advance from any node respectively. |
| struct DfsTypePredicates { |
| DfsTypePredicates() = default; |
| DfsTypePredicates(std::function<bool(int)> enter, |
| std::function<bool(int)> advance) |
| : enter(std::move(enter)), advance(std::move(advance)) {} |
| |
| static DfsTypePredicates Enter(std::function<bool(int)> enter) { |
| return DfsTypePredicates(std::move(enter), nullptr); |
| } |
| |
| static DfsTypePredicates Advance(std::function<bool(int)> advance) { |
| return DfsTypePredicates(nullptr, std::move(advance)); |
| } |
| |
| std::function<bool(int)> enter; |
| std::function<bool(int)> advance; |
| }; |
| |
| struct DfsStackElem { |
| DfsStackElem(int node, bool children_visited, int src) |
| : node(node), children_visited(children_visited), src(src) {} |
| explicit DfsStackElem(int node) : DfsStackElem(node, false, -1) {} |
| |
| // Index of the node in the graph ∊ [0, num_nodes). |
| int node; |
| // `True` if visited all the input/output nodes (pushed all input/output nodes |
| // to the stack). |
| bool children_visited; |
| // Index of the node in the graph, from which we entered the `node`. |
| int src; |
| }; |
| |
| enum class NodeState { kNotVisited, kVisiting, kDone }; |
| |
| void DfsTypeTraversal(const GraphTypeTopologyView& graph_type_view, |
| const absl::Span<const NodeTypeId* const> from, |
| const TypeTraversalDirection direction, |
| const DfsTypePredicates& predicates, |
| const DfsTypeCallbacks& callbacks) { |
| std::vector<DfsStackElem> stack; |
| stack.reserve(from.size()); |
| |
| for (const NodeTypeId* node : from) { |
| const absl::optional<int> node_idx = graph_type_view.GetNodeIndex(*node); |
| DCHECK(node_idx.has_value()) |
| << "Illegal start node: " << node->node->name(); |
| if (node_idx.has_value()) { |
| stack.emplace_back(node_idx.value()); |
| } |
| } |
| |
| absl::flat_hash_map<int, NodeState> node_state; |
| while (!stack.empty()) { |
| DfsStackElem w = stack.back(); |
| stack.pop_back(); |
| |
| NodeState& state = node_state[w.node]; |
| if (state == NodeState::kDone) continue; |
| |
| // Skip nodes that we should not enter. |
| if (predicates.enter && !predicates.enter(w.node)) { |
| state = NodeState::kDone; |
| continue; |
| } |
| |
| // We've processed all the children of this node. |
| if (w.children_visited) { |
| state = NodeState::kDone; |
| if (callbacks.post_order) { |
| callbacks.post_order(w.node); |
| } |
| continue; |
| } |
| |
| // Loop detected. |
| if (state == NodeState::kVisiting) { |
| if (callbacks.on_back_edge) { |
| callbacks.on_back_edge(w.src, w.node); |
| } |
| continue; |
| } |
| |
| state = NodeState::kVisiting; |
| if (callbacks.pre_order) { |
| callbacks.pre_order(w.node); |
| } |
| |
| // Enqueue the node again with the children_visited flag set to true. |
| stack.emplace_back(w.node, true, w.src); |
| |
| // Check if we can continue traversal from the current node. |
| if (predicates.advance && !predicates.advance(w.node)) { |
| continue; |
| } |
| |
| // Now enqueue the fanin/fanout nodes. |
| if (direction == TypeTraversalDirection::kFollowInputs || |
| direction == TypeTraversalDirection::kFollowInputsAndOutputs) { |
| for (const int fanin : graph_type_view.GetFanin(w.node)) { |
| stack.emplace_back(fanin, false, w.node); |
| } |
| } |
| if (direction == TypeTraversalDirection::kFollowOutputs || |
| direction == TypeTraversalDirection::kFollowInputsAndOutputs) { |
| for (const int fanout : graph_type_view.GetFanout(w.node)) { |
| stack.emplace_back(fanout, false, w.node); |
| } |
| } |
| } |
| } |
| |
| DataTypeSet AllowedDataTypes(const OpDef::AttrDef& attr_def) { |
| const auto& allowed_types = attr_def.allowed_values().list().type(); |
| if (allowed_types.empty()) { |
| return AllTypes(); |
| } |
| uint32 dtype_mask = 0; |
| for (int dtype : allowed_types) { |
| dtype_mask |= 1u << dtype; |
| } |
| return DataTypeSet(dtype_mask); |
| } |
| |
| DataTypeSet AllowedDataTypes(const OpDef& op_def, const TypeAttrId& t_attr_id) { |
| if (t_attr_id.attr_name.empty()) { |
| return ToSet(t_attr_id.fixed_type); |
| } |
| const OpDef::AttrDef* attr_def = FindAttr(t_attr_id.attr_name, op_def); |
| CHECK(attr_def); // Crash Ok |
| return AllowedDataTypes(*attr_def); |
| } |
| |
| Status ValidateLists(const gtl::FlatSet<string>& allow_list, |
| const gtl::FlatSet<string>& deny_list, |
| const gtl::FlatSet<string>& infer_list, |
| const gtl::FlatSet<string>& clear_list) { |
| std::vector<gtl::FlatSet<string>> lists{allow_list, deny_list, infer_list, |
| clear_list}; |
| std::multiset<string> counts; |
| for (const auto& list : lists) { |
| counts.insert(list.begin(), list.end()); |
| } |
| bool duplicates = false; |
| for (const auto& s : counts) { |
| if (counts.count(s) > 1) { |
| duplicates = true; |
| LOG(ERROR) << "Op present in multiple lists: " << s; |
| } |
| } |
| if (duplicates) { |
| return errors::InvalidArgument("Op lists have conflicting entries"); |
| } else { |
| return OkStatus(); |
| } |
| } |
| |
| bool HasInputOrOutputRefs(const NodeDef& node) { |
| const OpDef* op_def; |
| Status status = OpRegistry::Global()->LookUpOpDef(node.op(), &op_def); |
| if (!status.ok()) { |
| return true; |
| } |
| for (const auto& input : op_def->input_arg()) { |
| if (input.is_ref()) { |
| return true; |
| } |
| } |
| for (const auto& output : op_def->output_arg()) { |
| if (output.is_ref()) { |
| return true; |
| } |
| } |
| return false; |
| } |
| |
| // See TF issue 25977 for no-FP16 on SCEWL |
| bool CanForceFP16(const NodeDef& node) { |
| return node.op() != "Const" && node.op() != "SoftmaxCrossEntropyWithLogits" && |
| !IsStateful(node) && !HasInputOrOutputRefs(node); |
| } |
| |
| int GetCudaVersion( |
| const std::unordered_map<string, DeviceProperties>& devices) { |
| for (const auto& device : devices) { |
| const DeviceProperties& device_properties = device.second; |
| if (device_properties.type() == "GPU") { |
| const auto& device_env = device_properties.environment(); |
| auto it = device_env.find("cuda"); |
| if (it != device_env.end()) { |
| string cuda_version_str = it->second; |
| return std::stoi(cuda_version_str); |
| } |
| } |
| } |
| return 0; |
| } |
| |
| int GetCudnnVersion( |
| const std::unordered_map<string, DeviceProperties>& devices) { |
| for (const auto& device : devices) { |
| const DeviceProperties& device_properties = device.second; |
| if (device_properties.type() == "GPU") { |
| const auto& device_env = device_properties.environment(); |
| auto it = device_env.find("cudnn"); |
| if (it != device_env.end()) { |
| string cudnn_version_str = it->second; |
| return std::stoi(cudnn_version_str); |
| } |
| } |
| } |
| return 0; |
| } |
| |
| std::unordered_map<string, DeviceProperties> GetDevices(Cluster* cluster) { |
| if (!ShouldSimulateGpu()) { |
| return cluster->GetDevices(); |
| } |
| |
| bool has_gpu = false; |
| for (const auto& device : cluster->GetDevices()) { |
| const DeviceProperties& device_properties = device.second; |
| if (device_properties.type() == "GPU") { |
| has_gpu = true; |
| break; |
| } |
| } |
| |
| if (has_gpu) { |
| return cluster->GetDevices(); |
| } |
| |
| std::unordered_map<string, DeviceProperties> devices(cluster->GetDevices()); |
| DeviceProperties gpu_device_properies; |
| gpu_device_properies.set_type("GPU"); |
| gpu_device_properies.set_vendor("NVIDIA"); |
| gpu_device_properies.mutable_environment()->insert({"architecture", "8.0"}); |
| gpu_device_properies.mutable_environment()->insert({"cuda", "11050"}); |
| gpu_device_properies.mutable_environment()->insert({"cudnn", "8302"}); |
| devices.emplace(std::make_pair("/job:localhost/replica:0/task:0/device:GPU:0", |
| gpu_device_properies)); |
| return devices; |
| } |
| |
| class AutoMixedPrecisionImpl { |
| public: |
| // CastType indicates the type of inserted Cast op |
| // FP16: cast to float16 |
| // FP32: cast to float32 |
| // AUTO: cast to a data type that matches the required data type at fanouts |
| enum class CastType { FP16, FP32, AUTO }; |
| AutoMixedPrecisionImpl(Cluster* cluster, |
| const std::unordered_set<string>& nodes_to_preserve, |
| GraphDef* graph, string id, |
| AutoMixedPrecisionMode mode) |
| : devices_(GetDevices(cluster)), |
| virtual_placer_(devices_), |
| nodes_to_preserve_(nodes_to_preserve), |
| graph_(graph), |
| function_library_(OpRegistry::Global(), graph->library()), |
| id_(id), |
| graph_view_(graph), |
| cuda_version_(GetCudaVersion(devices_)), |
| cudnn_version_(GetCudnnVersion(devices_)), |
| num_nonvar_casts_to_f16_(0), |
| mode_(mode), |
| target_dtype_((mode_ == AutoMixedPrecisionMode::CUDA || |
| mode_ == AutoMixedPrecisionMode::CPU) |
| ? DT_HALF |
| : DT_BFLOAT16) {} |
| |
| Status Optimize(); |
| |
| private: |
| typedef absl::flat_hash_set<NodeTypeId> NodeTypeIdSet; |
| |
| std::unique_ptr<AutoMixedPrecisionLists> get_mixed_precision_lists() const { |
| switch (mode_) { |
| case AutoMixedPrecisionMode::CUDA: |
| return std::make_unique<AutoMixedPrecisionListsCuda>(cuda_version_, |
| cudnn_version_); |
| case AutoMixedPrecisionMode::MKL: |
| return std::make_unique<AutoMixedPrecisionListsMkl>(); |
| case AutoMixedPrecisionMode::CPU: |
| // Note: this is not a typo here. AutoMixedPrecisionListsCuda is used |
| // intentionally to make CPU and GPU have the same fp16 ops. |
| return std::make_unique<AutoMixedPrecisionListsCuda>( |
| /*cuda_version=*/10000, // Hardcode cuda and cudnn version so |
| /*cudnn_version=*/8000); // CPU emulates the same ops on GPU. |
| } |
| } |
| Status PrintDebugLogs(bool preop, size_t timestamp); |
| void LogSkippedNode(const NodeDef& node, const string& device_type) const; |
| bool MustPreserve(const NodeDef& node) const; |
| bool IsOnDevice(const NodeDef& node, const string& device_type) const; |
| bool IsOnSuitableGPUArch(const NodeDef& node) const; |
| bool ShouldProcess(const NodeDef& node) const; |
| bool NodeHasF16KernelForTypeAttr(const NodeDef& node, TypeAttrId taid) const; |
| bool NodeImplicitlyReadsNonResourceVariable(const NodeDef& node) const; |
| void ConvertBatchNormOpsToV2(); |
| bool SupportsF16(const NodeTypeId& node_type) const; |
| bool SupportsF16DataType(const NodeTypeId& node_type) const; |
| bool IsQuantized(const NodeTypeId& node_type) const; |
| const NodeTypeId* GetTensorListFloat32NodeTypeId(const NodeDef& node) const; |
| bool IsSourceOrSinkOp(const string& op) const; |
| void FindFloat32TensorListOpClustersAndDenylistUnsafe( |
| std::vector<absl::flat_hash_set<const NodeDef*>>* clusters, |
| absl::flat_hash_set<int>* deny_set) const; |
| void FindTensorListImplicitFloat32Edges( |
| const absl::flat_hash_set<const NodeDef*>& tensor_list_nodes, |
| std::vector<NodeTypeIdEdge>* implicit_fp32_edges) const; |
| void AddAllowlistOps(absl::flat_hash_set<int>* allow_set) const; |
| void RemoveAllowsetWithFp32(absl::flat_hash_set<int>* allow_set) const; |
| void PropagateDenyFwdThroughClearAndInfer( |
| absl::flat_hash_set<int>* deny_set) const; |
| void ForceColorMatchBetweenTensorListOps( |
| const absl::flat_hash_set<const NodeDef*>& tensor_list_nodes, |
| absl::flat_hash_set<int>* allow_set, |
| absl::flat_hash_set<int>* deny_set) const; |
| void AddClearAndInferToAllowIfBetweenAllow( |
| const absl::flat_hash_set<int>& deny_set, |
| absl::flat_hash_set<int>* allow_set) const; |
| void AddInferToAllowIfFollowAllow(const absl::flat_hash_set<int>& deny_set, |
| absl::flat_hash_set<int>* allow_set) const; |
| void PropagateAllowThroughClear(const absl::flat_hash_set<int>& deny_set, |
| absl::flat_hash_set<int>* allow_set) const; |
| Status ForceColorMatchOnRecurrentEdges( |
| absl::flat_hash_set<int>* allow_set) const; |
| void MakeCastsAllowIfAllOutputsAllow( |
| absl::flat_hash_set<int>* allow_set) const; |
| NodeDef BuildCastNode(const MutableGraphView::OutputPort& src, |
| const MutableGraphView::InputPort& dst, bool to_f16, |
| const string& device) const; |
| StatusOr<NodeDef*> InsertCastNodeAtFanout( |
| const absl::flat_hash_set<int>& allow_set, const bool src_is_allow, |
| const CastType& cast_type, MutableGraphView::OutputPort& src); |
| |
| StatusOr<DataType> GetCastToType(const NodeDef* node) const; |
| void CollectOutputPorts( |
| const TypeAttrId& type_attr, NodeDef* node, |
| std::vector<MutableGraphView::OutputPort>& output_ports) const; |
| Status ChangeTypeAttrsAndAddCasts(const absl::flat_hash_set<int>& allow_set); |
| |
| std::unordered_map<string, DeviceProperties> devices_; |
| VirtualPlacer virtual_placer_; |
| std::unordered_set<string> nodes_to_preserve_; |
| GraphDef* graph_; |
| FunctionLibraryDefinition function_library_; |
| string id_; |
| MutableGraphView graph_view_; |
| int cuda_version_; |
| int cudnn_version_; |
| int num_nonvar_casts_to_f16_; |
| NodeTypeAttrMap node_type_map_; |
| GraphTypeTopologyView graph_type_view_; |
| bool force_all_fp16_; |
| bool treat_infer_as_deny_; |
| AutoMixedPrecisionMode mode_; |
| gtl::FlatSet<string> f16_allowlist_; |
| gtl::FlatSet<string> f16_denylist_; |
| gtl::FlatSet<string> f16_inferlist_; |
| gtl::FlatSet<string> f16_clearlist_; |
| absl::flat_hash_set<const NodeDef*> should_process_nodes_; |
| DataType target_dtype_; // Either DT_HALF or DT_BFLOAT16 |
| }; |
| |
| NodeDef AutoMixedPrecisionImpl::BuildCastNode( |
| const MutableGraphView::OutputPort& src, |
| const MutableGraphView::InputPort& dst, bool to_f16, |
| const string& device) const { |
| DataType src_type = to_f16 ? DT_FLOAT : target_dtype_; |
| DataType dst_type = to_f16 ? target_dtype_ : DT_FLOAT; |
| const char* cast_string = !to_f16 ? kCastToFp32 |
| : target_dtype_ == DT_HALF ? kCastToFp16 |
| : kCastToBf16; |
| string name = |
| strings::StrCat(src.node->name(), "-", src.port_id, "-", dst.node->name(), |
| "-", dst.port_id, "-", cast_string, "-", kSuffix); |
| NodeDef node; |
| node.set_name(name); |
| node.set_op("Cast"); |
| node.set_device(device); |
| node.add_input(strings::StrCat(src.node->name(), ":", src.port_id)); |
| (*node.mutable_attr())["SrcT"].set_type(src_type); |
| (*node.mutable_attr())["DstT"].set_type(dst_type); |
| (*node.mutable_attr())["Truncate"].set_b(false); |
| return node; |
| } |
| |
| bool AutoMixedPrecisionImpl::NodeHasF16KernelForTypeAttr( |
| const NodeDef& node, TypeAttrId taid) const { |
| NodeDef node_copy(node); |
| if (node.device().empty()) { |
| string device_name = virtual_placer_.get_canonical_device_name(node); |
| node_copy.set_device(device_name); |
| } |
| if (!SetDataType(&node_copy, taid, target_dtype_)) { |
| return false; |
| } |
| return IsKernelRegisteredForNode(node_copy).ok(); |
| } |
| |
| Status AutoMixedPrecisionImpl::PrintDebugLogs(bool preop, size_t timestamp) { |
| string prepend_path; |
| TF_RETURN_IF_ERROR(ReadStringFromEnvVar( |
| "TF_AUTO_MIXED_PRECISION_GRAPH_REWRITE_LOG_PATH", "", &prepend_path)); |
| if (prepend_path.empty()) return OkStatus(); |
| |
| string suffix = |
| strings::StrCat("_", preop ? "preop" : kSuffix, "_", id_, "_", timestamp); |
| |
| string fname = |
| io::JoinPath(prepend_path, strings::StrCat("graphdef", suffix, ".pb")); |
| std::fstream f; |
| f.open(fname.c_str(), std::fstream::out | std::fstream::binary); |
| f << graph_->SerializeAsString(); |
| f.close(); |
| LOG(INFO) << "Saved " << (preop ? "pre-optimization" : "post-optimization") |
| << " graph as binary to " << fname; |
| |
| fname = io::JoinPath(prepend_path, |
| strings::StrCat("graphdef", suffix, ".pb.txt")); |
| f.open(fname.c_str(), std::fstream::out); |
| f << graph_->DebugString(); |
| f.close(); |
| LOG(INFO) << "Saved " << (preop ? "pre-optimization" : "post-optimization") |
| << " graph as text to " << fname; |
| |
| if (!preop) { |
| fname = io::JoinPath(prepend_path, |
| strings::StrCat("paintbuckets", suffix, ".txt")); |
| f.open(fname.c_str(), std::fstream::out); |
| std::unique_ptr<AutoMixedPrecisionLists> mp_lists = |
| get_mixed_precision_lists(); |
| f << "AllowList:\n"; |
| for (const auto& x : mp_lists->AllowList()) { |
| f << x << "\n"; |
| } |
| f << "\nDenyList:\n"; |
| for (const auto& x : mp_lists->DenyList()) { |
| f << x << "\n"; |
| } |
| f << "\nInferList:\n"; |
| for (const auto& x : mp_lists->InferList()) { |
| f << x << "\n"; |
| } |
| f << "\nClearList:\n"; |
| for (const auto& x : mp_lists->ClearList()) { |
| f << x << "\n"; |
| } |
| f.close(); |
| LOG(INFO) << "Saved paint bucket info to " << fname; |
| } |
| return OkStatus(); |
| } |
| |
| void AutoMixedPrecisionImpl::LogSkippedNode(const NodeDef& node, |
| const string& device_type) const { |
| VLOG(2) << "Skipping " << node.op() << " node " << node.name() |
| << " because it " |
| << (MustPreserve(node) |
| ? "must be preserved" |
| : absl::StrFormat( |
| "is not on the %s, or the %s arch is not suitable", |
| device_type, device_type)); |
| } |
| |
| bool AutoMixedPrecisionImpl::MustPreserve(const NodeDef& node) const { |
| return nodes_to_preserve_.count(node.name()); |
| } |
| |
| bool AutoMixedPrecisionImpl::IsOnDevice(const NodeDef& node, |
| const string& device_type) const { |
| string device_name; |
| if (node.device().empty()) { |
| device_name = virtual_placer_.get_canonical_device_name(node); |
| } else { |
| device_name = node.device(); |
| } |
| string device; |
| string not_used; |
| if (DeviceNameUtils::SplitDeviceName(device_name, ¬_used, &device) && |
| absl::StrContains(absl::AsciiStrToLower(device), |
| absl::AsciiStrToLower(device_type))) { |
| return true; |
| } |
| return false; |
| } |
| |
| bool AutoMixedPrecisionImpl::IsOnSuitableGPUArch(const NodeDef& node) const { |
| return HasFastFP16Support(virtual_placer_.get_device(node)); |
| } |
| |
| bool AutoMixedPrecisionImpl::ShouldProcess(const NodeDef& node) const { |
| return should_process_nodes_.count(&node); |
| } |
| |
| bool IsFloat32(const NodeTypeId& node_type) { |
| return GetDataType(*node_type.node, node_type.type_attr) == |
| DataType::DT_FLOAT; |
| } |
| |
| bool IsTensorListOp(const string& op) { |
| return absl::StrContains(op, "TensorList"); |
| } |
| |
| bool IsTensorListReaderOp(const string& op) { |
| static const gtl::FlatSet<string> tensor_list_reader_ops = { |
| "TensorListConcat", "TensorListConcatV2", "TensorListGather", |
| "TensorListGetItem", "TensorListPopBack", "TensorListStack"}; |
| return tensor_list_reader_ops.count(op); |
| } |
| |
| bool IsTensorListWriterOp(const string& op) { |
| static const gtl::FlatSet<string> tensor_list_writer_ops = { |
| "TensorListFromTensor", "TensorListPushBack", |
| "TensorListPushBackBatch", "TensorListScatter", |
| "TensorListScatterV2", "TensorListScatterIntoExistingList", |
| "TensorListSetItem", "TensorListSplit"}; |
| return tensor_list_writer_ops.count(op); |
| } |
| |
| bool AutoMixedPrecisionImpl::SupportsF16(const NodeTypeId& node_type) const { |
| const OpDef* op_def; |
| Status status = |
| OpRegistry::Global()->LookUpOpDef(node_type.node->op(), &op_def); |
| if (!status.ok()) return false; |
| return AllowedDataTypes(*op_def, node_type.type_attr) |
| .Contains(target_dtype_) && |
| NodeHasF16KernelForTypeAttr(*node_type.node, node_type.type_attr); |
| } |
| |
| bool AutoMixedPrecisionImpl::SupportsF16DataType( |
| const NodeTypeId& node_type) const { |
| const OpDef* op_def; |
| Status status = |
| OpRegistry::Global()->LookUpOpDef(node_type.node->op(), &op_def); |
| if (!status.ok()) return false; |
| return AllowedDataTypes(*op_def, node_type.type_attr).Contains(target_dtype_); |
| } |
| |
| bool AutoMixedPrecisionImpl::IsQuantized(const NodeTypeId& node_type) const { |
| for (const TypeAttrId& type_attr : |
| node_type_map_.GetTypeAttrs(*node_type.node)) { |
| if (DataTypeIsQuantized(GetDataType(*node_type.node, type_attr))) { |
| return true; |
| } |
| } |
| return false; |
| } |
| |
| // TODO(mconley): Make this change the node's name (to aid debugging). Need to |
| // make sure that doing this won't break anything. |
| void AutoMixedPrecisionImpl::ConvertBatchNormOpsToV2() { |
| for (int node_idx = 0; node_idx < graph_->node_size(); ++node_idx) { |
| NodeDef* node = graph_->mutable_node(node_idx); |
| if (!ShouldProcess(*node)) continue; |
| bool changed = false; |
| if (node->op() == "FusedBatchNorm") { |
| VLOG(2) << "Changing op of " << node->op() << " node " << node->name() |
| << " to FusedBatchNormV2"; |
| node->set_op("FusedBatchNormV2"); |
| changed = true; |
| } else if (node->op() == "FusedBatchNormGrad") { |
| VLOG(2) << "Changing op of " << node->op() << " node " << node->name() |
| << " to FusedBatchNormGradV2"; |
| node->set_op("FusedBatchNormGradV2"); |
| changed = true; |
| } |
| if (changed) { |
| (*node->mutable_attr())["U"].set_type(DT_FLOAT); |
| } |
| } |
| } |
| |
| // A helper function to decide whether to ignore the effect on performance when |
| // rewriting the graph. This can be useful for testing the numerical effects of |
| // reduced precision on systems that have poor mixed precision performance. |
| bool ShouldIgnorePerformance() { |
| static bool is_enabled = [] { |
| bool ret = false; |
| TF_CHECK_OK(ReadBoolFromEnvVar( |
| "TF_AUTO_MIXED_PRECISION_GRAPH_REWRITE_IGNORE_PERFORMANCE", |
| /*default_val=*/false, &ret)); |
| return ret; |
| }(); |
| return is_enabled; |
| } |
| |
| Status AutoMixedPrecisionImpl::Optimize() { |
| string optimization_level; |
| TF_RETURN_IF_ERROR(ReadStringFromEnvVar( |
| "TF_AUTO_MIXED_PRECISION_GRAPH_REWRITE_LEVEL", "", &optimization_level)); |
| optimization_level = absl::AsciiStrToUpper(optimization_level); |
| force_all_fp16_ = optimization_level == "UNSAFE_FORCE_ALL"; |
| if (force_all_fp16_ && mode_ == AutoMixedPrecisionMode::MKL) { |
| // Many ops do not support bfloat16 on the CPU so we disallowing forcing to |
| // bfloat16. |
| return errors::InvalidArgument( |
| "TF_AUTO_MIXED_PRECISION_GRAPH_REWRITE_LEVEL cannot be set to " |
| "UNSAFE_FORCE_ALL when MKL is used"); |
| } |
| |
| treat_infer_as_deny_ = optimization_level == "TREAT_INFER_AS_DENY"; |
| VLOG(2) << "Optimization Level: " << optimization_level; |
| |
| std::unique_ptr<AutoMixedPrecisionLists> mp_lists = |
| get_mixed_precision_lists(); |
| f16_allowlist_ = mp_lists->AllowList(); |
| f16_denylist_ = mp_lists->DenyList(); |
| |
| if (treat_infer_as_deny_) { |
| for (const auto& op : mp_lists->InferList()) { |
| f16_denylist_.insert(op); |
| } |
| } else { |
| f16_inferlist_ = mp_lists->InferList(); |
| } |
| |
| f16_clearlist_ = mp_lists->ClearList(); |
| TF_RETURN_IF_ERROR(ValidateLists(f16_allowlist_, f16_denylist_, |
| f16_inferlist_, f16_clearlist_)); |
| |
| size_t timestamp = Env::Default()->NowMicros() / 1000; |
| TF_RETURN_IF_ERROR(PrintDebugLogs(/* preop = */ true, timestamp)); |
| |
| VLOG(2) << "Identifying nodes that should be processed"; |
| for (const NodeDef& node : graph_->node()) { |
| bool should_process; |
| string device_type; |
| switch (mode_) { |
| case AutoMixedPrecisionMode::CUDA: |
| device_type = DEVICE_GPU; |
| should_process = |
| !MustPreserve(node) && IsOnDevice(node, device_type) && |
| (ShouldIgnorePerformance() || IsOnSuitableGPUArch(node)); |
| break; |
| case AutoMixedPrecisionMode::MKL: |
| case AutoMixedPrecisionMode::CPU: |
| device_type = DEVICE_CPU; |
| should_process = !MustPreserve(node) && IsOnDevice(node, device_type); |
| break; |
| } |
| if (should_process) { |
| should_process_nodes_.insert(&node); |
| } else { |
| LogSkippedNode(node, device_type); |
| } |
| } |
| |
| VLOG(2) << "Converting FusedBatchNorm* ops to V2"; |
| ConvertBatchNormOpsToV2(); |
| |
| VLOG(2) << "Building node type map for graph"; |
| TF_RETURN_IF_ERROR(node_type_map_.Init(*graph_)); |
| |
| VLOG(2) << "Constructing graph type attribute topology view"; |
| TF_RETURN_IF_ERROR( |
| graph_type_view_.InitializeFromGraph(*graph_, node_type_map_)); |
| |
| absl::flat_hash_set<int> deny_set; |
| |
| std::vector<absl::flat_hash_set<const NodeDef*>> tensor_list_clusters; |
| FindFloat32TensorListOpClustersAndDenylistUnsafe(&tensor_list_clusters, |
| &deny_set); |
| std::vector<NodeTypeIdEdge> ephemeral_edges; |
| for (const auto& cluster : tensor_list_clusters) { |
| VLOG(1) << "Found safe Tensor List cluster of size " << cluster.size(); |
| for (const NodeDef* node : cluster) { |
| VLOG(2) << " Cluster member: " << node->op() << " node " << node->name(); |
| } |
| FindTensorListImplicitFloat32Edges(cluster, &ephemeral_edges); |
| } |
| TF_RETURN_IF_ERROR(graph_type_view_.AddEphemeralEdges(ephemeral_edges)); |
| |
| // The goal here is to change performance-critical ops to fp16 or bf16, and to |
| // do so with the minimal number of casts, subject to the constraint that the |
| // model's convergence is not affected. This is achieved by first identifying |
| // which nodes should be changed to f16 and then inserting casts at the |
| // boundaries between f16/non-f16 nodes. |
| |
| // The algorithm for deciding which nodes to change to f16 is as follows: |
| // 1) Add all performance-critical ops (aka "allowlist" ops) to the allow_set. |
| // This is done under the assumption that allowlist ops are always |
| // numerically-safe in f16 and that they are the most important ops for |
| // improving performance. |
| // 2) Add nodes to the deny_set iff they are numerically-dangerous (aka |
| // "denylist" ops) or they are on a forward path from a denylist node to |
| // a deny/infer node (including the node at the end of the path) through |
| // non-numerically-dangerous ops (aka "inferlist" and "clearlist" ops). |
| // This is done to prevent numerically-dangerous ops and their downstream |
| // effects from being changed to f16, which would risk breaking the |
| // numerical accuracy of the model. |
| // 3) For all remaining nodes that are not considered dangerous (inferlist |
| // and clearlist ops), find those that are between (i.e., both upstream |
| // and downstream of) allow nodes, and add them to the allow_set. |
| // This is done to avoid unnecessary casts between allowlist ops. |
| // 4) For the remaining inferlist nodes, add them to the allow_set if they |
| // are immediate downstream of allow_set node. |
| // 5) For all remaining clearlist nodes, add them to the allow_set if they are |
| // connected to a node in the allow_set via other clearlist nodes. |
| // This is done to increase the number of ops in the allow_set without |
| // affecting numerical stability. |
| |
| absl::flat_hash_set<int> allow_set; |
| VLOG(2) << "Beginning pass 1 to add allowlist ops"; |
| AddAllowlistOps(&allow_set); |
| VLOG(2) << "Finished pass 1"; |
| |
| if (allow_set.empty()) { |
| LOG(INFO) << "No allowlist ops found, nothing to do"; |
| return OkStatus(); |
| } |
| |
| VLOG(2) << "Beginning pass 2 to propagate deny forwards from denylist ops " |
| "through clear/inferlist ops"; |
| PropagateDenyFwdThroughClearAndInfer(&deny_set); |
| VLOG(2) << "Finished pass 2"; |
| |
| VLOG(2) << "Forcing color match between data structure ops"; |
| for (const auto& cluster : tensor_list_clusters) { |
| ForceColorMatchBetweenTensorListOps(cluster, &allow_set, &deny_set); |
| } |
| |
| VLOG(2) << "Beginning pass 3 to set clear and infer nodes to allow if they " |
| "are between allow ops"; |
| AddClearAndInferToAllowIfBetweenAllow(deny_set, &allow_set); |
| VLOG(2) << "Finished pass 3"; |
| |
| VLOG(2) << "Beginning pass 4 to add infer list ops to allow if they " |
| "directly follow allow nodes"; |
| AddInferToAllowIfFollowAllow(deny_set, &allow_set); |
| VLOG(2) << "Finished pass 4"; |
| |
| VLOG(2) << "Beginning pass 5 to propagate allow from allow nodes through " |
| "clearlist ops"; |
| PropagateAllowThroughClear(deny_set, &allow_set); |
| VLOG(2) << "Finished pass 5"; |
| |
| VLOG(2) << "Beginning pass 6 to remove some nodes which could not be changed " |
| "to F16" |
| "from allow set"; |
| RemoveAllowsetWithFp32(&allow_set); |
| VLOG(2) << "Finished pass 6"; |
| |
| VLOG(2) << "Forcing color match between data structure ops"; |
| for (const auto& cluster : tensor_list_clusters) { |
| ForceColorMatchBetweenTensorListOps(cluster, &allow_set, &deny_set); |
| } |
| |
| VLOG(2) << "Forcing color match on loop edges"; |
| TF_RETURN_IF_ERROR(ForceColorMatchOnRecurrentEdges(&allow_set)); |
| |
| VLOG(2) << "Finding existing casts that can be made allow"; |
| MakeCastsAllowIfAllOutputsAllow(&allow_set); |
| |
| VLOG(2) << "Beginning final pass to change type attributes and insert Cast " |
| "ops at paint boundaries"; |
| TF_RETURN_IF_ERROR(ChangeTypeAttrsAndAddCasts(allow_set)); |
| VLOG(2) << "Finished final pass"; |
| |
| TF_RETURN_IF_ERROR(PrintDebugLogs(/* preop = */ false, timestamp)); |
| |
| return OkStatus(); |
| } |
| |
| // If node is a Tensor List op with a float32 data type attribute then this |
| // returns a pointer to the NodeTypeId representing that type attribute. In |
| // all other cases this returns nullptr. |
| const NodeTypeId* AutoMixedPrecisionImpl::GetTensorListFloat32NodeTypeId( |
| const NodeDef& node) const { |
| if (!IsTensorListOp(node.op())) return nullptr; |
| for (const TypeAttrId& type_attr : node_type_map_.GetTypeAttrs(node)) { |
| const NodeTypeId* node_type = |
| graph_type_view_.GetNode(node.name(), type_attr); |
| // This assumes that the float32 data type on a Tensor List op is always a |
| // non-fixed type attribute containing a single type, and that this type |
| // attribute represents the dtype of the values in the list. |
| // TODO(benbarsdell): A new Tensor List op could theoretically break these |
| // assumptions. |
| if (node_type && node_type->type_attr.fixed_type == DT_INVALID && |
| node_type->type_attr.type_index == TypeAttrId::kSingleType && |
| IsFloat32(*node_type)) { |
| return node_type; |
| } |
| } |
| return nullptr; |
| } |
| |
| bool AutoMixedPrecisionImpl::IsSourceOrSinkOp(const string& op) const { |
| const gtl::FlatSet<string> source_and_sink_ops = { |
| "_Arg", |
| "_Retval", |
| "OptionalFromValue", |
| "OptionalGetValue", |
| "PartitionedCall", |
| "Placeholder", |
| "StatefulPartitionedCall", |
| }; |
| return source_and_sink_ops.count(op) || function_library_.Find(op); |
| } |
| |
| // Finds all clusters of float32 Tensor List nodes that are connected via their |
| // handle edges. Unsafe clusters (those with unprocessable nodes, or with edges |
| // that cross untraversable boundaries via _Arg, _Ret, PartitionedCall etc. |
| // nodes) are added to deny_set. The caller should paint all nodes in a cluster |
| // the same color, as they may all refer to the same Tensor List. |
| void AutoMixedPrecisionImpl::FindFloat32TensorListOpClustersAndDenylistUnsafe( |
| std::vector<absl::flat_hash_set<const NodeDef*>>* tensor_list_clusters, |
| absl::flat_hash_set<int>* deny_set) const { |
| absl::flat_hash_set<const NodeDef*> tensor_list_prop_set; |
| for (int root_idx = 0; root_idx < graph_type_view_.num_nodes(); ++root_idx) { |
| const NodeTypeId& root = *graph_type_view_.GetNode(root_idx); |
| if (!ShouldProcess(*root.node) || |
| root.type_attr.fixed_type != DataType::DT_VARIANT || |
| !GetTensorListFloat32NodeTypeId(*root.node) || |
| tensor_list_prop_set.count(root.node)) { |
| continue; |
| } |
| const NodeTypeId* root_fp32 = GetTensorListFloat32NodeTypeId(*root.node); |
| const absl::optional<int> maybe_root_fp32_idx = |
| graph_type_view_.GetNodeIndex(*root_fp32); |
| DCHECK(maybe_root_fp32_idx.has_value()) |
| << "Type attribute " << root_fp32->type_attr.DebugString() |
| << " of node " << root.node->name() << " not found in graph view"; |
| int root_fp32_idx = maybe_root_fp32_idx.value(); |
| // Traverse Tensor List handle edges (DT_VARIANT) to find cluster of all |
| // connected Tensor List nodes. |
| absl::flat_hash_set<const NodeDef*> cluster({root.node}); |
| DfsTypeTraversal(graph_type_view_, {&root}, |
| TypeTraversalDirection::kFollowInputsAndOutputs, |
| DfsTypePredicates::Enter([&](int idx) -> bool { |
| const NodeTypeId& item = *graph_type_view_.GetNode(idx); |
| return !tensor_list_prop_set.count(item.node); |
| }), |
| DfsTypeCallbacks::PreOrder([&](int idx) { |
| const NodeTypeId& item = *graph_type_view_.GetNode(idx); |
| const NodeDef* node = item.node; |
| if (GetTensorListFloat32NodeTypeId(*node)) { |
| cluster.insert(node); |
| if (!ShouldProcess(*node)) { |
| // The cluster contains an un-processable node. |
| deny_set->insert(root_fp32_idx); |
| } |
| // TODO(benbarsdell): In a theoretical pathological |
| // case of a Tensor List of Tensor List handles, the |
| // Tensor List itself would need to be treated as a |
| // sink. |
| } else if (IsSourceOrSinkOp(node->op())) { |
| // The cluster crosses an untraversable boundary. |
| deny_set->insert(root_fp32_idx); |
| } |
| })); |
| tensor_list_clusters->push_back(cluster); |
| } |
| } |
| |
| // Finds all writer -> reader pairs in the given set that are connected via |
| // their handles, and adds corresponding float32 edges to *implicit_fp32_edges. |
| void AutoMixedPrecisionImpl::FindTensorListImplicitFloat32Edges( |
| const absl::flat_hash_set<const NodeDef*>& tensor_list_nodes, |
| std::vector<NodeTypeIdEdge>* implicit_fp32_edges) const { |
| for (const NodeDef* root_node : tensor_list_nodes) { |
| if (!IsTensorListReaderOp(root_node->op())) continue; |
| NodeTypeId root(root_node, TypeAttrId(DataType::DT_VARIANT)); |
| const NodeTypeId* root_fp32 = GetTensorListFloat32NodeTypeId(*root.node); |
| CHECK(root_fp32) << "No float32 type attribute found for " // Crash OK |
| << root.node->op() << " node " << root.node->name(); |
| // Search backwards through handle edges (DT_VARIANT) for all writer ops, |
| // adding direct implicit edges between them and the reader. |
| DfsTypeTraversal( |
| graph_type_view_, {&root}, TypeTraversalDirection::kFollowInputs, |
| DfsTypePredicates::Enter([&](int idx) -> bool { |
| const NodeTypeId& item = *graph_type_view_.GetNode(idx); |
| return ShouldProcess(*item.node); |
| }), |
| DfsTypeCallbacks::PreOrder([&](int idx) { |
| const NodeTypeId& item = *graph_type_view_.GetNode(idx); |
| if (IsTensorListWriterOp(item.node->op())) { |
| const NodeTypeId* item_fp32 = |
| GetTensorListFloat32NodeTypeId(*item.node); |
| CHECK(item_fp32) // Crash OK |
| << "No float32 type attribute found for " << item.node->op() |
| << " node " << item.node->name(); |
| VLOG(2) << "Adding ephemeral float32 edge from " |
| << item_fp32->node->op() << " node " |
| << item_fp32->node->name() << " to " |
| << root_fp32->node->op() << " node " |
| << root_fp32->node->name(); |
| implicit_fp32_edges->emplace_back(*item_fp32, *root_fp32); |
| } |
| })); |
| } |
| } |
| |
| void AutoMixedPrecisionImpl::AddAllowlistOps( |
| absl::flat_hash_set<int>* allow_set) const { |
| // Add allowlisted ops to allow_set. |
| for (int root_idx = 0; root_idx < graph_type_view_.num_nodes(); ++root_idx) { |
| const NodeTypeId& root = *graph_type_view_.GetNode(root_idx); |
| if (!ShouldProcess(*root.node)) continue; |
| bool force_allow = force_all_fp16_ && CanForceFP16(*root.node); |
| if (f16_allowlist_.count(root.node->op()) || force_allow) { |
| bool inserted = allow_set->insert(root_idx).second; |
| if (VLOG_IS_ON(2) && inserted) { |
| VLOG(2) << "Painting type " << root.type_attr.DebugString() |
| << " of node " << root.node->name() << " ALLOW because its op " |
| << root.node->op() << " is on the allowlist"; |
| } |
| } |
| } |
| } |
| |
| // Adds nodes to deny_set iff they are on the denylist or they are on a |
| // forward path from a denylist node to a deny/infer node (including the node |
| // at the end of the path) through clear and infer nodes. |
| // E.g., deny -> infer -> clear -> infer -> clear -> allow -> infer |
| // becomes: deny -> deny -> deny -> deny -> clear -> allow -> infer. |
| void AutoMixedPrecisionImpl::PropagateDenyFwdThroughClearAndInfer( |
| absl::flat_hash_set<int>* deny_set) const { |
| if (force_all_fp16_) return; |
| |
| // Find clear nodes that are upstream of deny or infer. |
| absl::flat_hash_set<int> upstream_of_deny_or_infer_set; |
| for (int root_idx = 0; root_idx < graph_type_view_.num_nodes(); ++root_idx) { |
| const NodeTypeId& root = *graph_type_view_.GetNode(root_idx); |
| if (!(f16_denylist_.count(root.node->op()) || |
| f16_inferlist_.count(root.node->op()))) { |
| continue; |
| } |
| DfsTypeTraversal(graph_type_view_, {&root}, |
| TypeTraversalDirection::kFollowInputs, |
| DfsTypePredicates::Enter([&](int idx) -> bool { |
| const NodeTypeId& item = *graph_type_view_.GetNode(idx); |
| return idx == root_idx || |
| (!upstream_of_deny_or_infer_set.count(idx) && |
| f16_clearlist_.count(item.node->op())); |
| }), |
| DfsTypeCallbacks::PreOrder([&](int idx) { |
| upstream_of_deny_or_infer_set.insert(idx); |
| })); |
| } |
| |
| // Propagate deny forward through nodes in upstream_of_deny_or_infer_set. |
| for (int root_idx = 0; root_idx < graph_type_view_.num_nodes(); ++root_idx) { |
| const NodeTypeId& root = *graph_type_view_.GetNode(root_idx); |
| if (deny_set->count(root_idx) || !f16_denylist_.count(root.node->op())) { |
| continue; |
| } |
| DfsTypeTraversal( |
| graph_type_view_, {&root}, TypeTraversalDirection::kFollowOutputs, |
| DfsTypePredicates::Enter([&](int idx) -> bool { |
| return idx == root_idx || (!deny_set->count(idx) && |
| upstream_of_deny_or_infer_set.count(idx)); |
| }), |
| DfsTypeCallbacks::PreOrder([&](int idx) { |
| bool inserted = deny_set->insert(idx).second; |
| if (VLOG_IS_ON(2) && inserted) { |
| const NodeTypeId& item = *graph_type_view_.GetNode(idx); |
| VLOG(2) << "Painting type " << item.type_attr.DebugString() |
| << " of " << item.node->op() << " node " |
| << item.node->name() << " DENY"; |
| } |
| })); |
| } |
| } |
| |
| void AutoMixedPrecisionImpl::AddClearAndInferToAllowIfBetweenAllow( |
| const absl::flat_hash_set<int>& deny_set, |
| absl::flat_hash_set<int>* allow_set) const { |
| // Find clear/inferlist ops that are downstream of allow ops. |
| absl::flat_hash_set<int> downstream_of_allow_set; |
| for (int root_idx = 0; root_idx < graph_type_view_.num_nodes(); ++root_idx) { |
| const NodeTypeId& root = *graph_type_view_.GetNode(root_idx); |
| if (!ShouldProcess(*root.node) || !f16_allowlist_.count(root.node->op())) { |
| continue; |
| } |
| DfsTypeTraversal( |
| graph_type_view_, {&root}, TypeTraversalDirection::kFollowOutputs, |
| DfsTypePredicates::Enter([&](int idx) -> bool { |
| const NodeTypeId& item = *graph_type_view_.GetNode(idx); |
| return idx == root_idx || |
| (!downstream_of_allow_set.count(idx) && |
| !f16_allowlist_.count(item.node->op()) && |
| !deny_set.count(idx) && ShouldProcess(*item.node) && |
| // TODO(benbarsdell): Consider allowing propagation through |
| // ops that are already float16 in order to reduce the number |
| // of casts. |
| IsFloat32(item) && SupportsF16(item) && |
| (f16_clearlist_.count(item.node->op()) || |
| f16_inferlist_.count(item.node->op()))); |
| }), |
| DfsTypeCallbacks::PreOrder( |
| [&](int idx) { downstream_of_allow_set.insert(idx); })); |
| } |
| |
| // Set nodes that are both downstream and upstream of allow ops to allow. |
| absl::flat_hash_set<int> upstream_of_allow_set; |
| for (int root_idx = 0; root_idx < graph_type_view_.num_nodes(); ++root_idx) { |
| const NodeTypeId& root = *graph_type_view_.GetNode(root_idx); |
| if (!ShouldProcess(*root.node) || upstream_of_allow_set.count(root_idx) || |
| !f16_allowlist_.count(root.node->op())) { |
| continue; |
| } |
| DfsTypeTraversal( |
| graph_type_view_, {&root}, TypeTraversalDirection::kFollowInputs, |
| DfsTypePredicates::Enter([&](int idx) -> bool { |
| return idx == root_idx || (!upstream_of_allow_set.count(idx) && |
| downstream_of_allow_set.count(idx)); |
| }), |
| DfsTypeCallbacks::PreOrder([&](int idx) { |
| upstream_of_allow_set.insert(idx); |
| bool inserted = allow_set->insert(idx).second; |
| if (VLOG_IS_ON(2) && inserted) { |
| const NodeTypeId& item = *graph_type_view_.GetNode(idx); |
| VLOG(2) << "Painting type " << item.type_attr.DebugString() |
| << " of " << item.node->op() << " node " |
| << item.node->name() << " ALLOW"; |
| } |
| })); |
| } |
| } |
| |
| void AutoMixedPrecisionImpl::PropagateAllowThroughClear( |
| const absl::flat_hash_set<int>& deny_set, |
| absl::flat_hash_set<int>* allow_set) const { |
| // Propagate allow from allow nodes through clearlist ops. |
| absl::flat_hash_set<int> clear_prop_set; |
| for (int root_idx = 0; root_idx < graph_type_view_.num_nodes(); ++root_idx) { |
| const NodeTypeId& root = *graph_type_view_.GetNode(root_idx); |
| if (!ShouldProcess(*root.node) || clear_prop_set.count(root_idx) || |
| !allow_set->count(root_idx)) { |
| continue; |
| } |
| DfsTypeTraversal( |
| graph_type_view_, {&root}, |
| TypeTraversalDirection::kFollowInputsAndOutputs, |
| DfsTypePredicates::Enter([&](int idx) -> bool { |
| const NodeTypeId& item = *graph_type_view_.GetNode(idx); |
| return idx == root_idx || |
| (!allow_set->count(idx) && !deny_set.count(idx) && |
| ShouldProcess(*item.node) && IsFloat32(item) && |
| SupportsF16(item) && |
| (f16_clearlist_.count(item.node->op())) && |
| // We don't propagate (backwards) through nodes that read |
| // Variables because it can break the behavior of TensorBoard |
| // visualization and/or (in the case of Enter nodes) the model |
| // itself. This is only a problem for non-resource variables. |
| !NodeImplicitlyReadsNonResourceVariable(*item.node)); |
| }), |
| DfsTypeCallbacks::PreOrder([&](int idx) { |
| clear_prop_set.insert(idx); |
| bool inserted = allow_set->insert(idx).second; |
| if (VLOG_IS_ON(2) && inserted) { |
| const NodeTypeId& item = *graph_type_view_.GetNode(idx); |
| VLOG(2) << "Painting type " << item.type_attr.DebugString() |
| << " of " << item.node->op() << " node " |
| << item.node->name() << " ALLOW"; |
| } |
| })); |
| } |
| } |
| |
| // Set infer node to allow if its immediate upstream node is in allow set |
| void AutoMixedPrecisionImpl::AddInferToAllowIfFollowAllow( |
| const absl::flat_hash_set<int>& deny_set, |
| absl::flat_hash_set<int>* allow_set) const { |
| // Currently only target for MKL |
| if (mode_ != AutoMixedPrecisionMode::MKL) { |
| return; |
| } |
| for (int item_idx = 0; item_idx < graph_type_view_.num_nodes(); ++item_idx) { |
| const NodeTypeId& item = *graph_type_view_.GetNode(item_idx); |
| if (!ShouldProcess(*item.node) || deny_set.count(item_idx) || |
| allow_set->count(item_idx) || !f16_inferlist_.count(item.node->op()) || |
| !IsFloat32(item) || !SupportsF16DataType(item)) { |
| continue; |
| } |
| |
| bool has_allow_fanin = false; |
| for (const int fanin : graph_type_view_.GetFanin(item_idx)) { |
| if (deny_set.count(fanin)) { |
| has_allow_fanin = false; |
| break; |
| } |
| if (allow_set->count(fanin)) { |
| has_allow_fanin = true; |
| } |
| } |
| if (has_allow_fanin) { |
| bool inserted = allow_set->insert(item_idx).second; |
| if (VLOG_IS_ON(2) && inserted) { |
| VLOG(2) << "Painting type " << item.type_attr.DebugString() << " of " |
| << item.node->op() << " node " << item.node->name() << " ALLOW"; |
| } |
| } |
| } |
| } |
| |
| // If ops have one or more type_attr, But this type_attr could not be converted |
| // to F16. Such as FusedBatchNormV2/FusedBatchNormV3, its type_attr 'U' only |
| // support float. So we will remove this node from allow_set. |
| // Also don't convert quantized ops to FP16. |
| void AutoMixedPrecisionImpl::RemoveAllowsetWithFp32( |
| absl::flat_hash_set<int>* allow_set) const { |
| for (int root_idx = 0; root_idx < graph_type_view_.num_nodes(); ++root_idx) { |
| const NodeTypeId& root = *graph_type_view_.GetNode(root_idx); |
| if (f16_allowlist_.count(root.node->op()) && allow_set->count(root_idx) && |
| (!SupportsF16DataType(root) || IsQuantized(root))) { |
| auto erased = allow_set->erase(root_idx); |
| if (VLOG_IS_ON(2) && erased) { |
| VLOG(2) << "UnPainting type " << root.type_attr.DebugString() |
| << " of node " << root.node->name() << " ALLOW because its op " |
| << root.node->op() << " is not support F16 DataType"; |
| } |
| } |
| } |
| } |
| |
| // Forces NextIteration nodes and their output Merge node(s) to have the same |
| // color. Specifically, it removes them all from allow_set if any of the Merge |
| // nodes is not in allow_set, otherwise it adds the NextIteration node to |
| // allow_set. |
| Status AutoMixedPrecisionImpl::ForceColorMatchOnRecurrentEdges( |
| absl::flat_hash_set<int>* allow_set) const { |
| for (const NodeDef& node : graph_->node()) { |
| if (node.op() == "NextIteration") { |
| GraphView::OutputPort output_port(&node, 0); |
| const auto& fanout = graph_view_.GetFanout(output_port); |
| std::vector<int> merge_idxs; |
| merge_idxs.reserve(fanout.size()); |
| bool any_merge_is_not_allow = false; |
| for (const auto& output : fanout) { |
| const NodeDef& merge_node = *output.node; |
| if (merge_node.op() != "Merge") { |
| return errors::FailedPrecondition( |
| "Expected Merge node after NextIteration, got ", merge_node.op()); |
| } |
| const absl::optional<int> maybe_merge_idx = |
| graph_type_view_.GetNodeIndex(merge_node.name(), TypeAttrId("T")); |
| if (!maybe_merge_idx.has_value()) { |
| return errors::Internal("Type attribute T of Merge node ", |
| merge_node.name(), |
| " not found in graph view"); |
| } |
| int merge_idx = maybe_merge_idx.value(); |
| merge_idxs.push_back(merge_idx); |
| any_merge_is_not_allow = |
| any_merge_is_not_allow || !allow_set->count(merge_idx); |
| } |
| const absl::optional<int> maybe_nextiter_idx = |
| graph_type_view_.GetNodeIndex(node.name(), TypeAttrId("T")); |
| if (!maybe_nextiter_idx.has_value()) { |
| return errors::Internal("Type attribute T of NextIteration node ", |
| node.name(), " not found in graph view"); |
| } |
| int nextiter_idx = maybe_nextiter_idx.value(); |
| if (any_merge_is_not_allow) { |
| for (int merge_idx : merge_idxs) { |
| if (allow_set->erase(merge_idx)) { |
| VLOG(2) << "Painting type T of Merge node " |
| << graph_type_view_.GetNode(merge_idx)->node->name() |
| << " DENY to match the color of its sibling Merge nodes " |
| "with common NextIteration node " |
| << node.name(); |
| } |
| } |
| if (allow_set->erase(nextiter_idx)) { |
| VLOG(2) << "Painting type T of NextIteration node " << node.name() |
| << " DENY to match the color of its output Merge node(s)"; |
| } |
| } else { |
| if (allow_set->insert(nextiter_idx).second) { |
| VLOG(2) << "Painting type T of NextIteration node " << node.name() |
| << " ALLOW to match the color of its output Merge node(s)"; |
| } |
| } |
| } |
| } |
| return OkStatus(); |
| } |
| |
| // Forces all of the given Tensor List nodes into the same color set. |
| void AutoMixedPrecisionImpl::ForceColorMatchBetweenTensorListOps( |
| const absl::flat_hash_set<const NodeDef*>& tensor_list_nodes, |
| absl::flat_hash_set<int>* allow_set, |
| absl::flat_hash_set<int>* deny_set) const { |
| bool any_deny = false; |
| bool any_allow = false; |
| std::vector<int> node_type_idxs; |
| node_type_idxs.reserve(tensor_list_nodes.size()); |
| for (const NodeDef* node : tensor_list_nodes) { |
| const NodeTypeId& node_type = *GetTensorListFloat32NodeTypeId(*node); |
| const absl::optional<int> maybe_node_type_idx = |
| graph_type_view_.GetNodeIndex(node_type); |
| DCHECK(maybe_node_type_idx.has_value()) |
| << "Type attribute " << node_type.type_attr.DebugString() << " of node " |
| << node->name() << " not found in graph view"; |
| node_type_idxs.push_back(maybe_node_type_idx.value()); |
| } |
| for (int node_type_idx : node_type_idxs) { |
| if (deny_set->count(node_type_idx)) { |
| any_deny = true; |
| break; |
| } else if (allow_set->count(node_type_idx)) { |
| any_allow = true; |
| } |
| } |
| if (!any_deny && !any_allow) return; |
| for (int node_type_idx : node_type_idxs) { |
| const NodeTypeId& node_type = *graph_type_view_.GetNode(node_type_idx); |
| VLOG(2) << "Painting type " << node_type.type_attr.DebugString() << " of " |
| << node_type.node->op() << " node " << node_type.node->name() << " " |
| << (any_deny ? "DENY" : "ALLOW") |
| << " because at least one of its siblings is " |
| << (any_deny ? "DENY" : "ALLOW"); |
| if (any_deny) { |
| allow_set->erase(node_type_idx); |
| deny_set->insert(node_type_idx); |
| } else { |
| allow_set->insert(node_type_idx); |
| } |
| } |
| } |
| |
| bool AutoMixedPrecisionImpl::NodeImplicitlyReadsNonResourceVariable( |
| const NodeDef& node) const { |
| if (node.op() == "Identity" || node.op() == "Enter") { |
| GraphView::InputPort node_input(&node, 0); |
| MutableGraphView::OutputPort prev_output = |
| graph_view_.GetRegularFanin(node_input); |
| const NodeDef* input = prev_output.node; |
| if (input && ((node.op() == "Identity" && (input->op() == "Variable" || |
| input->op() == "VariableV2")) || |
| (node.op() == "Enter" && |
| NodeImplicitlyReadsNonResourceVariable(*input)))) { |
| return true; |
| } |
| } |
| return false; |
| } |
| |
| // This adds existing Cast nodes to allow_set if all of their outputs are allow, |
| // avoiding the need to add a new Cast node after an existing Cast. |
| void AutoMixedPrecisionImpl::MakeCastsAllowIfAllOutputsAllow( |
| absl::flat_hash_set<int>* allow_set) const { |
| int num_nodes_preop = graph_->node_size(); |
| for (int node_idx = 0; node_idx < num_nodes_preop; ++node_idx) { |
| NodeDef* node = graph_->mutable_node(node_idx); |
| NodeTypeId node_type(node, TypeAttrId("DstT")); |
| if (node->op() != "Cast" || !IsFloat32(node_type)) { |
| continue; |
| } |
| bool all_fanouts_allow = true; |
| MutableGraphView::OutputPort src(node, 0); |
| const auto& fanout = graph_view_.GetFanout(src); |
| for (const MutableGraphView::InputPort& dst : fanout) { |
| TypeAttrId dst_type_attr = |
| node_type_map_.GetInputTypeAttr(*dst.node, dst.port_id); |
| const absl::optional<int> maybe_dst_type_idx = |
| graph_type_view_.GetNodeIndex(dst.node->name(), dst_type_attr); |
| DCHECK(maybe_dst_type_idx.has_value()) |
| << "Type attribute " << dst_type_attr.DebugString() << " of node " |
| << dst.node->name() << " not found in graph view"; |
| int dst_type_idx = maybe_dst_type_idx.value(); |
| bool dst_is_allow = allow_set->count(dst_type_idx); |
| if (!dst_is_allow) { |
| all_fanouts_allow = false; |
| break; |
| } |
| } |
| if (!fanout.empty() && all_fanouts_allow) { |
| const absl::optional<int> maybe_node_type_idx = |
| graph_type_view_.GetNodeIndex(node_type); |
| DCHECK(maybe_node_type_idx.has_value()) |
| << "Type attribute " << node_type.type_attr.DebugString() |
| << " of node " << node_type.node->name() |
| << " not found in graph view"; |
| int node_type_idx = maybe_node_type_idx.value(); |
| allow_set->insert(node_type_idx); |
| } |
| } |
| } |
| |
| // Insert a Cast op at the output of a node. |
| // CastType indicates the type of inserted Cast op |
| // FP16: cast to float16 |
| // FP32: cast to float32 |
| // AUTO: cast to a data type that matches the fanout data type |
| StatusOr<NodeDef*> AutoMixedPrecisionImpl::InsertCastNodeAtFanout( |
| const absl::flat_hash_set<int>& allow_set, const bool src_is_allow, |
| const CastType& cast_type, MutableGraphView::OutputPort& src) { |
| NodeDef* added_cast_node = nullptr; |
| // Note: This is copied so that edges can be modified inside the loop. |
| auto fanout = graph_view_.GetFanout(src); |
| for (const MutableGraphView::InputPort& dst : fanout) { |
| TypeAttrId dst_type_attr = |
| node_type_map_.GetInputTypeAttr(*dst.node, dst.port_id); |
| const absl::optional<int> maybe_dst_type_idx = |
| graph_type_view_.GetNodeIndex(dst.node->name(), dst_type_attr); |
| if (!maybe_dst_type_idx.has_value()) { |
| return errors::Internal("Type attribute ", dst_type_attr.DebugString(), |
| " of ", dst.node->op(), " node ", |
| dst.node->name(), " not found in graph view"); |
| } |
| int dst_type_idx = maybe_dst_type_idx.value(); |
| bool dst_is_allow = allow_set.count(dst_type_idx); |
| bool to_f16 = false; |
| bool should_cast = false; |
| switch (cast_type) { |
| case CastType::AUTO: |
| if (src_is_allow != dst_is_allow) { |
| to_f16 = dst_is_allow; |
| should_cast = true; |
| } |
| break; |
| case CastType::FP16: |
| to_f16 = true; |
| should_cast = true; |
| break; |
| case CastType::FP32: |
| to_f16 = false; |
| should_cast = true; |
| break; |
| default: |
| return errors::Internal("Invalid Cast Type: ", |
| static_cast<int>(cast_type)); |
| } |
| |
| if (!should_cast) continue; |
| if (added_cast_node == nullptr) { |
| VLOG(1) << "Inserting cast to " |
| << (to_f16 ? DataTypeString(target_dtype_) : "DT_FLOAT") << " at " |
| << src.node->op() << " " << src.node->name() << ":" |
| << src.port_id; |
| added_cast_node = graph_view_.AddNode( |
| BuildCastNode(src, dst, to_f16, src.node->device())); |
| if (to_f16 && !IsConstant(*src.node) && !IsVariable(*src.node) && |
| !NodeImplicitlyReadsNonResourceVariable(*src.node)) { |
| ++num_nonvar_casts_to_f16_; |
| } |
| } |
| TF_RETURN_IF_ERROR(graph_view_.UpdateRegularFaninByPort( |
| dst.node->name(), dst.port_id, {added_cast_node->name(), 0})); |
| } |
| return added_cast_node; |
| } |
| |
| // Get the destination data type of a cast op. Return error if the node is not |
| // a Cast op. |
| StatusOr<DataType> AutoMixedPrecisionImpl::GetCastToType( |
| const NodeDef* node) const { |
| CHECK_EQ(node->op(), "Cast") // Crash OK |
| << "Node " << node->name() << " is not a Cast op"; |
| return node->attr().at("DstT").type(); |
| } |
| |
| // Collect the output ports of a node based on a type attribute and append them |
| // to a vector. |
| // Input: type_attr |
| // Input: node |
| // Output: output_ports |
| void AutoMixedPrecisionImpl::CollectOutputPorts( |
| const TypeAttrId& type_attr, NodeDef* node, |
| std::vector<MutableGraphView::OutputPort>& output_ports) const { |
| for (int port_id : node_type_map_.GetOutputPorts(*node, type_attr)) { |
| output_ports.emplace_back(node, port_id); |
| } |
| } |
| |
| // Changes all allow-painted type attributes to DT_HALF or DT_BFLOAT16, and |
| // inserts Cast nodes at node outputs for all edges that connect |
| // allow-painted <-> non-allow-painted type attributes. |
| Status AutoMixedPrecisionImpl::ChangeTypeAttrsAndAddCasts( |
| const absl::flat_hash_set<int>& allow_set) { |
| int num_nodes_changed = 0; |
| const int num_nodes_preop = graph_->node_size(); |
| |
| bool emulate_f16 = false; |
| if (mode_ == AutoMixedPrecisionMode::CPU) { |
| TF_CHECK_OK( |
| ReadBoolFromEnvVar("TF_AUTO_MIXED_PRECISION_GRAPH_REWRITE_EMULATE_FP16", |
| /*default_val=*/true, &emulate_f16)); |
| } |
| |
| VLOG(1) << "Setting emulate_f16 = " << emulate_f16; |
| |
| for (int node_idx = 0; node_idx < num_nodes_preop; ++node_idx) { |
| NodeDef* node = graph_->mutable_node(node_idx); |
| for (const TypeAttrId& type_attr : node_type_map_.GetTypeAttrs(*node)) { |
| const absl::optional<int> maybe_node_type_idx = |
| graph_type_view_.GetNodeIndex(node->name(), type_attr); |
| if (!maybe_node_type_idx.has_value()) { |
| return errors::Internal("Type attribute ", type_attr.DebugString(), |
| " of ", node->op(), " node ", node->name(), |
| " not found in graph view"); |
| } |
| int node_type_idx = maybe_node_type_idx.value(); |
| if (!IsFloat32(*graph_type_view_.GetNode(node_type_idx))) continue; |
| bool src_is_allow = allow_set.count(node_type_idx); |
| |
| // Include output ports of fp32 nodes, real fp16 nodes, |
| // and the fp16 Cast nodes at the fanout of emulated fp16 ops. |
| std::vector<MutableGraphView::OutputPort> output_ports; |
| |
| if (src_is_allow) { |
| if (emulate_f16) { |
| // For emulated fp16 op, we do not change the op type but instead |
| // insert fp32 Cast at the fanin and fp16 Cast at the fanout |
| for (int port_id : node_type_map_.GetInputPorts(*node, type_attr)) { |
| VLOG(2) << "Cast to F32 at fanin of node " << node->name() << ":" |
| << port_id; |
| MutableGraphView::InputPort dst(node, port_id); |
| MutableGraphView::OutputPort src = graph_view_.GetRegularFanin(dst); |
| NodeDef* added_cast_node = graph_view_.AddNode( |
| BuildCastNode(src, dst, /*to_f16=*/false, src.node->device())); |
| VLOG(1) << "Inserting cast to DT_FLOAT at " << src.node->op() << " " |
| << src.node->name() << ":" << src.port_id; |
| TF_RETURN_IF_ERROR(graph_view_.UpdateRegularFaninByPort( |
| dst.node->name(), dst.port_id, {added_cast_node->name(), 0})); |
| } |
| // Cast to fp16 at outputs |
| for (int port_id : node_type_map_.GetOutputPorts(*node, type_attr)) { |
| MutableGraphView::OutputPort src(node, port_id); |
| VLOG(2) << "Cast to F16 at fanout of node " << node->name() << ":" |
| << port_id; |
| TF_ASSIGN_OR_RETURN(NodeDef * added_cast_node, |
| InsertCastNodeAtFanout(allow_set, src_is_allow, |
| CastType::FP16, src)); |
| if (added_cast_node != nullptr) { |
| output_ports.emplace_back(added_cast_node, /*port_id=*/0); |
| } |
| } |
| } else { |
| VLOG(1) << "Changing type " << type_attr.DebugString() << " of " |
| << node->op() << " node " << node->name() << " to " |
| << DataTypeString(target_dtype_); |
| if (!SetDataType(node, type_attr, target_dtype_)) { |
| return errors::Internal("Failed to set type attribute"); |
| } |
| ++num_nodes_changed; |
| CollectOutputPorts(type_attr, node, output_ports); |
| } |
| } else { |
| CollectOutputPorts(type_attr, node, output_ports); |
| } |
| |
| // If the fanouts require a different data type from the output of the |
| // current node, insert a Cast op. |
| for (auto output_port : output_ports) { |
| VLOG(2) << "Cast to required data type at fanout of node " |
| << output_port.node->name() << ":" << output_port.port_id; |
| TF_RETURN_IF_ERROR(InsertCastNodeAtFanout(allow_set, src_is_allow, |
| CastType::AUTO, output_port) |
| .status()); |
| } |
| } |
| } |
| |
| // Use Python type names (e.g. float16) instead of C++ type names (e.g. half) |
| // since many Python users will see this message. |
| const char* type_str = target_dtype_ == DT_HALF ? "float16" : "bfloat16"; |
| LOG(INFO) << "Converted " << num_nodes_changed << "/" << num_nodes_preop |
| << " nodes to " << type_str << " precision using " |
| << num_nonvar_casts_to_f16_ << " cast(s) to " << type_str |
| << " (excluding Const and Variable casts)"; |
| return OkStatus(); |
| } |
| |
| int GetNumGPUs(const Cluster& cluster) { |
| if (ShouldSimulateGpu()) { |
| return 1; |
| } |
| auto devices = cluster.GetDevices(); |
| int num_gpus = 0; |
| for (const auto& device : devices) { |
| const DeviceProperties& device_properties = device.second; |
| if (device_properties.type() == "GPU" && |
| (ShouldIgnorePerformance() || HasFastFP16Support(device_properties))) { |
| num_gpus++; |
| } |
| } |
| return num_gpus; |
| } |
| |
| } // end namespace |
| |
| Status AutoMixedPrecision::Optimize(Cluster* cluster, const GrapplerItem& item, |
| GraphDef* output) { |
| if (cluster == nullptr) { |
| return errors::InvalidArgument("cluster == nullptr"); |
| } |
| |
| #if !defined(INTEL_MKL) |
| if (mode_ == AutoMixedPrecisionMode::MKL) { |
| return errors::Unimplemented( |
| "The auto_mixed_precision_mkl optimizer cannot be used since " |
| "this build of TensorFlow is not compiled with MKL support for " |
| "bfloat16. " |
| "For information on MKL builds, see: " |
| "https://software.intel.com/en-us/articles/intel-optimization-for-" |
| "tensorflow-installation-guide"); |
| } |
| #endif // INTEL_MKL |
| |
| // Start by copying input graph to output. |
| *output = item.graph; |
| |
| int num_gpus = GetNumGPUs(*cluster); |
| if (num_gpus < 1 && mode_ == AutoMixedPrecisionMode::CUDA) { |
| // AutoMixedPrecision is currently only tuned for GPU. |
| LOG(WARNING) << "No (suitable) GPUs detected, skipping " << name() |
| << " graph optimizer"; |
| return OkStatus(); |
| } |
| |
| // Optimize the output graph in-place. |
| AutoMixedPrecisionImpl optimizer(cluster, item.NodesToPreserve(), output, |
| item.id, mode_); |
| if (item.id == "tf_graph") { |
| LOG(INFO) << "Running " << name() << " graph optimizer"; |
| } else { |
| VLOG(1) << "Running " << name() << " graph optimizer on " << item.id; |
| } |
| Status status = optimizer.Optimize(); |
| if (!status.ok()) { |
| // Restore the original graph. |
| *output = item.graph; |
| LOG(WARNING) << name() << " graph optimizer FAILED: " << status.ToString(); |
| } |
| return status; |
| } |
| |
| } // end namespace grappler |
| } // end namespace tensorflow |