blob: 30746a82e04f480eb5559c8265ff8bab06093dfe [file] [log] [blame]
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for remote execution."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import random
import numpy as np
from tensorflow.python.eager import context
from tensorflow.python.eager import def_function
from tensorflow.python.eager import remote
from tensorflow.python.eager import test
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import variables
from tensorflow.python.training import server_lib
class SingleWorkerTest(test.TestCase):
def setUp(self):
super(SingleWorkerTest, self).setUp()
workers, _ = test_util.create_local_cluster(1, 0)
remote.connect_to_remote_host(workers[0].target)
def testMultiDeviceFunctionBasic(self):
@def_function.function
def basic(i):
with ops.device('/job:localhost/replica:0/task:0/cpu:0'):
a = constant_op.constant([2]) + i
with ops.device('/job:worker/replica:0/task:0/cpu:0'):
b = constant_op.constant([1])
return a + b
self.assertAllEqual(basic(constant_op.constant([2])).numpy(), [5])
self.assertAllEqual(basic(constant_op.constant([1])).numpy(), [4])
def testMultiDeviceFunctionVariable(self):
with ops.device('/job:worker/replica:0/task:0/cpu:0'):
variable_b = variables.Variable(1)
@def_function.function
def with_variable(i):
return i + variable_b
self.assertAllEqual(with_variable(constant_op.constant([2])).numpy(), [3])
def testMultiDeviceFunctionRemoteOutput(self):
with ops.device('/job:worker/replica:0/task:0/cpu:0'):
variable_b = variables.Variable(1)
@def_function.function
def remote_output(i):
return variable_b, i + variable_b
with self.assertRaises(errors.UnimplementedError) as cm:
remote_output(constant_op.constant([1]))
self.assertIn(
'Currently, outputting tensors on remote devices is not supported.',
cm.exception.message)
def testMultiDeviceFunctionAmbiguousDevice(self):
self.skipTest('b/139212497')
@def_function.function
def ambiguous_device(i):
with ops.device('cpu:0'):
return i + constant_op.constant([2])
with self.assertRaises(ValueError) as cm:
with ops.device('/job:worker/replica:0/task:0/cpu:0'):
ambiguous_device(constant_op.constant([2])).numpy()
self.assertIn('the output node must match exactly one device',
cm.exception.message)
def testStreaming(self):
"""A mini stress test for streaming - issuing many RPCs back to back."""
with ops.device('job:worker/replica:0/task:0/device:CPU:0'):
x = array_ops.ones([2, 2])
y = array_ops.zeros([2, 2])
num_iters = 200
for _ in range(num_iters):
y = x + y
# Ask for y's shape after every 10 additions on average.
# This exercises waiting for remote shape logic in TensorHandle.
if random.randint(1, 10) == 1:
_ = y.shape
np.testing.assert_array_equal(
[[num_iters, num_iters], [num_iters, num_iters]], y.numpy())
class MultiWorkersTest(test.TestCase):
def setUp(self):
super(MultiWorkersTest, self).setUp()
workers, _ = test_util.create_local_cluster(3, 0)
remote.connect_to_remote_host(
[workers[0].target, workers[1].target, workers[2].target])
def testMultiDeviceFunctionOnLocalDevice(self):
with ops.device('/job:worker/replica:0/task:1'):
variable_b = variables.Variable(1.0)
@def_function.function
def remote_function(i):
with ops.device('/job:worker/replica:0/task:0'):
a = i + variable_b
c = a + 1.0
return c
self.assertAllEqual(remote_function(constant_op.constant([1.0])), [3.0])
def testMultiDeviceFunctionOnRemoteDevice(self):
with ops.device('/job:worker/replica:0/task:1'):
variable_b = variables.Variable(1.0)
@def_function.function
def remote_function(i):
with ops.device('/job:worker/replica:0/task:0'):
a = i + variable_b
c = a + 1.0
return c
context.context().mirroring_policy = context.MIRRORING_NONE
with ops.device('/job:worker/replica:0/task:0'):
self.assertAllEqual(remote_function(constant_op.constant([1.0])), [3.0])
if test_util.is_gpu_available():
with ops.device('/job:worker/replica:0/task:0/device:GPU:0'):
self.assertAllEqual(remote_function(constant_op.constant([1.0])), [3.0])
context.context().mirroring_policy = context.MIRRORING_ALL
with ops.device('/job:worker/replica:0/task:0'):
self.assertAllEqual(remote_function(constant_op.constant([1.0])), [3.0])
if test_util.is_gpu_available():
with ops.device('/job:worker/replica:0/task:0/device:GPU:0'):
self.assertAllEqual(remote_function(constant_op.constant([1.0])), [3.0])
def testMultiDeviceWhileLoopOnRemoteDevice(self):
with ops.device('/job:worker/replica:0/task:1'):
variable_b = variables.Variable(1.0)
@def_function.function
def remote_function(i):
def body(i, _):
with ops.device('/job:worker/replica:0/task:0'):
a = i + variable_b
return a + 1.0, 1
return control_flow_ops.while_loop_v2(lambda _, d: d < 1, body, [i, 0])[0]
context.context().mirroring_policy = context.MIRRORING_NONE
with ops.device('/job:worker/replica:0/task:0'):
self.assertAllEqual(remote_function(constant_op.constant([1.0])), [3.0])
if test_util.is_gpu_available():
with ops.device('/job:worker/replica:0/task:0/device:GPU:0'):
self.assertAllEqual(remote_function(constant_op.constant([1.0])), [3.0])
context.context().mirroring_policy = context.MIRRORING_ALL
with ops.device('/job:worker/replica:0/task:0'):
self.assertAllEqual(remote_function(constant_op.constant([1.0])), [3.0])
if test_util.is_gpu_available():
with ops.device('/job:worker/replica:0/task:0/device:GPU:0'):
self.assertAllEqual(remote_function(constant_op.constant([1.0])), [3.0])
def testSimpleParameterServer(self):
with ops.device('/job:worker/task:2/device:CPU:0'):
v1 = variables.Variable(initial_value=0)
v2 = variables.Variable(initial_value=10)
@def_function.function
def worker_fn():
v1.assign_add(1)
v2.assign_sub(2)
return v1.read_value() + v2.read_value()
with ops.device('/job:worker/task:0/device:CPU:0'):
self.assertAllEqual(worker_fn(), 9)
with ops.device('/job:worker/task:1/device:CPU:0'):
self.assertAllEqual(worker_fn(), 8)
_GRPC_PREFIX = 'grpc://'
class MultiJobsTest(test.TestCase):
def setUp(self):
super(MultiJobsTest, self).setUp()
workers, ps = test_util.create_local_cluster(2, 1)
cluster = {
'my_worker': [
_strip_prefix(workers[0].target, _GRPC_PREFIX),
_strip_prefix(workers[1].target, _GRPC_PREFIX),
],
'my_ps': [_strip_prefix(ps[0].target, _GRPC_PREFIX)],
}
remote.connect_to_cluster(server_lib.ClusterSpec(cluster))
def testSimpleParameterServer(self):
with ops.device('/job:my_ps/task:0/device:CPU:0'):
v1 = variables.Variable(initial_value=0)
v2 = variables.Variable(initial_value=10)
@def_function.function
def worker_fn():
v1.assign_add(1)
v2.assign_sub(2)
return v1.read_value() + v2.read_value()
with ops.device('/job:my_worker/task:0/device:CPU:0'):
self.assertAllEqual(worker_fn(), 9)
with ops.device('/job:my_worker/task:1/device:CPU:0'):
self.assertAllEqual(worker_fn(), 8)
def _strip_prefix(s, prefix):
return s[len(prefix):] if s.startswith(prefix) else s
if __name__ == '__main__':
test.main()