Reenable full mixed precision warning.
The warning was previously disabled since it called `list_local_devices`, which does not work if `set_visible_devices` is also called. The warning now uses `tf.config.experimental.get_device_details()` instead.
PiperOrigin-RevId: 316141769
Change-Id: I9c77f8c0a2e13f793f4124960d5da3d16af682dd
diff --git a/tensorflow/python/keras/mixed_precision/experimental/BUILD b/tensorflow/python/keras/mixed_precision/experimental/BUILD
index 4b52b44..ec89fa0 100644
--- a/tensorflow/python/keras/mixed_precision/experimental/BUILD
+++ b/tensorflow/python/keras/mixed_precision/experimental/BUILD
@@ -80,9 +80,7 @@
srcs = ["device_compatibility_check.py"],
srcs_version = "PY2AND3",
deps = [
- "//tensorflow/python:device_lib",
- "//tensorflow/python:gpu_util",
- "//tensorflow/python/eager:context",
+ "//tensorflow/python:config",
],
)
diff --git a/tensorflow/python/keras/mixed_precision/experimental/device_compatibility_check.py b/tensorflow/python/keras/mixed_precision/experimental/device_compatibility_check.py
index 0b28661..5a759d3 100644
--- a/tensorflow/python/keras/mixed_precision/experimental/device_compatibility_check.py
+++ b/tensorflow/python/keras/mixed_precision/experimental/device_compatibility_check.py
@@ -20,10 +20,7 @@
import itertools
-from tensorflow.python.client import device_lib
-from tensorflow.python.eager import context
from tensorflow.python.framework import config
-from tensorflow.python.framework import gpu_util
from tensorflow.python.platform import tf_logging
@@ -61,14 +58,16 @@
return new_device_strs
-def _log_device_compatibility_check(policy_name, device_attr_list):
+def _log_device_compatibility_check(policy_name, gpu_details_list):
"""Logs a compatibility check if the devices support the policy.
Currently only logs for the policy mixed_float16.
Args:
policy_name: The name of the dtype policy.
- device_attr_list: A list of DeviceAttributes.
+ gpu_details_list: A list of dicts, one dict per GPU. Each dict
+ is the device details for a GPU, as returned by
+ `tf.config.experimental.get_device_details()`.
"""
if policy_name != 'mixed_float16':
# TODO(b/145686977): Log if the policy is 'mixed_bfloat16'. This requires
@@ -76,19 +75,18 @@
return
supported_device_strs = []
unsupported_device_strs = []
- for device in device_attr_list:
- if device.device_type == 'GPU':
- name, cc = gpu_util.compute_capability_from_device_desc(device)
- name = name or 'Unknown GPU'
- if cc:
- device_str = '%s, compute capability %s.%s' % (name, cc[0], cc[1])
- if cc >= (7, 0):
- supported_device_strs.append(device_str)
- else:
- unsupported_device_strs.append(device_str)
+ for details in gpu_details_list:
+ name = details.get('device_name', 'Unknown GPU')
+ cc = details.get('compute_capability')
+ if cc:
+ device_str = '%s, compute capability %s.%s' % (name, cc[0], cc[1])
+ if cc >= (7, 0):
+ supported_device_strs.append(device_str)
else:
- unsupported_device_strs.append(
- name + ', no compute capability (probably not an Nvidia GPU)')
+ unsupported_device_strs.append(device_str)
+ else:
+ unsupported_device_strs.append(
+ name + ', no compute capability (probably not an Nvidia GPU)')
if unsupported_device_strs:
warning_str = _COMPAT_CHECK_WARNING_PREFIX + '\n'
@@ -134,7 +132,7 @@
_logged_compatibility_check = False
-def log_device_compatibility_check(policy_name, skip_local):
+def log_device_compatibility_check(policy_name):
"""Logs a compatibility check if the devices support the policy.
Currently only logs for the policy mixed_float16. A log is shown only the
@@ -142,29 +140,11 @@
Args:
policy_name: The name of the dtype policy.
- skip_local: If True, do not call list_local_devices(). This is useful since
- if list_local_devices() and tf.config.set_visible_devices() are both
- called, TensorFlow will crash. However, since GPU names and compute
- capabilities cannot be checked without list_local_devices(), setting this
- to True means the function will only warn if there are no GPUs.
"""
global _logged_compatibility_check
- # In graph mode, calling list_local_devices may initialize some session state,
- # so we only call it in eager mode.
- if not context.executing_eagerly() or _logged_compatibility_check:
+ if _logged_compatibility_check:
return
_logged_compatibility_check = True
- if not skip_local:
- device_attr_list = device_lib.list_local_devices()
- _log_device_compatibility_check(policy_name, device_attr_list)
- return
-
- # TODO(b/146009447): Create an API to replace list_local_devices(), then
- # remove the skip_local paramater.
gpus = config.list_physical_devices('GPU')
- if not gpus and policy_name == 'mixed_float16':
- tf_logging.warn(
- '%s\n'
- 'The dtype policy mixed_float16 may run slowly because '
- 'this machine does not have a GPU.\n%s' %
- (_COMPAT_CHECK_WARNING_PREFIX, _COMPAT_CHECK_WARNING_SUFFIX))
+ gpu_details_list = [config.get_device_details(g) for g in gpus]
+ _log_device_compatibility_check(policy_name, gpu_details_list)
diff --git a/tensorflow/python/keras/mixed_precision/experimental/device_compatibility_check_test.py b/tensorflow/python/keras/mixed_precision/experimental/device_compatibility_check_test.py
index a88594a..33d5b9d 100644
--- a/tensorflow/python/keras/mixed_precision/experimental/device_compatibility_check_test.py
+++ b/tensorflow/python/keras/mixed_precision/experimental/device_compatibility_check_test.py
@@ -20,26 +20,19 @@
import re
-from tensorflow.core.framework import device_attributes_pb2
from tensorflow.python.keras import combinations
from tensorflow.python.keras.mixed_precision.experimental import device_compatibility_check
from tensorflow.python.platform import test
from tensorflow.python.platform import tf_logging
-def _get_device_attrs(device_type, device_name=None, cc_major=None,
- cc_minor=None):
- if device_type == 'CPU':
- return device_attributes_pb2.DeviceAttributes(device_type='CPU')
- assert device_type == 'GPU', 'Invalid device type: %s' % (device_type,)
- if not device_name:
- return device_attributes_pb2.DeviceAttributes(device_type='GPU')
- physical_device_desc = (
- 'device: 0, name: %s, pci bus id: 0:0:0.0' % device_name)
- if cc_major:
- physical_device_desc += ', compute capability: %d.%d' % (cc_major, cc_minor)
- return device_attributes_pb2.DeviceAttributes(
- device_type='GPU', physical_device_desc=physical_device_desc)
+def device_details(device_name, compute_capability=None):
+ details = {}
+ if device_name:
+ details['device_name'] = device_name
+ if compute_capability:
+ details['compute_capability'] = compute_capability
+ return details
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
@@ -59,57 +52,49 @@
mock_warn.assert_not_called()
def test_supported(self):
- device_attrs_list = [_get_device_attrs('GPU', 'GPU 1', 7, 1)]
+ details_list = [device_details('GPU 1', (7, 1))]
regex = re.compile(
r'.*compatibility check \(mixed_float16\): OK\n'
r'Your GPU will likely run quickly with dtype policy mixed_float16 as '
r'it has compute capability of at least 7.0. Your GPU: GPU 1, compute '
r'capability 7.1', flags=re.MULTILINE)
- self._test_compat_check(device_attrs_list, False, regex)
- device_attrs_list.append(_get_device_attrs('CPU'))
- self._test_compat_check(device_attrs_list, False, regex)
+ self._test_compat_check(details_list, False, regex)
- device_attrs_list = [
- _get_device_attrs('CPU', 'CPU'),
- _get_device_attrs('GPU', 'GPU 1', 7, 0),
- _get_device_attrs('GPU', 'GPU 2', 7, 1),
- _get_device_attrs('GPU', 'GPU 3', 8, 0),
+ details_list = [
+ device_details('GPU 1', (7, 0)),
+ device_details('GPU 2', (7, 1)),
+ device_details('GPU 3', (8, 0)),
]
regex = re.compile(
r'.*compatibility check \(mixed_float16\): OK\n'
r'Your GPUs will likely run quickly with dtype policy mixed_float16 as '
r'they all have compute capability of at least 7.0', flags=re.MULTILINE)
- self._test_compat_check(device_attrs_list, False, regex)
+ self._test_compat_check(details_list, False, regex)
def test_unsupported(self):
- device_attrs_list = [
- _get_device_attrs('GPU', 'GPU 1', 6, 0)
+ details_list = [
+ device_details('GPU 1', (6, 0))
]
regex = re.compile(
r'.*compatibility check \(mixed_float16\): WARNING\n'
r'Your GPU may run slowly with dtype policy mixed_float16.*\n'
r' GPU 1, compute capability 6.0\n'
r'See.*', flags=re.MULTILINE)
- self._test_compat_check(device_attrs_list, True, regex)
- device_attrs_list.append(_get_device_attrs('CPU'))
- self._test_compat_check(device_attrs_list, True, regex)
+ self._test_compat_check(details_list, True, regex)
- device_attrs_list = [
- _get_device_attrs('GPU')
+ details_list = [
+ device_details(None)
]
regex = re.compile(
r'.*compatibility check \(mixed_float16\): WARNING\n'
r'Your GPU may run slowly with dtype policy mixed_float16.*\n'
r' Unknown GPU, no compute capability \(probably not an Nvidia GPU\)\n'
r'See.*', flags=re.MULTILINE)
- self._test_compat_check(device_attrs_list, True, regex)
- device_attrs_list.append(_get_device_attrs('CPU'))
- self._test_compat_check(device_attrs_list, True, regex)
+ self._test_compat_check(details_list, True, regex)
- device_attrs_list = [
- _get_device_attrs('CPU', 'CPU'),
- _get_device_attrs('GPU', 'GPU 1', 6, 0),
- _get_device_attrs('GPU', 'GPU 2', 3, 10),
+ details_list = [
+ device_details('GPU 1', (6, 0)),
+ device_details('GPU 2', (3, 10)),
]
regex = re.compile(
r'.*compatibility check \(mixed_float16\): WARNING\n'
@@ -117,14 +102,13 @@
r' GPU 1, compute capability 6.0\n'
r' GPU 2, compute capability 3.10\n'
r'See.*', flags=re.MULTILINE)
- self._test_compat_check(device_attrs_list, True, regex)
+ self._test_compat_check(details_list, True, regex)
- device_attrs_list = [
- _get_device_attrs('CPU', 'CPU'),
- _get_device_attrs('GPU', 'GPU 1', 6, 0),
- _get_device_attrs('GPU', 'GPU 1', 6, 0),
- _get_device_attrs('GPU', 'GPU 1', 6, 0),
- _get_device_attrs('GPU', 'GPU 2', 3, 10),
+ details_list = [
+ device_details('GPU 1', (6, 0)),
+ device_details('GPU 1', (6, 0)),
+ device_details('GPU 1', (6, 0)),
+ device_details('GPU 2', (3, 10)),
]
regex = re.compile(
r'.*compatibility check \(mixed_float16\): WARNING\n'
@@ -132,20 +116,20 @@
r' GPU 1, compute capability 6.0 \(x3\)\n'
r' GPU 2, compute capability 3.10\n'
r'See.*', flags=re.MULTILINE)
- self._test_compat_check(device_attrs_list, True, regex)
+ self._test_compat_check(details_list, True, regex)
- device_attrs_list = [_get_device_attrs('CPU')]
+ details_list = []
regex = re.compile(
r'.*compatibility check \(mixed_float16\): WARNING\n'
r'The dtype policy mixed_float16 may run slowly because this machine '
r'does not have a GPU', flags=re.MULTILINE)
- self._test_compat_check(device_attrs_list, True, regex)
+ self._test_compat_check(details_list, True, regex)
def test_mix_of_supported_and_unsupported(self):
- device_attrs_list = [
- _get_device_attrs('GPU', 'GPU 1', 7, 0),
- _get_device_attrs('GPU', 'GPU 1', 7, 0),
- _get_device_attrs('GPU', 'GPU 2', 6, 0)
+ details_list = [
+ device_details('GPU 1', (7, 0)),
+ device_details('GPU 1', (7, 0)),
+ device_details('GPU 2', (6, 0))
]
regex = re.compile(
r'.*compatibility check \(mixed_float16\): WARNING\n'
@@ -153,9 +137,7 @@
r' GPU 1, compute capability 7.0 \(x2\)\n'
r' GPU 2, compute capability 6.0\n'
r'See.*', flags=re.MULTILINE)
- self._test_compat_check(device_attrs_list, True, regex)
- device_attrs_list.append(_get_device_attrs('CPU'))
- self._test_compat_check(device_attrs_list, True, regex)
+ self._test_compat_check(details_list, True, regex)
if __name__ == '__main__':
diff --git a/tensorflow/python/keras/mixed_precision/experimental/policy.py b/tensorflow/python/keras/mixed_precision/experimental/policy.py
index 0b809e6..592057f 100644
--- a/tensorflow/python/keras/mixed_precision/experimental/policy.py
+++ b/tensorflow/python/keras/mixed_precision/experimental/policy.py
@@ -337,8 +337,7 @@
self._loss_scale = keras_loss_scale_module.get(loss_scale)
if name in ('mixed_float16', 'mixed_bloat16'):
- device_compatibility_check.log_device_compatibility_check(name,
- skip_local=True)
+ device_compatibility_check.log_device_compatibility_check(name)
def _parse_name(self, name):
"""Parses a Policy name into a compute and variable dtype.