| # Copyright 2019 The TensorFlow Authors. All Rights Reserved. |
| # |
| # Licensed under the Apache License, Version 2.0 (the "License"); |
| # you may not use this file except in compliance with the License. |
| # You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| # ============================================================================== |
| """Tests for remote execution.""" |
| |
| from __future__ import absolute_import |
| from __future__ import division |
| from __future__ import print_function |
| |
| import random |
| |
| from absl.testing import parameterized |
| import numpy as np |
| import six |
| |
| from tensorflow.python.data.ops import dataset_ops |
| from tensorflow.python.distribute.cluster_resolver.cluster_resolver import SimpleClusterResolver |
| from tensorflow.python.eager import context |
| from tensorflow.python.eager import def_function |
| from tensorflow.python.eager import remote |
| from tensorflow.python.eager import test |
| from tensorflow.python.framework import constant_op |
| from tensorflow.python.framework import errors |
| from tensorflow.python.framework import ops |
| from tensorflow.python.framework import test_util |
| from tensorflow.python.ops import array_ops |
| from tensorflow.python.ops import control_flow_ops |
| from tensorflow.python.ops import math_ops |
| from tensorflow.python.ops import variables |
| from tensorflow.python.training import server_lib |
| from tensorflow.python.training.server_lib import ClusterSpec |
| |
| |
| class SingleWorkerTest(test.TestCase, parameterized.TestCase): |
| |
| def setUp(self): |
| super(SingleWorkerTest, self).setUp() |
| |
| workers, _ = test_util.create_local_cluster(1, 0) |
| remote.connect_to_remote_host(workers[0].target) |
| |
| def tearDown(self): |
| super(SingleWorkerTest, self).tearDown() |
| |
| # Clear the current device scope to avoid polluting other test cases. |
| ops.device(None).__enter__() |
| # Reset the context to avoid polluting other test cases. |
| context._reset_context() |
| |
| @test_util.eager_lazy_remote_copy_on_and_off |
| def testMultiDeviceFunctionBasic(self): |
| |
| @def_function.function |
| def basic(i): |
| with ops.device('/job:localhost/replica:0/task:0/cpu:0'): |
| a = constant_op.constant([2]) + i |
| with ops.device('/job:worker/replica:0/task:0/cpu:0'): |
| b = constant_op.constant([1]) |
| |
| return a + b |
| |
| self.assertAllEqual(basic(constant_op.constant([2])).numpy(), [5]) |
| self.assertAllEqual(basic(constant_op.constant([1])).numpy(), [4]) |
| |
| @test_util.eager_lazy_remote_copy_on_and_off |
| def testMultiDeviceFunctionVariable(self): |
| with ops.device('/job:worker/replica:0/task:0/cpu:0'): |
| variable_b = variables.Variable(1) |
| |
| @def_function.function |
| def with_variable(i): |
| return i + variable_b |
| |
| self.assertAllEqual(with_variable(constant_op.constant([2])).numpy(), [3]) |
| |
| @test_util.eager_lazy_remote_copy_on_and_off |
| def testMultiDeviceFunctionRemoteOutput(self): |
| with ops.device('/job:worker/replica:0/task:0/cpu:0'): |
| variable_b = variables.Variable(1) |
| |
| @def_function.function |
| def remote_output(i): |
| with ops.device('/job:worker/replica:0/task:0/cpu:0'): |
| c = variable_b + 1 |
| return c, i + variable_b |
| |
| self.assertAllEqual( |
| remote_output(constant_op.constant([1]))[0].numpy(), 2) |
| |
| # TODO(b/148235520): Re-enable this test. |
| def DISABLED_testMultiDeviceFunctionAmbiguousDevice(self): |
| |
| @def_function.function |
| def ambiguous_device(i): |
| with ops.device('cpu:0'): |
| return i + constant_op.constant([2]) |
| |
| with self.assertRaises(errors.InvalidArgumentError) as cm: |
| with ops.device('/job:worker/replica:0/task:0/cpu:0'): |
| ambiguous_device(constant_op.constant([2])).numpy() |
| |
| self.assertIn('the output node must match exactly one device', |
| cm.exception.message) |
| |
| def testStreaming(self): |
| """A mini stress test for streaming - issuing many RPCs back to back.""" |
| with ops.device('job:worker/replica:0/task:0/device:CPU:0'): |
| x = array_ops.ones([2, 2]) |
| y = array_ops.zeros([2, 2]) |
| num_iters = 200 |
| for _ in range(num_iters): |
| y = x + y |
| # Ask for y's shape after every 10 additions on average. |
| # This exercises waiting for remote shape logic in TensorHandle. |
| if random.randint(1, 10) == 1: |
| _ = y.shape |
| np.testing.assert_array_equal( |
| [[num_iters, num_iters], [num_iters, num_iters]], y.numpy()) |
| |
| def testShapeError_OpByOp(self): |
| with ops.device('job:worker/replica:0/task:0/device:CPU:0'): |
| x = array_ops.ones([2, 3]) |
| y = array_ops.zeros([2, 2]) |
| with self.assertRaises(errors.InvalidArgumentError) as cm: |
| math_ops.matmul(x, y) |
| |
| self.assertIn('Dimensions must be equal', cm.exception.message) |
| |
| @test_util.eager_lazy_remote_copy_on_and_off |
| def testShapeError_Function(self): |
| |
| @def_function.function |
| def matmul_func(x, y): |
| return math_ops.matmul(x, y) |
| |
| x = array_ops.ones([2, 3]) |
| y = array_ops.zeros([2, 2]) |
| |
| with ops.device('job:worker/replica:0/task:0/device:CPU:0'): |
| with self.assertRaises(ValueError) as cm: |
| matmul_func(x, y) |
| |
| if six.PY2: |
| self.assertIn('Dimensions must be equal', cm.exception.message) |
| else: |
| self.assertIn('Dimensions must be equal', cm.exception.args[0]) |
| |
| |
| class RemoteAsyncTest(test.TestCase): |
| |
| def setUp(self): |
| super(RemoteAsyncTest, self).setUp() |
| |
| workers, _ = test_util.create_local_cluster(1, 0) |
| remote.connect_to_remote_host(workers[0].target) |
| |
| def tearDown(self): |
| super(RemoteAsyncTest, self).tearDown() |
| |
| # Reset the context to avoid polluting other test cases. |
| context._reset_context() |
| |
| def test_out_of_range_with_while_loop(self): |
| |
| with ops.device('/job:worker/task:0'): |
| dataset = dataset_ops.Dataset.from_tensor_slices([1.0, 2.0]) |
| dataset = dataset.batch(1, drop_remainder=False) |
| iterator = iter(dataset) |
| v = variables.Variable(1.0) |
| |
| @def_function.function |
| def train_step(iterator): |
| i = next(iterator) |
| v.assign_add(math_ops.reduce_mean(i)) |
| |
| while True: |
| try: |
| with ops.device('/job:worker/task:0'): |
| train_step(iterator) |
| except (errors.OutOfRangeError, errors.InternalError): |
| context.async_clear_error() |
| break |
| |
| self.assertAllEqual(v.numpy(), 4.0) |
| |
| def test_out_of_range_with_for_loop(self): |
| |
| with ops.device('/job:worker/task:0'): |
| dataset = dataset_ops.Dataset.from_tensor_slices([1.0, 2.0]) |
| dataset = dataset.batch(1, drop_remainder=False) |
| iterator = iter(dataset) |
| v = variables.Variable(1.0) |
| |
| @def_function.function |
| def train_step(iterator): |
| i = next(iterator) |
| v.assign_add(math_ops.reduce_mean(i)) |
| |
| num_steps = 3 |
| for i in range(num_steps): |
| try: |
| with ops.device('/job:worker/task:0'): |
| train_step(iterator) |
| if i == num_steps - 1: |
| context.async_wait() |
| except errors.OutOfRangeError: |
| context.async_clear_error() |
| break |
| |
| self.assertAllEqual(v.numpy(), 4.0) |
| |
| def test_out_of_range_with_async_scope(self): |
| |
| with ops.device('/job:worker/task:0'): |
| dataset = dataset_ops.Dataset.from_tensor_slices([1.0, 2.0]) |
| dataset = dataset.batch(1, drop_remainder=False) |
| iterator = iter(dataset) |
| v = variables.Variable(1.0) |
| |
| @def_function.function |
| def train_step(iterator): |
| i = next(iterator) |
| v.assign_add(math_ops.reduce_mean(i)) |
| |
| num_steps = 3 |
| try: |
| with context.async_scope(): |
| for _ in range(num_steps): |
| with ops.device('/job:worker/task:0'): |
| train_step(iterator) |
| except errors.OutOfRangeError: |
| context.async_clear_error() |
| |
| self.assertAllEqual(v.numpy(), 4.0) |
| |
| |
| class MultiWorkersTest(test.TestCase, parameterized.TestCase): |
| |
| def setUp(self): |
| super(MultiWorkersTest, self).setUp() |
| |
| workers, _ = test_util.create_local_cluster(3, 0) |
| remote.connect_to_remote_host( |
| [workers[0].target, workers[1].target, workers[2].target]) |
| |
| def tearDown(self): |
| super(MultiWorkersTest, self).tearDown() |
| |
| # Clear the current device scope to avoid polluting other test cases. |
| ops.device(None).__enter__() |
| # Reset the context to avoid polluting other test cases. |
| context._reset_context() |
| |
| @test_util.eager_lazy_remote_copy_on_and_off |
| def testReturnRemoteArgument(self): |
| |
| @def_function.function |
| def local_func(i): |
| return i |
| |
| with ops.device('/job:worker/replica:0/task:0'): |
| x = constant_op.constant([2, 1]) |
| |
| with ops.device('/job:worker/replica:0/task:1'): |
| self.assertAllEqual(local_func(x), [2, 1]) |
| |
| @test_util.eager_lazy_remote_copy_on_and_off |
| def testMultiDeviceFunctionOnLocalDevice(self): |
| with ops.device('/job:worker/replica:0/task:1'): |
| variable_b = variables.Variable(1.0) |
| |
| @def_function.function |
| def remote_function(i): |
| with ops.device('/job:worker/replica:0/task:0'): |
| a = i + variable_b |
| c = a + 1.0 |
| return c |
| |
| self.assertAllEqual(remote_function(constant_op.constant([1.0])), [3.0]) |
| |
| @test_util.eager_lazy_remote_copy_on_and_off |
| def testMultiDeviceFunctionOnRemoteDeviceWithWait(self): |
| with ops.device('/job:worker/replica:0/task:1'): |
| variable_b = variables.Variable([1.0]) |
| |
| @def_function.function |
| def remote_function(i): |
| x = array_ops.ones([1000, 1000]) |
| for _ in range(1, 1000): |
| x = x * x |
| variable_b.assign_add(i) |
| a = 1.0 + variable_b |
| return a |
| |
| @def_function.function |
| def remote_function2(i): |
| variable_b.assign_add(i) |
| a = 1.0 + variable_b |
| return a |
| |
| # Runs first function: |
| # - on remote device |
| # - needs remote input |
| # - is side impacting |
| # - runs much slower |
| with ops.device('/job:worker/replica:0/task:0'): |
| remote_function(constant_op.constant([2.0])) |
| |
| # Runs second function: |
| # - on remote device |
| # - is side impacting |
| # There should be a sync point here and the next function will be executed |
| # only after the first function has completed. |
| with ops.device('/job:worker/replica:0/task:2'): |
| self.assertAllEqual(remote_function2(constant_op.constant([3.0])), [7.0]) |
| |
| @test_util.eager_lazy_remote_copy_on_and_off |
| def testMultiDeviceFunctionOnRemoteDevice(self): |
| with ops.device('/job:worker/replica:0/task:1'): |
| variable_b = variables.Variable(1.0) |
| |
| @def_function.function |
| def remote_function(i): |
| with ops.device('/job:worker/replica:0/task:0'): |
| a = i + variable_b |
| c = a + 1.0 |
| return c |
| |
| context.context().mirroring_policy = context.MIRRORING_NONE |
| |
| with ops.device('/job:worker/replica:0/task:0'): |
| self.assertAllEqual(remote_function(constant_op.constant([1.0])), [3.0]) |
| |
| if test_util.is_gpu_available(): |
| with ops.device('/job:worker/replica:0/task:0/device:GPU:0'): |
| self.assertAllEqual(remote_function(constant_op.constant([1.0])), [3.0]) |
| |
| context.context().mirroring_policy = context.MIRRORING_ALL |
| |
| with ops.device('/job:worker/replica:0/task:0'): |
| self.assertAllEqual(remote_function(constant_op.constant([1.0])), [3.0]) |
| |
| if test_util.is_gpu_available(): |
| with ops.device('/job:worker/replica:0/task:0/device:GPU:0'): |
| self.assertAllEqual(remote_function(constant_op.constant([1.0])), [3.0]) |
| |
| @test_util.eager_lazy_remote_copy_on_and_off |
| def testMultiDeviceWhileLoopOnRemoteDevice(self): |
| with ops.device('/job:worker/replica:0/task:1'): |
| variable_b = variables.Variable(1.0) |
| |
| @def_function.function |
| def remote_function(i): |
| |
| def body(i, _): |
| with ops.device('/job:worker/replica:0/task:0'): |
| a = i + variable_b |
| return a + 1.0, 1 |
| |
| return control_flow_ops.while_loop_v2(lambda _, d: d < 1, body, [i, 0])[0] |
| |
| context.context().mirroring_policy = context.MIRRORING_NONE |
| |
| with ops.device('/job:worker/replica:0/task:0'): |
| self.assertAllEqual(remote_function(constant_op.constant([1.0])), [3.0]) |
| |
| if test_util.is_gpu_available(): |
| with ops.device('/job:worker/replica:0/task:0/device:GPU:0'): |
| self.assertAllEqual(remote_function(constant_op.constant([1.0])), [3.0]) |
| |
| context.context().mirroring_policy = context.MIRRORING_ALL |
| |
| with ops.device('/job:worker/replica:0/task:0'): |
| self.assertAllEqual(remote_function(constant_op.constant([1.0])), [3.0]) |
| |
| if test_util.is_gpu_available(): |
| with ops.device('/job:worker/replica:0/task:0/device:GPU:0'): |
| self.assertAllEqual(remote_function(constant_op.constant([1.0])), [3.0]) |
| |
| @test_util.eager_lazy_remote_copy_on_and_off |
| def testSimpleParameterServer(self): |
| |
| with ops.device('/job:worker/task:2/device:CPU:0'): |
| v1 = variables.Variable(initial_value=0) |
| v2 = variables.Variable(initial_value=10) |
| |
| @def_function.function |
| def worker_fn(): |
| v1.assign_add(1) |
| v2.assign_sub(2) |
| return v1.read_value() + v2.read_value() |
| |
| with ops.device('/job:worker/task:0/device:CPU:0'): |
| self.assertAllEqual(worker_fn(), 9) |
| |
| with ops.device('/job:worker/task:1/device:CPU:0'): |
| self.assertAllEqual(worker_fn(), 8) |
| |
| |
| _GRPC_PREFIX = 'grpc://' |
| |
| |
| class MultiJobsTest(test.TestCase, parameterized.TestCase): |
| |
| def setUp(self): |
| super(MultiJobsTest, self).setUp() |
| |
| workers, ps = test_util.create_local_cluster(num_workers=2, num_ps=2) |
| cluster = { |
| 'my_worker': [_strip_prefix(t.target, _GRPC_PREFIX) for t in workers], |
| 'my_ps': [_strip_prefix(t.target, _GRPC_PREFIX) for t in ps], |
| } |
| self._cluster = server_lib.ClusterSpec(cluster) |
| self._cluster_resolver = SimpleClusterResolver( |
| cluster_spec=self._cluster, master=ps[0].target) |
| |
| def tearDown(self): |
| super(MultiJobsTest, self).tearDown() |
| |
| # Clear the current device scope to avoid polluting other test cases. |
| ops.device(None).__enter__() |
| # Reset the context to avoid polluting other test cases. |
| context._reset_context() |
| |
| @test_util.eager_lazy_remote_copy_on_and_off |
| def testSimpleParameterServer(self): |
| remote.connect_to_cluster(self._cluster) |
| |
| with ops.device('/job:my_ps/task:0/device:CPU:0'): |
| v1 = variables.Variable(initial_value=0) |
| v2 = variables.Variable(initial_value=10) |
| |
| @def_function.function |
| def worker_fn(): |
| v1.assign_add(1) |
| v2.assign_sub(2) |
| return v1.read_value() + v2.read_value() |
| |
| with ops.device('/job:my_worker/task:0/device:CPU:0'): |
| self.assertAllEqual(worker_fn(), 9) |
| |
| with ops.device('/job:my_worker/task:1/device:CPU:0'): |
| self.assertAllEqual(worker_fn(), 8) |
| |
| # TODO(b/152224115): Re-enable this test. |
| @test_util.eager_lazy_remote_copy_on_and_off |
| def DISABLED_testSimpleParameterServerWithDeviceFilters(self): |
| cluster_device_filters = server_lib.ClusterDeviceFilters() |
| for i in range(2): |
| cluster_device_filters.set_device_filters('my_worker', i, ['/job:my_ps']) |
| cluster_device_filters.set_device_filters('my_ps', i, ['/job:my_worker']) |
| remote.connect_to_cluster( |
| self._cluster, cluster_device_filters=cluster_device_filters) |
| |
| with ops.device('/job:my_ps/task:0/device:CPU:0'): |
| v1 = variables.Variable(initial_value=0) |
| with ops.device('/job:my_ps/task:1/device:CPU:0'): |
| v2 = variables.Variable(initial_value=10) |
| |
| @def_function.function |
| def worker_fn(): |
| v1.assign_add(1) |
| v2.assign_sub(2) |
| return v1.read_value() + v2.read_value() |
| |
| with ops.device('/job:my_worker/task:0/device:CPU:0'): |
| self.assertAllEqual(worker_fn(), 9) |
| with ops.device('/job:my_worker/task:1/device:CPU:0'): |
| self.assertAllEqual(worker_fn(), 8) |
| |
| # The following remote call would fail because the ps nodes cannot see each |
| # other due to the device filters. |
| with self.assertRaises(errors.InvalidArgumentError) as cm: |
| with ops.device('/job:my_ps/task:0/device:CPU:0'): |
| worker_fn().numpy() |
| self.assertIn('/job:my_ps/replica:0/task:1/device:CPU:0 unknown device', |
| cm.exception.message) |
| |
| with self.assertRaises(errors.InvalidArgumentError) as cm: |
| with ops.device('/job:my_ps/task:1/device:CPU:0'): |
| worker_fn().numpy() |
| self.assertIn('/job:my_ps/replica:0/task:0/device:CPU:0 unknown device', |
| cm.exception.message) |
| |
| with ops.device('/job:my_worker/task:0/device:CPU:0'): |
| self.assertAllEqual(worker_fn(), 7) |
| with ops.device('/job:my_worker/task:1/device:CPU:0'): |
| self.assertAllEqual(worker_fn(), 6) |
| # Explicitly delete variables to avoid triggering errors when being GC'ed in |
| # subsequent tests. |
| del v1, v2 |
| |
| @test_util.eager_lazy_remote_copy_on_and_off |
| def testConnectWithClusterResolver(self): |
| remote.connect_to_cluster(self._cluster_resolver) |
| |
| v1 = variables.Variable(initial_value=0) |
| v2 = variables.Variable(initial_value=10) |
| |
| @def_function.function |
| def worker_fn(): |
| v1.assign_add(1) |
| v2.assign_sub(2) |
| return v1.read_value() + v2.read_value() |
| |
| with ops.device('/job:my_worker/task:0/device:CPU:0'): |
| self.assertAllEqual(worker_fn(), 9) |
| |
| with ops.device('/job:my_worker/task:1/device:CPU:0'): |
| self.assertAllEqual(worker_fn(), 8) |
| |
| @test_util.eager_lazy_remote_copy_on_and_off |
| def testConnectToClusterTwiceOk(self): |
| remote.connect_to_cluster(self._cluster_resolver) |
| remote.connect_to_cluster(self._cluster_resolver) |
| |
| @test_util.eager_lazy_remote_copy_on_and_off |
| def testConnectToClusterOnMismatchedDevice(self): |
| remote.connect_to_cluster(self._cluster_resolver) |
| |
| # enter into another device scope. |
| ops.device('/job:my_worker/task:0/device:CPU:0').__enter__() |
| |
| with self.assertRaises(ValueError): |
| remote.connect_to_cluster(self._cluster_resolver) |
| |
| @test_util.eager_lazy_remote_copy_on_and_off |
| def testConnectToClusterWithLocalMaster(self): |
| local_resolver = SimpleClusterResolver(ClusterSpec({}), master='local') |
| remote.connect_to_cluster(local_resolver) |
| |
| @test_util.eager_lazy_remote_copy_on_and_off |
| def testConnectToClusterInGraphModeWillFail(self): |
| ops.disable_eager_execution() |
| with self.assertRaises(ValueError): |
| remote.connect_to_cluster(self._cluster_resolver) |
| ops.enable_eager_execution() |
| |
| |
| def _strip_prefix(s, prefix): |
| return s[len(prefix):] if s.startswith(prefix) else s |
| |
| |
| if __name__ == '__main__': |
| test.main() |