Reenable full mixed precision warning.

The warning was previously disabled since it called `list_local_devices`, which does not work if `set_visible_devices` is also called. The warning now uses `tf.config.experimental.get_device_details()` instead.

PiperOrigin-RevId: 316141769
Change-Id: I9c77f8c0a2e13f793f4124960d5da3d16af682dd
diff --git a/tensorflow/python/keras/mixed_precision/experimental/BUILD b/tensorflow/python/keras/mixed_precision/experimental/BUILD
index 4b52b44..ec89fa0 100644
--- a/tensorflow/python/keras/mixed_precision/experimental/BUILD
+++ b/tensorflow/python/keras/mixed_precision/experimental/BUILD
@@ -80,9 +80,7 @@
     srcs = ["device_compatibility_check.py"],
     srcs_version = "PY2AND3",
     deps = [
-        "//tensorflow/python:device_lib",
-        "//tensorflow/python:gpu_util",
-        "//tensorflow/python/eager:context",
+        "//tensorflow/python:config",
     ],
 )
 
diff --git a/tensorflow/python/keras/mixed_precision/experimental/device_compatibility_check.py b/tensorflow/python/keras/mixed_precision/experimental/device_compatibility_check.py
index 0b28661..5a759d3 100644
--- a/tensorflow/python/keras/mixed_precision/experimental/device_compatibility_check.py
+++ b/tensorflow/python/keras/mixed_precision/experimental/device_compatibility_check.py
@@ -20,10 +20,7 @@
 
 import itertools
 
-from tensorflow.python.client import device_lib
-from tensorflow.python.eager import context
 from tensorflow.python.framework import config
-from tensorflow.python.framework import gpu_util
 from tensorflow.python.platform import tf_logging
 
 
@@ -61,14 +58,16 @@
   return new_device_strs
 
 
-def _log_device_compatibility_check(policy_name, device_attr_list):
+def _log_device_compatibility_check(policy_name, gpu_details_list):
   """Logs a compatibility check if the devices support the policy.
 
   Currently only logs for the policy mixed_float16.
 
   Args:
     policy_name: The name of the dtype policy.
-    device_attr_list: A list of DeviceAttributes.
+    gpu_details_list: A list of dicts, one dict per GPU. Each dict
+      is the device details for a GPU, as returned by
+      `tf.config.experimental.get_device_details()`.
   """
   if policy_name != 'mixed_float16':
     # TODO(b/145686977): Log if the policy is 'mixed_bfloat16'. This requires
@@ -76,19 +75,18 @@
     return
   supported_device_strs = []
   unsupported_device_strs = []
-  for device in device_attr_list:
-    if device.device_type == 'GPU':
-      name, cc = gpu_util.compute_capability_from_device_desc(device)
-      name = name or 'Unknown GPU'
-      if cc:
-        device_str = '%s, compute capability %s.%s' % (name, cc[0], cc[1])
-        if cc >= (7, 0):
-          supported_device_strs.append(device_str)
-        else:
-          unsupported_device_strs.append(device_str)
+  for details in gpu_details_list:
+    name = details.get('device_name', 'Unknown GPU')
+    cc = details.get('compute_capability')
+    if cc:
+      device_str = '%s, compute capability %s.%s' % (name, cc[0], cc[1])
+      if cc >= (7, 0):
+        supported_device_strs.append(device_str)
       else:
-        unsupported_device_strs.append(
-            name + ', no compute capability (probably not an Nvidia GPU)')
+        unsupported_device_strs.append(device_str)
+    else:
+      unsupported_device_strs.append(
+          name + ', no compute capability (probably not an Nvidia GPU)')
 
   if unsupported_device_strs:
     warning_str = _COMPAT_CHECK_WARNING_PREFIX + '\n'
@@ -134,7 +132,7 @@
 _logged_compatibility_check = False
 
 
-def log_device_compatibility_check(policy_name, skip_local):
+def log_device_compatibility_check(policy_name):
   """Logs a compatibility check if the devices support the policy.
 
   Currently only logs for the policy mixed_float16. A log is shown only the
@@ -142,29 +140,11 @@
 
   Args:
     policy_name: The name of the dtype policy.
-    skip_local: If True, do not call list_local_devices(). This is useful since
-      if list_local_devices() and tf.config.set_visible_devices() are both
-      called, TensorFlow will crash. However, since GPU names and compute
-      capabilities cannot be checked without list_local_devices(), setting this
-      to True means the function will only warn if there are no GPUs.
   """
   global _logged_compatibility_check
-  # In graph mode, calling list_local_devices may initialize some session state,
-  # so we only call it in eager mode.
-  if not context.executing_eagerly() or _logged_compatibility_check:
+  if _logged_compatibility_check:
     return
   _logged_compatibility_check = True
-  if not skip_local:
-    device_attr_list = device_lib.list_local_devices()
-    _log_device_compatibility_check(policy_name, device_attr_list)
-    return
-
-  # TODO(b/146009447): Create an API to replace list_local_devices(), then
-  # remove the skip_local paramater.
   gpus = config.list_physical_devices('GPU')
-  if not gpus and policy_name == 'mixed_float16':
-    tf_logging.warn(
-        '%s\n'
-        'The dtype policy mixed_float16 may run slowly because '
-        'this machine does not have a GPU.\n%s' %
-        (_COMPAT_CHECK_WARNING_PREFIX, _COMPAT_CHECK_WARNING_SUFFIX))
+  gpu_details_list = [config.get_device_details(g) for g in gpus]
+  _log_device_compatibility_check(policy_name, gpu_details_list)
diff --git a/tensorflow/python/keras/mixed_precision/experimental/device_compatibility_check_test.py b/tensorflow/python/keras/mixed_precision/experimental/device_compatibility_check_test.py
index a88594a..33d5b9d 100644
--- a/tensorflow/python/keras/mixed_precision/experimental/device_compatibility_check_test.py
+++ b/tensorflow/python/keras/mixed_precision/experimental/device_compatibility_check_test.py
@@ -20,26 +20,19 @@
 
 import re
 
-from tensorflow.core.framework import device_attributes_pb2
 from tensorflow.python.keras import combinations
 from tensorflow.python.keras.mixed_precision.experimental import device_compatibility_check
 from tensorflow.python.platform import test
 from tensorflow.python.platform import tf_logging
 
 
-def _get_device_attrs(device_type, device_name=None, cc_major=None,
-                      cc_minor=None):
-  if device_type == 'CPU':
-    return device_attributes_pb2.DeviceAttributes(device_type='CPU')
-  assert device_type == 'GPU', 'Invalid device type: %s' % (device_type,)
-  if not device_name:
-    return device_attributes_pb2.DeviceAttributes(device_type='GPU')
-  physical_device_desc = (
-      'device: 0, name: %s, pci bus id: 0:0:0.0' % device_name)
-  if cc_major:
-    physical_device_desc += ', compute capability: %d.%d' % (cc_major, cc_minor)
-  return device_attributes_pb2.DeviceAttributes(
-      device_type='GPU', physical_device_desc=physical_device_desc)
+def device_details(device_name, compute_capability=None):
+  details = {}
+  if device_name:
+    details['device_name'] = device_name
+  if compute_capability:
+    details['compute_capability'] = compute_capability
+  return details
 
 
 @combinations.generate(combinations.combine(mode=['graph', 'eager']))
@@ -59,57 +52,49 @@
       mock_warn.assert_not_called()
 
   def test_supported(self):
-    device_attrs_list = [_get_device_attrs('GPU', 'GPU 1', 7, 1)]
+    details_list = [device_details('GPU 1', (7, 1))]
     regex = re.compile(
         r'.*compatibility check \(mixed_float16\): OK\n'
         r'Your GPU will likely run quickly with dtype policy mixed_float16 as '
         r'it has compute capability of at least 7.0. Your GPU: GPU 1, compute '
         r'capability 7.1', flags=re.MULTILINE)
-    self._test_compat_check(device_attrs_list, False, regex)
-    device_attrs_list.append(_get_device_attrs('CPU'))
-    self._test_compat_check(device_attrs_list, False, regex)
+    self._test_compat_check(details_list, False, regex)
 
-    device_attrs_list = [
-        _get_device_attrs('CPU', 'CPU'),
-        _get_device_attrs('GPU', 'GPU 1', 7, 0),
-        _get_device_attrs('GPU', 'GPU 2', 7, 1),
-        _get_device_attrs('GPU', 'GPU 3', 8, 0),
+    details_list = [
+        device_details('GPU 1', (7, 0)),
+        device_details('GPU 2', (7, 1)),
+        device_details('GPU 3', (8, 0)),
     ]
     regex = re.compile(
         r'.*compatibility check \(mixed_float16\): OK\n'
         r'Your GPUs will likely run quickly with dtype policy mixed_float16 as '
         r'they all have compute capability of at least 7.0', flags=re.MULTILINE)
-    self._test_compat_check(device_attrs_list, False, regex)
+    self._test_compat_check(details_list, False, regex)
 
   def test_unsupported(self):
-    device_attrs_list = [
-        _get_device_attrs('GPU', 'GPU 1', 6, 0)
+    details_list = [
+        device_details('GPU 1', (6, 0))
     ]
     regex = re.compile(
         r'.*compatibility check \(mixed_float16\): WARNING\n'
         r'Your GPU may run slowly with dtype policy mixed_float16.*\n'
         r'  GPU 1, compute capability 6.0\n'
         r'See.*', flags=re.MULTILINE)
-    self._test_compat_check(device_attrs_list, True, regex)
-    device_attrs_list.append(_get_device_attrs('CPU'))
-    self._test_compat_check(device_attrs_list, True, regex)
+    self._test_compat_check(details_list, True, regex)
 
-    device_attrs_list = [
-        _get_device_attrs('GPU')
+    details_list = [
+        device_details(None)
     ]
     regex = re.compile(
         r'.*compatibility check \(mixed_float16\): WARNING\n'
         r'Your GPU may run slowly with dtype policy mixed_float16.*\n'
         r'  Unknown GPU, no compute capability \(probably not an Nvidia GPU\)\n'
         r'See.*', flags=re.MULTILINE)
-    self._test_compat_check(device_attrs_list, True, regex)
-    device_attrs_list.append(_get_device_attrs('CPU'))
-    self._test_compat_check(device_attrs_list, True, regex)
+    self._test_compat_check(details_list, True, regex)
 
-    device_attrs_list = [
-        _get_device_attrs('CPU', 'CPU'),
-        _get_device_attrs('GPU', 'GPU 1', 6, 0),
-        _get_device_attrs('GPU', 'GPU 2', 3, 10),
+    details_list = [
+        device_details('GPU 1', (6, 0)),
+        device_details('GPU 2', (3, 10)),
     ]
     regex = re.compile(
         r'.*compatibility check \(mixed_float16\): WARNING\n'
@@ -117,14 +102,13 @@
         r'  GPU 1, compute capability 6.0\n'
         r'  GPU 2, compute capability 3.10\n'
         r'See.*', flags=re.MULTILINE)
-    self._test_compat_check(device_attrs_list, True, regex)
+    self._test_compat_check(details_list, True, regex)
 
-    device_attrs_list = [
-        _get_device_attrs('CPU', 'CPU'),
-        _get_device_attrs('GPU', 'GPU 1', 6, 0),
-        _get_device_attrs('GPU', 'GPU 1', 6, 0),
-        _get_device_attrs('GPU', 'GPU 1', 6, 0),
-        _get_device_attrs('GPU', 'GPU 2', 3, 10),
+    details_list = [
+        device_details('GPU 1', (6, 0)),
+        device_details('GPU 1', (6, 0)),
+        device_details('GPU 1', (6, 0)),
+        device_details('GPU 2', (3, 10)),
     ]
     regex = re.compile(
         r'.*compatibility check \(mixed_float16\): WARNING\n'
@@ -132,20 +116,20 @@
         r'  GPU 1, compute capability 6.0 \(x3\)\n'
         r'  GPU 2, compute capability 3.10\n'
         r'See.*', flags=re.MULTILINE)
-    self._test_compat_check(device_attrs_list, True, regex)
+    self._test_compat_check(details_list, True, regex)
 
-    device_attrs_list = [_get_device_attrs('CPU')]
+    details_list = []
     regex = re.compile(
         r'.*compatibility check \(mixed_float16\): WARNING\n'
         r'The dtype policy mixed_float16 may run slowly because this machine '
         r'does not have a GPU', flags=re.MULTILINE)
-    self._test_compat_check(device_attrs_list, True, regex)
+    self._test_compat_check(details_list, True, regex)
 
   def test_mix_of_supported_and_unsupported(self):
-    device_attrs_list = [
-        _get_device_attrs('GPU', 'GPU 1', 7, 0),
-        _get_device_attrs('GPU', 'GPU 1', 7, 0),
-        _get_device_attrs('GPU', 'GPU 2', 6, 0)
+    details_list = [
+        device_details('GPU 1', (7, 0)),
+        device_details('GPU 1', (7, 0)),
+        device_details('GPU 2', (6, 0))
     ]
     regex = re.compile(
         r'.*compatibility check \(mixed_float16\): WARNING\n'
@@ -153,9 +137,7 @@
         r'  GPU 1, compute capability 7.0 \(x2\)\n'
         r'  GPU 2, compute capability 6.0\n'
         r'See.*', flags=re.MULTILINE)
-    self._test_compat_check(device_attrs_list, True, regex)
-    device_attrs_list.append(_get_device_attrs('CPU'))
-    self._test_compat_check(device_attrs_list, True, regex)
+    self._test_compat_check(details_list, True, regex)
 
 
 if __name__ == '__main__':
diff --git a/tensorflow/python/keras/mixed_precision/experimental/policy.py b/tensorflow/python/keras/mixed_precision/experimental/policy.py
index 0b809e6..592057f 100644
--- a/tensorflow/python/keras/mixed_precision/experimental/policy.py
+++ b/tensorflow/python/keras/mixed_precision/experimental/policy.py
@@ -337,8 +337,7 @@
     self._loss_scale = keras_loss_scale_module.get(loss_scale)
 
     if name in ('mixed_float16', 'mixed_bloat16'):
-      device_compatibility_check.log_device_compatibility_check(name,
-                                                                skip_local=True)
+      device_compatibility_check.log_device_compatibility_check(name)
 
   def _parse_name(self, name):
     """Parses a Policy name into a compute and variable dtype.