blob: d209235ef553d3bb92e5b8e1d207b7a2589e0472 [file] [log] [blame]
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for TPUStrategy."""
import os
from tensorflow.python.checkpoint import checkpoint as util
from tensorflow.python.checkpoint import checkpoint_management
from tensorflow.python.distribute import distribution_strategy_context
from tensorflow.python.distribute import strategy_test_lib
from tensorflow.python.distribute import tpu_strategy as tpu_lib
from tensorflow.python.distribute import tpu_values
from tensorflow.python.distribute.cluster_resolver import tpu_cluster_resolver
from tensorflow.python.eager import def_function
from tensorflow.python.eager import remote
from tensorflow.python.eager import test
from tensorflow.python.framework import config
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import test_util
from tensorflow.python.module import module
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import summary_ops_v2 as summary_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import flags
from tensorflow.python.tpu import device_assignment as device_assignment_lib
from tensorflow.python.tpu import tpu
from tensorflow.python.tpu import tpu_strategy_util
FLAGS = flags.FLAGS
flags.DEFINE_string("tpu", "", "Name of TPU to connect to.")
flags.DEFINE_string("project", None, "Name of GCP project with TPU.")
flags.DEFINE_string("zone", None, "Name of GCP zone with TPU.")
def get_tpu_cluster_resolver():
resolver = tpu_cluster_resolver.TPUClusterResolver(
tpu=FLAGS.tpu,
zone=FLAGS.zone,
project=FLAGS.project,
)
return resolver
def get_tpu_strategy(enable_spmd=False):
resolver = get_tpu_cluster_resolver()
remote.connect_to_cluster(resolver)
topology = tpu_strategy_util.initialize_tpu_system(resolver)
num_replicas = resolver.get_tpu_system_metadata().num_cores // 2
device_assignment = device_assignment_lib.DeviceAssignment.build(
topology, num_replicas=num_replicas, computation_shape=[1, 1, 1, 2])
strategy = tpu_lib.TPUStrategyV2(
resolver,
experimental_device_assignment=device_assignment,
experimental_spmd_xla_partitioning=enable_spmd)
return strategy, num_replicas
class TPUStrategyModelParallelismTest(
strategy_test_lib.DistributionTestBase,
strategy_test_lib.TwoDeviceDistributionTestBase):
def test_logical_device_assignment(self):
if test_util.is_mlir_bridge_enabled():
self.skipTest("TODO(b/238811067): fix MLIR bridge")
strategy, num_replicas = get_tpu_strategy()
with strategy.scope():
v = variables.Variable(2.)
with strategy.extended.experimental_logical_device(1):
w = variables.Variable(3.)
self.assertLen(strategy.experimental_local_results(v), num_replicas)
self.assertLen(strategy.experimental_local_results(w), num_replicas)
self.assertEqual("/job:localhost/replica:0/task:0/device:TPU:0",
strategy.experimental_local_results(v)[0].device)
self.assertEqual("/job:localhost/replica:0/task:0/device:TPU:1",
strategy.experimental_local_results(w)[0].device)
logical_devices = []
@def_function.function
def f(x):
replica_ctx = distribution_strategy_context.get_replica_context()
with replica_ctx.experimental_logical_device(0):
y = v * x
with replica_ctx.experimental_logical_device(1):
z = w * y
logical_devices.append((y.device, z.device))
return z
result = strategy.run(f, args=(5.,))
self.assertEqual(
[("/device:TPU_REPLICATED_CORE:0", "/device:TPU_REPLICATED_CORE:1")],
logical_devices)
with self.cached_session():
self.evaluate(variables.global_variables_initializer())
self.assertEqual(30. * num_replicas,
self.evaluate(strategy.reduce("SUM", result, axis=None)))
def test_paritioned_model_checkpointing(self):
if test_util.is_mlir_bridge_enabled():
self.skipTest("TODO(b/238811067): fix MLIR bridge")
class PartitionedModel(module.Module):
def __init__(self, v, w):
super(PartitionedModel, self).__init__()
assert distribution_strategy_context.has_strategy()
strategy = distribution_strategy_context.get_strategy()
with strategy.extended.experimental_logical_device(0):
self.v = variables.Variable(v)
with strategy.extended.experimental_logical_device(1):
self.w = variables.Variable(w)
def __call__(self, x):
replica_ctx = distribution_strategy_context.get_replica_context()
with replica_ctx.experimental_logical_device(0):
y = self.v * x
with replica_ctx.experimental_logical_device(1):
z = self.w * y
return z
def change_weights_op(self, v_new, w_new):
return control_flow_ops.group(
[self.v.assign(v_new), self.w.assign(w_new)])
strategy, num_replicas = get_tpu_strategy()
with strategy.scope():
model = PartitionedModel(2., 3.)
checkpoint_dir = self.get_temp_dir()
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint = util.Checkpoint(model=model)
with self.cached_session() as sess:
self.evaluate(variables.global_variables_initializer())
checkpoint.save(file_prefix=checkpoint_prefix)
self.evaluate(model.change_weights_op(1., 4.))
result = strategy.run(def_function.function(model), args=(5.0,))
self.assertEqual(20. * num_replicas,
self.evaluate(strategy.reduce("SUM", result, axis=None)))
status = checkpoint.restore(
checkpoint_management.latest_checkpoint(checkpoint_dir))
status.run_restore_ops(sess) # must run restore op in non-eager mode.
status.assert_consumed()
status.assert_existing_objects_matched()
result = strategy.run(def_function.function(model), args=(5.0,))
self.assertEqual(30. * num_replicas,
self.evaluate(strategy.reduce("SUM", result, axis=None)))
def test_spmd_cannot_assign_tensor_to_logical_device(self):
strategy, _ = get_tpu_strategy(enable_spmd=True)
x = constant_op.constant([0, 1])
with self.assertRaises(ValueError):
strategy.experimental_assign_to_logical_device(x, 0)
def test_spmd_variable_created_from_callable(self):
initilizer = lambda: random_ops.random_normal(shape=(16, 16))
strategy, _ = get_tpu_strategy(enable_spmd=True)
with strategy.scope():
w = variables.Variable(initilizer)
value0 = w.values[0]
for v in value0.variables:
self.assertAllEqual(v, value0.variables[0])
def test_spmd_variable_read(self):
batch_size = 32
num_feature_in = 16
num_feature_out = 8
x = random_ops.random_uniform((batch_size, num_feature_in),
dtype=dtypes.float32)
w_init = random_ops.random_uniform((num_feature_in, num_feature_out),
dtype=dtypes.float32)
strategy, num_replicas = get_tpu_strategy(enable_spmd=True)
if test_util.is_mlir_bridge_enabled() and num_replicas == 1:
self.skipTest("TODO(b/238811067): fix MLIR bridge for num_replicas=1")
with strategy.scope():
w = variables.Variable(w_init, dtype=dtypes.float32)
self.assertEqual(w.values[0].variables[0].shape.as_list(),
[num_feature_in, num_feature_out])
self.assertEqual(w.shape.as_list(), [num_feature_in, num_feature_out])
def step_fn(batch_features):
predict = math_ops.matmul(batch_features, w)
return predict
@def_function.function
def train_fn(batch_features):
return strategy.run(step_fn, args=(batch_features,))
result = train_fn(x)
self.assertAllClose(
strategy.reduce("SUM", result, axis=None),
math_ops.matmul(x, w_init) * num_replicas,
rtol=5e-03,
atol=5e-03)
def test_spmd_variable_update(self):
batch_size = 1024
num_feature_in = 256
x = random_ops.random_uniform((batch_size, num_feature_in),
dtype=dtypes.float32)
w_init = random_ops.random_uniform((batch_size, num_feature_in),
dtype=dtypes.float32)
strategy, num_replicas = get_tpu_strategy(enable_spmd=True)
if test_util.is_mlir_bridge_enabled() and num_replicas == 1:
self.skipTest("TODO(b/238811067): fix MLIR bridge for num_replicas=1")
with strategy.scope():
w = variables.Variable(w_init, dtype=dtypes.float32)
self.assertIsInstance(w, tpu_values.TPUMirroredVariable)
self.assertTrue(w._is_replicated_or_sharded_to_logical_cores())
def make_strategy_run(fn):
def run(value):
return strategy.run(fn, args=(value,))
return def_function.function(run)
result = make_strategy_run(w.assign)(x)
self.assertAllClose(
strategy.reduce("SUM", result, axis=None), x * num_replicas)
delta = random_ops.random_uniform((batch_size, num_feature_in),
dtype=dtypes.float32)
result = make_strategy_run(w.assign_sub)(delta)
x -= delta
self.assertAllClose(
strategy.reduce("SUM", result, axis=None), x * num_replicas)
delta = random_ops.random_uniform((batch_size, num_feature_in),
dtype=dtypes.float32)
result = make_strategy_run(w.assign_add)(delta)
x += delta
self.assertAllClose(
strategy.reduce("SUM", result, axis=None), x * num_replicas)
def test_spmd_variable_eager_update(self):
batch_size = 32
num_feature_in = 16
x = random_ops.random_uniform((batch_size, num_feature_in),
dtype=dtypes.float32)
w_init = random_ops.random_uniform((batch_size, num_feature_in),
dtype=dtypes.float32)
strategy, _ = get_tpu_strategy(enable_spmd=True)
with strategy.scope():
w = variables.Variable(w_init, dtype=dtypes.float32)
w.assign(x)
result = w.numpy()
self.assertAllClose(result, x)
x1 = random_ops.random_uniform((batch_size, num_feature_in),
dtype=dtypes.float32)
w.assign_sub(x1)
result = w.numpy()
self.assertAllClose(result, x - x1)
x2 = random_ops.random_uniform((batch_size, num_feature_in),
dtype=dtypes.float32)
w.assign(x)
w.assign_add(x2)
result = w.numpy()
self.assertAllClose(result, x + x2)
def test_spmd_model_checkpointing(self):
class LinearModel(module.Module):
def __init__(self, w):
super(LinearModel, self).__init__()
self.w = variables.Variable(w)
def __call__(self, x):
return math_ops.matmul(x, self.w)
def change_weights_op(self, w_new):
return self.w.assign(w_new)
batch_size = 32
num_feature_in = 16
num_feature_out = 8
w1 = random_ops.random_uniform((num_feature_in, num_feature_out),
dtype=dtypes.float32)
w2 = random_ops.random_uniform((num_feature_in, num_feature_out),
dtype=dtypes.float32)
x = random_ops.random_uniform((batch_size, num_feature_in),
dtype=dtypes.float32)
strategy, num_replicas = get_tpu_strategy(enable_spmd=True)
if test_util.is_mlir_bridge_enabled() and num_replicas == 1:
self.skipTest("TODO(b/238811067): fix MLIR bridge for num_replicas=1")
with strategy.scope():
model = LinearModel(w1)
checkpoint_dir = self.get_temp_dir()
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint = util.Checkpoint(model=model)
@def_function.function
def step_fn(x):
x = strategy.experimental_split_to_logical_devices(x, [1, 2])
return model(x)
with self.cached_session() as sess:
self.evaluate(variables.global_variables_initializer())
checkpoint.save(file_prefix=checkpoint_prefix)
self.evaluate(model.change_weights_op(w2))
result = strategy.run(step_fn, args=(x,))
self.assertAllClose(
math_ops.matmul(x, w2) * num_replicas,
self.evaluate(strategy.reduce("SUM", result, axis=None)),
rtol=5e-3,
atol=5e-3)
status = checkpoint.restore(
checkpoint_management.latest_checkpoint(checkpoint_dir))
status.run_restore_ops(sess) # must run restore op in non-eager mode.
status.assert_consumed()
status.assert_existing_objects_matched()
result = strategy.run(step_fn, args=(x,))
self.assertAllClose(
math_ops.matmul(x, w1) * num_replicas,
self.evaluate(strategy.reduce("SUM", result, axis=None)),
rtol=5e-3,
atol=5e-3)
def test_spmd_with_summary(self):
if test_util.is_mlir_bridge_enabled():
self.skipTest("TODO(b/232580663): fix MLIR bridge")
original_device_placement = config.get_soft_device_placement()
config.set_soft_device_placement(True)
strategy, _ = get_tpu_strategy(enable_spmd=True)
summary_dir = self.get_temp_dir()
writer = summary_ops.create_file_writer_v2(summary_dir)
with strategy.scope():
step = variables.Variable(0, dtype=dtypes.int64)
@def_function.function
def run():
with writer.as_default():
summary_ops.scalar("result", step * 2, step=step)
step.assign_add(1)
for _ in range(10):
strategy.run(run, args=())
for val in step.values:
for var in val.variables:
self.assertAllEqual(10, var)
config.set_soft_device_placement(original_device_placement)
def test_spmd_with_outside_comp(self):
strategy, num_replicas = get_tpu_strategy(enable_spmd=True)
if test_util.is_mlir_bridge_enabled() and num_replicas == 1:
self.skipTest("TODO(b/238673134): fix MLIR bridge for num_replicas=1")
def host_inc(x):
return x + 1
@def_function.function
def fn(x):
y = x + 1
z = tpu.outside_compilation(host_inc, y)
a = z + 1
return a
arg = constant_op.constant(0, shape=(), dtype=dtypes.int64)
result = strategy.run(fn, args=(arg,))
self.assertEqual(3 * num_replicas,
self.evaluate(strategy.reduce("SUM", result, axis=None)))
if __name__ == "__main__":
test.main()