Adding TF_AUTO_MIXED_PRECISION_GRAPH_REWRITE_SIMULATE_GPU variable to auto_mixed_precision optimizer to allow running Graph rewrite targeting GPUs on machines without GPUs

PiperOrigin-RevId: 460485299
diff --git a/tensorflow/core/grappler/optimizers/BUILD b/tensorflow/core/grappler/optimizers/BUILD
index aa9d87e..84e272e 100644
--- a/tensorflow/core/grappler/optimizers/BUILD
+++ b/tensorflow/core/grappler/optimizers/BUILD
@@ -589,6 +589,7 @@
         "//tensorflow/core/grappler/costs:virtual_placer",
         "@com_google_absl//absl/container:flat_hash_map",
         "@com_google_absl//absl/container:flat_hash_set",
+        "@com_google_absl//absl/strings",
     ],
 )
 
diff --git a/tensorflow/core/grappler/optimizers/auto_mixed_precision.cc b/tensorflow/core/grappler/optimizers/auto_mixed_precision.cc
index 1ac329c..4690874 100644
--- a/tensorflow/core/grappler/optimizers/auto_mixed_precision.cc
+++ b/tensorflow/core/grappler/optimizers/auto_mixed_precision.cc
@@ -17,9 +17,11 @@
 
 #include <fstream>
 #include <memory>
+#include <unordered_map>
 
 #include "absl/container/flat_hash_map.h"
 #include "absl/container/flat_hash_set.h"
+#include "absl/strings/str_format.h"
 #include "tensorflow/core/framework/attr_value_util.h"
 #include "tensorflow/core/framework/node_def.pb.h"
 #include "tensorflow/core/framework/op.h"
@@ -40,12 +42,27 @@
 #include "tensorflow/core/lib/strings/str_util.h"
 #include "tensorflow/core/lib/strings/strcat.h"
 #include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/status.h"
 #include "tensorflow/core/util/env_var.h"
 
 namespace tensorflow {
 namespace grappler {
 namespace {
 
+bool ShouldSimulateGpu() {
+  bool is_enabled = [] {
+    bool ret = false;
+    string var;
+    TF_CHECK_OK(ReadStringFromEnvVar(
+        "TF_AUTO_MIXED_PRECISION_GRAPH_REWRITE_SIMULATE_GPU", "", &var));
+    TF_CHECK_OK(
+        ReadBoolFromEnvVar("TF_AUTO_MIXED_PRECISION_GRAPH_REWRITE_SIMULATE_GPU",
+                           /*default_val=*/false, &ret));
+    return ret;
+  }();
+  return is_enabled;
+}
+
 #if GOOGLE_CUDA
 const std::pair<int, int> kMinGPUArch = {7, 0};
 #else
@@ -99,7 +116,7 @@
   std::vector<std::string> gpu_arch = absl::StrSplit(gcnArchName, ":");
   return !gpu_arch.empty() && FP16SupportedDevices.contains(gpu_arch[0]);
 #endif
-  return false;
+  return ShouldSimulateGpu();
 }
 
 // Instances of this class represent unique type attribute identifiers within a
@@ -940,8 +957,8 @@
          !IsStateful(node) && !HasInputOrOutputRefs(node);
 }
 
-int GetCudaVersion(const Cluster& cluster) {
-  auto devices = cluster.GetDevices();
+int GetCudaVersion(
+    const std::unordered_map<string, DeviceProperties>& devices) {
   for (const auto& device : devices) {
     const DeviceProperties& device_properties = device.second;
     if (device_properties.type() == "GPU") {
@@ -956,8 +973,8 @@
   return 0;
 }
 
-int GetCudnnVersion(const Cluster& cluster) {
-  auto devices = cluster.GetDevices();
+int GetCudnnVersion(
+    const std::unordered_map<string, DeviceProperties>& devices) {
   for (const auto& device : devices) {
     const DeviceProperties& device_properties = device.second;
     if (device_properties.type() == "GPU") {
@@ -972,6 +989,36 @@
   return 0;
 }
 
+std::unordered_map<string, DeviceProperties> GetDevices(Cluster* cluster) {
+  if (!ShouldSimulateGpu()) {
+    return cluster->GetDevices();
+  }
+
+  bool has_gpu = false;
+  for (const auto& device : cluster->GetDevices()) {
+    const DeviceProperties& device_properties = device.second;
+    if (device_properties.type() == "GPU") {
+      has_gpu = true;
+      break;
+    }
+  }
+
+  if (has_gpu) {
+    return cluster->GetDevices();
+  }
+
+  std::unordered_map<string, DeviceProperties> devices(cluster->GetDevices());
+  DeviceProperties gpu_device_properies;
+  gpu_device_properies.set_type("GPU");
+  gpu_device_properies.set_vendor("NVIDIA");
+  gpu_device_properies.mutable_environment()->insert({"architecture", "8.0"});
+  gpu_device_properies.mutable_environment()->insert({"cuda", "11050"});
+  gpu_device_properies.mutable_environment()->insert({"cudnn", "8302"});
+  devices.emplace(std::make_pair("/job:localhost/replica:0/task:0/device:GPU:0",
+                                 gpu_device_properies));
+  return devices;
+}
+
 class AutoMixedPrecisionImpl {
  public:
   // CastType indicates the type of inserted Cast op
@@ -983,14 +1030,15 @@
                          const std::unordered_set<string>& nodes_to_preserve,
                          GraphDef* graph, string id,
                          AutoMixedPrecisionMode mode)
-      : virtual_placer_(cluster->GetDevices()),
+      : devices_(GetDevices(cluster)),
+        virtual_placer_(devices_),
         nodes_to_preserve_(nodes_to_preserve),
         graph_(graph),
         function_library_(OpRegistry::Global(), graph->library()),
         id_(id),
         graph_view_(graph),
-        cuda_version_(GetCudaVersion(*cluster)),
-        cudnn_version_(GetCudnnVersion(*cluster)),
+        cuda_version_(GetCudaVersion(devices_)),
+        cudnn_version_(GetCudnnVersion(devices_)),
         num_nonvar_casts_to_f16_(0),
         mode_(mode),
         target_dtype_((mode_ == AutoMixedPrecisionMode::CUDA ||
@@ -1019,7 +1067,7 @@
     }
   }
   Status PrintDebugLogs(bool preop, size_t timestamp);
-  void LogSkippedNode(const NodeDef& node) const;
+  void LogSkippedNode(const NodeDef& node, const string& device_type) const;
   bool MustPreserve(const NodeDef& node) const;
   bool IsOnDevice(const NodeDef& node, const string& device_type) const;
   bool IsOnSuitableGPUArch(const NodeDef& node) const;
@@ -1070,6 +1118,7 @@
       std::vector<MutableGraphView::OutputPort>& output_ports) const;
   Status ChangeTypeAttrsAndAddCasts(const absl::flat_hash_set<int>& allow_set);
 
+  std::unordered_map<string, DeviceProperties> devices_;
   VirtualPlacer virtual_placer_;
   std::unordered_set<string> nodes_to_preserve_;
   GraphDef* graph_;
@@ -1182,12 +1231,15 @@
   return OkStatus();
 }
 
-void AutoMixedPrecisionImpl::LogSkippedNode(const NodeDef& node) const {
+void AutoMixedPrecisionImpl::LogSkippedNode(const NodeDef& node,
+                                            const string& device_type) const {
   VLOG(2) << "Skipping " << node.op() << " node " << node.name()
           << " because it "
           << (MustPreserve(node)
                   ? "must be preserved"
-                  : "is not on the GPU, or the GPU arch is not suitable");
+                  : absl::StrFormat(
+                        "is not on the %s, or the %s arch is not suitable",
+                        device_type, device_type));
 }
 
 bool AutoMixedPrecisionImpl::MustPreserve(const NodeDef& node) const {
@@ -1352,21 +1404,24 @@
   VLOG(2) << "Identifying nodes that should be processed";
   for (const NodeDef& node : graph_->node()) {
     bool should_process;
+    string device_type;
     switch (mode_) {
       case AutoMixedPrecisionMode::CUDA:
+        device_type = DEVICE_GPU;
         should_process =
-            !MustPreserve(node) && IsOnDevice(node, DEVICE_GPU) &&
+            !MustPreserve(node) && IsOnDevice(node, device_type) &&
             (ShouldIgnorePerformance() || IsOnSuitableGPUArch(node));
         break;
       case AutoMixedPrecisionMode::MKL:
       case AutoMixedPrecisionMode::CPU:
-        should_process = !MustPreserve(node) && IsOnDevice(node, DEVICE_CPU);
+        device_type = DEVICE_CPU;
+        should_process = !MustPreserve(node) && IsOnDevice(node, device_type);
         break;
     }
     if (should_process) {
       should_process_nodes_.insert(&node);
     } else {
-      LogSkippedNode(node);
+      LogSkippedNode(node, device_type);
     }
   }
 
@@ -2190,6 +2245,9 @@
 }
 
 int GetNumGPUs(const Cluster& cluster) {
+  if (ShouldSimulateGpu()) {
+    return 1;
+  }
   auto devices = cluster.GetDevices();
   int num_gpus = 0;
   for (const auto& device : devices) {
diff --git a/tensorflow/core/grappler/optimizers/auto_mixed_precision_test.cc b/tensorflow/core/grappler/optimizers/auto_mixed_precision_test.cc
index 9233af7..e997369 100644
--- a/tensorflow/core/grappler/optimizers/auto_mixed_precision_test.cc
+++ b/tensorflow/core/grappler/optimizers/auto_mixed_precision_test.cc
@@ -1347,6 +1347,102 @@
   }
 }
 
+class AutoMixedPrecisionSimulateGpuTest : public GrapplerTest {
+ protected:
+  void SetUp() override {
+    std::unordered_map<string, DeviceProperties> devices;
+    DeviceProperties cpu_device;
+    cpu_device.set_type("CPU");
+    cpu_device.set_frequency(1000);
+    cpu_device.set_num_cores(4);
+    cpu_device.set_memory_size(1024 * 1024);
+    devices["/job:localhost/replica:0/task:0/device:CPU:0"] = cpu_device;
+    // Explicitly creating machine without GPU.
+    virtual_cluster_.reset(new VirtualCluster(devices));
+    TF_CHECK_OK(virtual_cluster_->Provision());
+  }
+  void TearDown() override {
+    unsetenv("TF_AUTO_MIXED_PRECISION_GRAPH_REWRITE_SIMULATE_GPU");
+    TF_CHECK_OK(virtual_cluster_->Shutdown());
+  }
+
+  std::unique_ptr<Cluster> virtual_cluster_;
+
+  void TestSimple(tensorflow::Scope s, bool is_optimized) {
+    Output input = ops::Const(s.WithOpName("input"), 1.f / 32, {32, 32});
+    Output deny1 = ops::Exp(s.WithOpName("deny1"), input);
+    Output clr1 = ops::Relu(s.WithOpName("clr1"), deny1);
+    Output infer1 = ops::Sqrt(s.WithOpName("infer1"), clr1);
+    Output clr2 = ops::Relu(s.WithOpName("clr2"), infer1);
+    Output allow1 = ops::MatMul(s.WithOpName("allow1"), clr2, clr2);
+    Output clr3 = ops::Relu(s.WithOpName("clr3"), allow1);
+    Output infer2 = ops::Log(s.WithOpName("infer2"), clr3);
+    Output clr4 = ops::Relu(s.WithOpName("clr4"), infer2);
+    Output deny2 = ops::SparseMatMul(s.WithOpName("deny2"), clr4, clr4);
+    Output clr5 = ops::Relu(s.WithOpName("clr5"), deny2);
+    Output fetch = ops::Identity(s.WithOpName("fetch"), clr5);
+
+    GrapplerItem item;
+    item.fetch = {"fetch"};
+    TF_CHECK_OK(s.ToGraphDef(&item.graph));
+    auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
+
+    GraphDef output;
+    AutoMixedPrecision optimizer;
+    TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
+
+    VLOG(1) << output.DebugString();
+
+    GraphView output_view(&output);
+    DataType expected_data_type = is_optimized ? DT_HALF : DT_FLOAT;
+    int expected_graph_size =
+        is_optimized ? item.graph.node_size() + 2 : item.graph.node_size();
+
+    EXPECT_EQ(output.node_size(), expected_graph_size);
+    EXPECT_EQ(output_view.GetNode("input")->attr().at("dtype").type(),
+              DT_FLOAT);
+    EXPECT_EQ(output_view.GetNode("deny1")->attr().at("T").type(), DT_FLOAT);
+    EXPECT_EQ(output_view.GetNode("clr1")->attr().at("T").type(), DT_FLOAT);
+    EXPECT_EQ(output_view.GetNode("infer1")->attr().at("T").type(), DT_FLOAT);
+    EXPECT_EQ(output_view.GetNode("clr2")->attr().at("T").type(),
+              expected_data_type);
+    EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(),
+              expected_data_type);
+    EXPECT_EQ(output_view.GetNode("clr3")->attr().at("T").type(),
+              expected_data_type);
+    EXPECT_EQ(output_view.GetNode("infer2")->attr().at("T").type(), DT_FLOAT);
+    EXPECT_EQ(output_view.GetNode("clr4")->attr().at("T").type(), DT_FLOAT);
+    EXPECT_EQ(output_view.GetNode("deny2")->attr().at("Ta").type(), DT_FLOAT);
+    EXPECT_EQ(output_view.GetNode("deny2")->attr().at("Tb").type(), DT_FLOAT);
+    EXPECT_EQ(output_view.GetNode("clr5")->attr().at("T").type(), DT_FLOAT);
+
+    auto tensors = EvaluateNodes(output, item.fetch);
+    EXPECT_EQ(tensors.size(), tensors_expected.size());
+    EXPECT_EQ(tensors.size(), item.fetch.size());
+    for (int i = 0; i < item.fetch.size(); ++i) {
+      test::ExpectClose(tensors_expected[i], tensors[i], -1, 5e-4);
+    }
+  }
+};
+
+TEST_F(AutoMixedPrecisionSimulateGpuTest, Simple_NoGpu) {
+  TestSimple(tensorflow::Scope::NewRootScope(), /* is_optimized= */ false);
+}
+
+TEST_F(AutoMixedPrecisionSimulateGpuTest, Simple_SimulatedGpu) {
+  setenv("TF_AUTO_MIXED_PRECISION_GRAPH_REWRITE_SIMULATE_GPU", "true",
+         1 /* replace */);
+  TestSimple(tensorflow::Scope::NewRootScope(), /* is_optimized= */ true);
+}
+
+TEST_F(AutoMixedPrecisionSimulateGpuTest, Simple_SimulatedGpu_CpuScope) {
+  setenv("TF_AUTO_MIXED_PRECISION_GRAPH_REWRITE_SIMULATE_GPU", "true",
+         1 /* replace */);
+  TestSimple(tensorflow::Scope::NewRootScope().WithDevice(
+                 "/job:localhost/replica:0/task:0/device:CPU:0"),
+             /* is_optimized= */ false);
+}
+
 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
 
 #if INTEL_MKL