blob: cb4d785e6cddd18f95086441cd424adb4bf7234a [file] [log] [blame]
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for TPUClusterResolver."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import six
from six.moves.urllib.error import URLError
from tensorflow.python import eager
from tensorflow.python.client import session
from tensorflow.python.distribute.cluster_resolver import tpu_cluster_resolver as resolver
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.platform import test
from tensorflow.python.training import server_lib
from tensorflow.python.util import compat
mock = test.mock
class MockRequestClass(object):
def __init__(self, name, tpu_map):
self._name = name
self._tpu_map = tpu_map
def execute(self):
if self._name in self._tpu_map:
return self._tpu_map[self._name]
else:
raise KeyError('Resource %s was not found' % self._name)
class MockNodeClass(object):
def __init__(self, tpu_map):
self._tpu_map = tpu_map
def get(self, name):
return MockRequestClass(name, self._tpu_map)
def mock_request_compute_metadata(cls, *args, **kwargs):
del cls, kwargs # Unused.
if args[0] == 'project/project-id':
return 'test-project'
elif args[0] == 'instance/zone':
return 'projects/test-project/locations/us-central1-c'
elif args[0] == 'instance/network-interfaces/0/ip':
return '10.128.1.2'
return ''
def mock_is_running_in_gce():
return True
def mock_is_not_running_in_gce():
return False
def mock_running_in_gce_urlopen(cls, *args, **kwargs):
del cls, args, kwargs # Unused.
mock_response = mock.MagicMock()
mock_response.info.return_value = {'Metadata-Flavor': 'Google'}
return mock_response
def mock_not_running_in_gce_urlopen(cls, *args, **kwargs):
del cls, args, kwargs # Unused.
raise URLError(reason='Host does not exist.')
@test_util.run_all_in_graph_and_eager_modes
class TPUClusterResolverTest(test.TestCase):
def _verifyClusterSpecEquality(self, cluster_spec, expected_proto):
"""Verifies that the ClusterSpec generates the correct proto.
We are testing this four different ways to ensure that the ClusterSpec
returned by the TPUClusterResolver behaves identically to a normal
ClusterSpec when passed into the generic ClusterSpec libraries.
Args:
cluster_spec: ClusterSpec returned by the TPUClusterResolver
expected_proto: Expected protobuf
"""
self.assertProtoEquals(expected_proto, cluster_spec.as_cluster_def())
self.assertProtoEquals(
expected_proto,
server_lib.ClusterSpec(cluster_spec).as_cluster_def())
self.assertProtoEquals(expected_proto,
server_lib.ClusterSpec(
cluster_spec.as_cluster_def()).as_cluster_def())
self.assertProtoEquals(expected_proto,
server_lib.ClusterSpec(
cluster_spec.as_dict()).as_cluster_def())
def mock_service_client(self, tpu_map=None):
if tpu_map is None:
tpu_map = {}
mock_locations = mock.MagicMock()
mock_locations.nodes.return_value = MockNodeClass(tpu_map)
mock_project = mock.MagicMock()
mock_project.locations.return_value = mock_locations
mock_client = mock.MagicMock()
mock_client.projects.return_value = mock_project
return mock_client
@mock.patch.object(resolver, 'is_running_in_gce',
mock_is_running_in_gce)
def testCheckRunningInGceWithNoTpuName(self):
with self.assertRaisesRegexp(RuntimeError, '.*Google Cloud.*'):
resolver.TPUClusterResolver(tpu='')
@mock.patch.object(six.moves.urllib.request,
'urlopen',
mock_running_in_gce_urlopen)
def testIsRunningInGce(self):
self.assertTrue(resolver.is_running_in_gce())
@mock.patch.object(six.moves.urllib.request,
'urlopen',
mock_not_running_in_gce_urlopen)
def testIsNotRunningInGce(self):
self.assertFalse(resolver.is_running_in_gce())
@mock.patch.object(resolver.TPUClusterResolver,
'_request_compute_metadata', mock_request_compute_metadata)
def testRetrieveProjectAndZoneFromMetadata(self):
tpu_map = {
'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': {
'ipAddress': '10.1.2.3',
'port': '8470',
'health': 'HEALTHY'
}
}
cluster_resolver = resolver.TPUClusterResolver(
project=None,
zone=None,
tpu=['test-tpu-1'],
credentials=None,
service=self.mock_service_client(tpu_map=tpu_map),
coordinator_name='coordinator')
actual_cluster_spec = cluster_resolver.cluster_spec()
expected_proto = """
job {
name: 'coordinator'
tasks { key: 0 value: '10.128.1.2:%s' }
}
job {
name: 'worker'
tasks { key: 0 value: '10.1.2.3:8470' }
}
""" % cluster_resolver._coordinator_port
self._verifyClusterSpecEquality(actual_cluster_spec, str(expected_proto))
self.assertEqual(cluster_resolver.master(), 'grpc://10.1.2.3:8470')
@mock.patch.object(resolver.TPUClusterResolver,
'_request_compute_metadata', mock_request_compute_metadata)
def testRetrieveProjectAndZoneFromMetadataNoCoordinator(self):
tpu_map = {
'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': {
'ipAddress': '10.1.2.3',
'port': '8470',
'health': 'HEALTHY'
}
}
cluster_resolver = resolver.TPUClusterResolver(
project=None,
zone=None,
tpu=['test-tpu-1'],
coordinator_name=None,
credentials=None,
service=self.mock_service_client(tpu_map=tpu_map))
actual_cluster_spec = cluster_resolver.cluster_spec()
expected_proto = """
job { name: 'worker' tasks { key: 0 value: '10.1.2.3:8470' } }
"""
self._verifyClusterSpecEquality(actual_cluster_spec, expected_proto)
self.assertEqual(cluster_resolver.master(), 'grpc://10.1.2.3:8470')
@mock.patch.object(resolver.TPUClusterResolver,
'_request_compute_metadata', mock_request_compute_metadata)
def testNotReadyCloudTpu(self):
tpu_map = {
'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': {
'ipAddress': '10.1.2.3',
'port': '8470',
'state': 'CREATING'
}
}
cluster_resolver = resolver.TPUClusterResolver(
project=None,
zone=None,
tpu='test-tpu-1',
coordinator_name=None,
credentials=None,
service=self.mock_service_client(tpu_map=tpu_map))
with self.assertRaises(RuntimeError):
cluster_resolver.cluster_spec()
def testSimpleSuccessfulRetrieval(self):
tpu_map = {
'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': {
'ipAddress': '10.1.2.3',
'port': '8470',
'health': 'HEALTHY'
}
}
cluster_resolver = resolver.TPUClusterResolver(
project='test-project',
zone='us-central1-c',
tpu=['test-tpu-1'],
coordinator_name='coordinator',
coordinator_address='10.128.1.5:10203',
credentials=None,
service=self.mock_service_client(tpu_map=tpu_map))
actual_cluster_spec = cluster_resolver.cluster_spec()
expected_proto = """
job { name: 'coordinator' tasks { key: 0 value: '10.128.1.5:10203' } }
job { name: 'worker' tasks { key: 0 value: '10.1.2.3:8470' } }
"""
self._verifyClusterSpecEquality(actual_cluster_spec, expected_proto)
self.assertEqual(cluster_resolver.master(), 'grpc://10.1.2.3:8470')
def testNewNetworkEndpointFormat(self):
tpu_map = {
'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': {
'health': 'HEALTHY',
'networkEndpoints': [{
'ipAddress': '10.2.3.4',
'port': 8470,
}]
}
}
cluster_resolver = resolver.TPUClusterResolver(
project='test-project',
zone='us-central1-c',
tpu='test-tpu-1',
coordinator_name='coordinator',
coordinator_address='10.128.1.5:10203',
credentials=None,
service=self.mock_service_client(tpu_map=tpu_map))
actual_cluster_spec = cluster_resolver.cluster_spec()
expected_proto = """
job { name: 'coordinator' tasks { key: 0 value: '10.128.1.5:10203' } }
job { name: 'worker' tasks { key: 0 value: '10.2.3.4:8470' } }
"""
self._verifyClusterSpecEquality(actual_cluster_spec, expected_proto)
self.assertEqual('grpc://10.2.3.4:8470', cluster_resolver.master())
@mock.patch.object(resolver.TPUClusterResolver,
'_request_compute_metadata', mock_request_compute_metadata)
def testPodResolution(self):
tpu_map = {
'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': {
'health':
'HEALTHY',
'networkEndpoints': [
{
'ipAddress': '10.2.3.4',
'port': 8470,
},
{
'ipAddress': '10.2.3.5',
'port': 8470,
},
{
'ipAddress': '10.2.3.6',
'port': 8470,
},
{
'ipAddress': '10.2.3.7',
'port': 8470,
},
]
}
}
cluster_resolver = resolver.TPUClusterResolver(
tpu='test-tpu-1',
credentials=None,
service=self.mock_service_client(tpu_map=tpu_map),
coordinator_name='coordinator')
actual_cluster_spec = cluster_resolver.cluster_spec()
expected_proto = """
job {
name: 'coordinator',
tasks { key: 0 value: '10.128.1.2:%s'}
}
job {
name: 'worker'
tasks { key: 0 value: '10.2.3.4:8470' }
tasks { key: 1 value: '10.2.3.5:8470' }
tasks { key: 2 value: '10.2.3.6:8470' }
tasks { key: 3 value: '10.2.3.7:8470' }
}
""" % cluster_resolver._coordinator_port
self._verifyClusterSpecEquality(actual_cluster_spec, str(expected_proto))
self.assertEqual(cluster_resolver.master(), 'grpc://10.2.3.4:8470')
def testPodResolutionNoCoordinator(self):
tpu_map = {
'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': {
'health':
'HEALTHY',
'networkEndpoints': [
{
'ipAddress': '10.2.3.4',
'port': 8470,
},
{
'ipAddress': '10.2.3.5',
'port': 8470,
},
{
'ipAddress': '10.2.3.6',
'port': 8470,
},
{
'ipAddress': '10.2.3.7',
'port': 8470,
},
]
}
}
cluster_resolver = resolver.TPUClusterResolver(
project='test-project',
zone='us-central1-c',
tpu='test-tpu-1',
coordinator_name=None,
credentials=None,
service=self.mock_service_client(tpu_map=tpu_map))
actual_cluster_spec = cluster_resolver.cluster_spec()
expected_proto = """
job {
name: 'worker'
tasks { key: 0 value: '10.2.3.4:8470' }
tasks { key: 1 value: '10.2.3.5:8470' }
tasks { key: 2 value: '10.2.3.6:8470' }
tasks { key: 3 value: '10.2.3.7:8470' }
}
"""
self._verifyClusterSpecEquality(actual_cluster_spec, expected_proto)
self.assertEqual(cluster_resolver.master(), 'grpc://10.2.3.4:8470')
def testGetMasterNoEntries(self):
tpu_map = {}
with self.assertRaises(ValueError):
resolver.TPUClusterResolver(
project='test-project',
zone='us-central1-c',
tpu=[],
coordinator_name=None,
credentials=None,
service=self.mock_service_client(tpu_map=tpu_map))
# TODO(saeta): Convert to parameterized test when included in OSS TF.
def verifyShouldResolve(self, tpu, should_resolve):
cluster_resolver = resolver.TPUClusterResolver(
project='test-project',
zone='us-central1-c',
tpu=tpu,
coordinator_name=None,
credentials=None,
service=self.mock_service_client(tpu_map={}))
self.assertEqual(should_resolve, cluster_resolver._should_resolve(),
"TPU: '%s'" % tpu)
@mock.patch.object(resolver, 'is_running_in_gce',
mock_is_not_running_in_gce)
def testShouldResolveNoName(self):
self.verifyShouldResolve('', False)
def testShouldResolveLocal(self):
self.verifyShouldResolve('local', False)
def testShouldResolveLocalhost(self):
self.verifyShouldResolve('localhost:12345', False)
def testShouldResolveGrpc(self):
self.verifyShouldResolve('grpc://10.1.2.3:8470', False)
def testShouldResolveBns(self):
self.verifyShouldResolve('/bns/foo/bar', False)
def testShouldResolveName(self):
self.verifyShouldResolve('mytpu', True)
def testShouldResolveList(self):
self.verifyShouldResolve(['myothertpu'], True)
def testShouldResolveGrpcPrefix(self):
self.verifyShouldResolve('grpctpu', True)
def testNoCallComputeMetadata(self):
cluster_resolver = resolver.TPUClusterResolver(tpu='/bns/foo/bar')
self.assertEqual('/bns/foo/bar', cluster_resolver.master())
if ops.executing_eagerly_outside_functions():
self.assertEqual(
server_lib.ClusterSpec({
'worker': ['/bns/foo/bar']
}).as_dict(),
cluster_resolver.cluster_spec().as_dict())
else:
self.assertEqual(None, cluster_resolver.cluster_spec())
def testLocalhostMaster(self):
cluster_resolver = resolver.TPUClusterResolver(tpu='localhost:12345')
self.assertEqual('localhost:12345', cluster_resolver.master())
def testGkeEnvironmentForDonut(self):
os.environ['KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS'] = 'grpc://10.120.27.5:8470'
self.assertIn('KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS', os.environ)
self.assertTrue(resolver.TPUClusterResolver._in_gke())
self.assertEqual(
compat.as_bytes('grpc://10.120.27.5:8470'),
compat.as_bytes(
resolver.TPUClusterResolver._gke_endpoints()))
cluster_resolver = resolver.TPUClusterResolver()
self.assertEqual(
compat.as_bytes('grpc://10.120.27.5:8470'),
compat.as_bytes(cluster_resolver.master()))
actual_cluster_spec = cluster_resolver.cluster_spec()
expected_proto = """
job {
name: 'worker'
tasks { key: 0 value: '10.120.27.5:8470' }
}
"""
self._verifyClusterSpecEquality(actual_cluster_spec, expected_proto)
del os.environ['KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS']
def testGkeEnvironmentForPod(self):
os.environ['KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS'] = ('grpc://10.120.27.5:8470,'
'grpc://10.120.27.6:8470,'
'grpc://10.120.27.7:8470,'
'grpc://10.120.27.8:8470')
self.assertIn('KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS', os.environ)
self.assertTrue(resolver.TPUClusterResolver._in_gke())
self.assertEqual(
compat.as_bytes('grpc://10.120.27.5:8470,'
'grpc://10.120.27.6:8470,'
'grpc://10.120.27.7:8470,'
'grpc://10.120.27.8:8470'),
compat.as_bytes(
resolver.TPUClusterResolver._gke_endpoints()))
cluster_resolver = resolver.TPUClusterResolver()
self.assertEqual(
compat.as_bytes('grpc://10.120.27.5:8470'),
compat.as_bytes(cluster_resolver.master()))
actual_cluster_spec = cluster_resolver.cluster_spec()
expected_proto = """
job {
name: 'worker'
tasks { key: 0 value: '10.120.27.5:8470' }
tasks { key: 1 value: '10.120.27.6:8470' }
tasks { key: 2 value: '10.120.27.7:8470' }
tasks { key: 3 value: '10.120.27.8:8470' }
}
"""
self._verifyClusterSpecEquality(actual_cluster_spec, expected_proto)
del os.environ['KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS']
def testEnvironmentDiscoveryUrl(self):
os.environ['TPU_API_DISCOVERY_URL'] = 'https://{api}.internal/{apiVersion}'
self.assertEqual(
'https://{api}.internal/{apiVersion}',
(resolver.TPUClusterResolver._environment_discovery_url()))
def testEnvironmentAndRpcDetectionForGoogle(self):
cluster_resolver = resolver.TPUClusterResolver(tpu='/bns/ab/cd/ef')
self.assertEqual(cluster_resolver.environment, 'google')
self.assertEqual(cluster_resolver.rpc_layer, None)
def testEnvironmentAndRpcDetectionForGrpcString(self):
cluster_resolver = resolver.TPUClusterResolver(
tpu='grpc://10.1.2.3:8470')
self.assertEqual(cluster_resolver.environment, '')
self.assertEqual(cluster_resolver.rpc_layer, 'grpc')
self.assertEqual(cluster_resolver.master(), 'grpc://10.1.2.3:8470')
def testOverrideTaskTypeAndIndexAndGetMaster(self):
tpu_map = {
'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': {
'health':
'HEALTHY',
'networkEndpoints': [
{
'ipAddress': '10.2.3.4',
'port': 8470,
},
{
'ipAddress': '10.2.3.5',
'port': 8470,
},
{
'ipAddress': '10.2.3.6',
'port': 8470,
},
{
'ipAddress': '10.2.3.7',
'port': 8470,
},
]
}
}
cluster_resolver = resolver.TPUClusterResolver(
project='test-project',
zone='us-central1-c',
tpu='test-tpu-1',
coordinator_name=None,
credentials=None,
service=self.mock_service_client(tpu_map=tpu_map))
self.assertEqual(cluster_resolver.master(), 'grpc://10.2.3.4:8470')
cluster_resolver.task_type = 'worker'
cluster_resolver.task_id = 3
self.assertEqual(cluster_resolver.master(), 'grpc://10.2.3.7:8470')
self.assertEqual(
cluster_resolver.master(
task_type='worker', task_id=2, rpc_layer='test'),
'test://10.2.3.6:8470')
def testGetDeviceDictAndCoresWithTPUs(self):
device_names = [
'/job:tpu_worker/task:0/device:TPU:0',
'/job:tpu_worker/task:1/device:TPU:1',
'/job:tpu_worker/task:2/device:TPU:0',
'/job:tpu_worker/task:3/device:TPU:1',
'/job:tpu_worker/task:0/device:TPU:4',
'/job:tpu_worker/task:1/device:TPU:5',
'/job:tpu_worker/task:2/device:TPU:4',
'/job:tpu_worker/task:3/device:TPU:5',
]
device_list = [
session._DeviceAttributes(
name, 'TPU', 1024, 0) for name in device_names
]
device_details = resolver.TPUClusterResolver._get_device_dict_and_cores(
device_list)
self.assertEqual(device_details.total_cores, 8)
self.assertEqual(device_details.device_map,
{'0': ['0', '4'],
'1': ['1', '5'],
'2': ['0', '4'],
'3': ['1', '5']})
def testGetDeviceDictAndCoresWithCPUsAndGPUs(self):
device_names = [
'/job:tpu_worker/task:0/device:CPU:0',
'/job:tpu_worker/task:1/device:CPU:0',
'/job:tpu_worker/task:2/device:CPU:0',
'/job:tpu_worker/task:3/device:CPU:0',
'/job:tpu_worker/task:0/device:GPU:1',
'/job:tpu_worker/task:1/device:GPU:1',
'/job:tpu_worker/task:2/device:GPU:1',
'/job:tpu_worker/task:3/device:GPU:1',
]
device_list = [
session._DeviceAttributes(
name, 'XLA', 1024, 0) for name in device_names
]
device_dict, num_cores =\
resolver.TPUClusterResolver._get_device_dict_and_cores(device_list)
self.assertEqual(num_cores, 0)
self.assertEqual(device_dict, {})
def testVerifySameCoreCount(self):
self.assertEqual(
resolver.TPUClusterResolver
._verify_and_return_same_core_count({0: [0, 1, 2, 3, 4, 5, 6, 7]}), 8)
self.assertEqual(
resolver.TPUClusterResolver
._verify_and_return_same_core_count({
0: [0, 1],
1: [2, 3]
}), 2)
with self.assertRaises(RuntimeError):
resolver.TPUClusterResolver._verify_and_return_same_core_count(
{
0: [0],
1: [1, 2]
})
@mock.patch.object(eager.context, 'list_devices')
@mock.patch.object(session.BaseSession, 'list_devices')
@mock.patch.object(resolver, 'is_running_in_gce',
mock_is_not_running_in_gce)
def testNumAcceleratorsSuccess(self, mock_list_devices,
mock_eager_list_devices):
device_names = [
'/job:tpu_worker/task:0/device:TPU:0',
'/job:tpu_worker/task:1/device:TPU:1',
'/job:tpu_worker/task:2/device:TPU:0',
'/job:tpu_worker/task:3/device:TPU:1',
'/job:tpu_worker/task:0/device:TPU:4',
'/job:tpu_worker/task:1/device:TPU:5',
'/job:tpu_worker/task:2/device:TPU:4',
'/job:tpu_worker/task:3/device:TPU:5',
]
device_list = [
session._DeviceAttributes(
name, 'TPU', 1024, 0) for name in device_names
]
mock_eager_list_devices.return_value = device_names
mock_list_devices.return_value = device_list
cluster_resolver = resolver.TPUClusterResolver(tpu='')
self.assertEqual(cluster_resolver.num_accelerators(), {'TPU': 2})
@mock.patch.object(eager.context, 'list_devices')
@mock.patch.object(session.BaseSession, 'list_devices')
@mock.patch.object(resolver, 'is_running_in_gce',
mock_is_not_running_in_gce)
def testNumAcceleratorsRetryFailure(self, mock_list_devices,
mock_eager_list_devices):
cluster_resolver = resolver.TPUClusterResolver(tpu='')
mock_list_devices.side_effect = errors.DeadlineExceededError(
None, None, 'timeout')
mock_eager_list_devices.side_effect = errors.DeadlineExceededError(
None, None, 'timeout')
with self.assertRaises(RuntimeError):
cluster_resolver.num_accelerators()
if __name__ == '__main__':
test.main()