| # Copyright 2020 The TensorFlow Authors. All Rights Reserved. |
| # |
| # Licensed under the Apache License, Version 2.0 (the "License"); |
| # you may not use this file except in compliance with the License. |
| # You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| # ============================================================================== |
| """Test utilities.""" |
| |
| from __future__ import absolute_import |
| from __future__ import division |
| from __future__ import print_function |
| |
| import functools |
| |
| from absl import app |
| |
| from tensorflow.python.compat import v2_compat |
| from tensorflow.python.distribute import collective_all_reduce_strategy |
| from tensorflow.python.distribute import multi_process_runner |
| from tensorflow.python.distribute import values |
| from tensorflow.python.eager import context |
| from tensorflow.python.framework import config |
| from tensorflow.python.framework import ops |
| from tensorflow.python.ops import array_ops |
| from tensorflow.python.util import nest |
| |
| |
| def gather(strategy, value): |
| """Gathers value from all workers. |
| |
| This is intended for tests before we implement an official all-gather API. |
| |
| Args: |
| strategy: a `tf.distribute.Strategy`. |
| value: a nested structure of n-dim `tf.distribute.DistributedValue` of |
| `tf.Tensor`, or of a `tf.Tensor` if the strategy only has one replica. |
| Cannot contain tf.sparse.SparseTensor. |
| |
| Returns: |
| a (n+1)-dim `tf.Tensor`. |
| """ |
| return nest.map_structure(functools.partial(_gather, strategy), value) |
| |
| |
| def _gather(strategy, value): |
| """Gathers a single value.""" |
| # pylint: disable=protected-access |
| if not isinstance(value, values.DistributedValues): |
| value = values.PerReplica([ops.convert_to_tensor(value)]) |
| if not isinstance(strategy.extended, |
| collective_all_reduce_strategy.CollectiveAllReduceExtended): |
| return array_ops.stack(value._values) |
| assert len(strategy.extended.worker_devices) == len(value._values) |
| inputs = [array_ops.expand_dims_v2(v, axis=0) for v in value._values] |
| return strategy.gather(values.PerReplica(inputs), axis=0) |
| # pylint: enable=protected-access |
| |
| |
| def set_logical_devices_to_at_least(device, num): |
| """Create logical devices of at least a given number.""" |
| if num < 1: |
| raise ValueError("`num` must be at least 1 not %r" % (num,)) |
| physical_devices = config.list_physical_devices(device) |
| if not physical_devices: |
| raise RuntimeError("No {} found".format(device)) |
| if len(physical_devices) >= num: |
| return |
| # By default each physical device corresponds to one logical device. We create |
| # multiple logical devices for the last physical device so that we have `num` |
| # logical devices. |
| num = num - len(physical_devices) + 1 |
| logical_devices = [] |
| for _ in range(num): |
| if device.upper() == "GPU": |
| logical_devices.append( |
| context.LogicalDeviceConfiguration(memory_limit=2048)) |
| else: |
| logical_devices.append(context.LogicalDeviceConfiguration()) |
| # Create logical devices from the last device since sometimes the first GPU |
| # is the primary graphic card and may have less memory available. |
| config.set_logical_device_configuration(physical_devices[-1], logical_devices) |
| |
| |
| def _set_logical_devices(): |
| if config.list_physical_devices("GPU"): |
| set_logical_devices_to_at_least("GPU", 2) |
| if config.list_physical_devices("CPU"): |
| set_logical_devices_to_at_least("CPU", 2) |
| |
| |
| def main(enable_v2_behavior=True, config_logical_devices=True): |
| """All-in-one main function for tf.distribute tests.""" |
| if config_logical_devices: |
| app.call_after_init(_set_logical_devices) |
| if enable_v2_behavior: |
| v2_compat.enable_v2_behavior() |
| else: |
| v2_compat.disable_v2_behavior() |
| # TODO(b/131360402): configure default logical devices. |
| multi_process_runner.test_main() |