blob: c90ae959db6eb473c9bdff4ae9b501dc23db6241 [file] [log] [blame]
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for TPU Embeddings mid level API on TPU."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import itertools
import os
from absl import flags
from absl.testing import parameterized
import numpy as np
from tensorflow.python.compat import v2_compat
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.distribute import distribute_lib
from tensorflow.python.distribute import tpu_strategy
from tensorflow.python.distribute.cluster_resolver import tpu_cluster_resolver
from tensorflow.python.eager import backprop
from tensorflow.python.eager import def_function
from tensorflow.python.eager import remote
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_math_ops
from tensorflow.python.ops import init_ops_v2
from tensorflow.python.ops import math_ops
from tensorflow.python.ops.ragged import ragged_tensor
from tensorflow.python.platform import test
from tensorflow.python.tpu import tpu_embedding_v2
from tensorflow.python.tpu import tpu_embedding_v2_utils
from tensorflow.python.tpu import tpu_strategy_util
from tensorflow.python.util import nest
FLAGS = flags.FLAGS
flags.DEFINE_string('tpu', '', 'Name of TPU to connect to.')
flags.DEFINE_string('project', None, 'Name of GCP project with TPU.')
flags.DEFINE_string('zone', None, 'Name of GCP zone with TPU.')
flags.DEFINE_string('model_dir', os.environ.get('TEST_TMPDIR'),
'A temporary directory.')
class TPUEmbeddingCorrectness(parameterized.TestCase, test.TestCase):
def setUp(self):
super(TPUEmbeddingCorrectness, self).setUp()
self.embedding_values = np.array(list(range(32)), dtype=np.float64)
self.initializer = init_ops_v2.Constant(self.embedding_values)
# Embedding for video initialized to
# 0 1 2 3
# 4 5 6 7
# ...
self.table_video = tpu_embedding_v2_utils.TableConfig(
vocabulary_size=8,
dim=4,
initializer=self.initializer,
combiner='sum',
name='video')
# Embedding for user initialized to
# 0 1
# 2 3
# 4 5
# 6 7
# ...
self.table_user = tpu_embedding_v2_utils.TableConfig(
vocabulary_size=16,
dim=2,
initializer=self.initializer,
combiner='mean',
name='user')
self.feature_config = (
tpu_embedding_v2_utils.FeatureConfig(
table=self.table_video, name='watched'),
tpu_embedding_v2_utils.FeatureConfig(
table=self.table_video, name='favorited'),
tpu_embedding_v2_utils.FeatureConfig(
table=self.table_user, name='friends'))
self.batch_size = 2
self.data_batch_size = 4
# One (global) batch of inputs
# sparse tensor for watched:
# row 0: 0
# row 1: 0, 1
# row 2: 0, 1
# row 3: 1
self.feature_watched_indices = [[0, 0], [1, 0], [1, 1],
[2, 0], [2, 1], [3, 0]]
self.feature_watched_values = [0, 0, 1, 0, 1, 1]
self.feature_watched_row_lengths = [1, 2, 2, 1]
# sparse tensor for favorited:
# row 0: 0, 1
# row 1: 1
# row 2: 0
# row 3: 0, 1
self.feature_favorited_indices = [[0, 0], [0, 1], [1, 0],
[2, 0], [3, 0], [3, 1]]
self.feature_favorited_values = [0, 1, 1, 0, 0, 1]
self.feature_favorited_row_lengths = [2, 1, 1, 2]
# sparse tensor for friends:
# row 0: 3
# row 1: 0, 1, 2
# row 2: 3
# row 3: 0, 1, 2
self.feature_friends_indices = [[0, 0], [1, 0], [1, 1], [1, 2],
[2, 0], [3, 0], [3, 1], [3, 2]]
self.feature_friends_values = [3, 0, 1, 2, 3, 0, 1, 2]
self.feature_friends_row_lengths = [1, 3, 1, 3]
self.resolver = None
def _get_strategy(self):
self.resolver = tpu_cluster_resolver.TPUClusterResolver(
tpu=FLAGS.tpu, zone=FLAGS.zone, project=FLAGS.project)
remote.connect_to_cluster(self.resolver)
tpu_strategy_util.initialize_tpu_system(self.resolver)
return tpu_strategy.TPUStrategy(self.resolver)
def _create_strategy_and_mid_level(self, optimizer_name):
strategy = self._get_strategy()
with strategy.scope():
if optimizer_name == 'sgd':
optimizer = tpu_embedding_v2_utils.SGD(learning_rate=0.1)
elif optimizer_name == 'adagrad':
optimizer = tpu_embedding_v2_utils.Adagrad(learning_rate=0.1)
elif optimizer_name == 'adam':
optimizer = tpu_embedding_v2_utils.Adam(learning_rate=0.1)
else:
raise ValueError('optimizer is not recognized: ', optimizer_name)
mid_level_api = self._create_mid_level(optimizer=optimizer)
return strategy, mid_level_api, optimizer
@parameterized.parameters(
*itertools.product(
['sgd', 'adagrad', 'adam'],
[True, False],
[True, False]))
def test_embedding(self, optimizer_name, training, sparse):
strategy, mid_level_api, optimizer = (
self._create_strategy_and_mid_level(optimizer_name))
if sparse:
dataset = self._create_sparse_dataset(strategy)
else:
dataset = self._create_ragged_dataset(strategy)
dist = strategy.experimental_distribute_dataset(
dataset,
options=distribute_lib.InputOptions(
experimental_prefetch_to_device=False))
dist_iter = iter(dist)
@def_function.function
def test_fn():
def step():
"""Create and run computation that returns the embedding activations."""
if not training:
activations = mid_level_api.dequeue()
total_loss = _get_total_loss_tensor(activations)
ret_val = [total_loss] + list(activations)
return ret_val
else:
with backprop.GradientTape() as tape:
activations = mid_level_api.dequeue()
tape.watch(activations)
total_loss = _get_total_loss_tensor(activations)
loss_per_replica = total_loss / strategy.num_replicas_in_sync
gradients = tape.gradient(loss_per_replica, activations)
mid_level_api.apply_gradients(gradients)
ret_val = [total_loss] + list(activations)
return ret_val
mid_level_api.enqueue(next(dist_iter), training=training)
result = strategy.run(step)
return result
# Run model.
shard_out_val = test_fn()
# Retrieve TPU weights to CPU.
mid_level_api._retrieve_variables()
# Compute sparse tensors for global batch.
input_data = next(iter(self._create_sparse_dataset(strategy)))
# Check results.
self._check_results(strategy, shard_out_val, training, input_data,
mid_level_api._variables,
optimizer)
def _create_mid_level(self, optimizer=None):
# Create `TPUEmbedding` object.
if optimizer is None:
optimizer = tpu_embedding_v2_utils.SGD(learning_rate=0.1)
return tpu_embedding_v2.TPUEmbedding(
feature_config=self.feature_config,
optimizer=optimizer)
def _create_sparse_data(self, include_weights, weight=0.5):
sparse_features = (
sparse_tensor.SparseTensor(
indices=self.feature_watched_indices,
values=self.feature_watched_values,
dense_shape=[self.data_batch_size, 2]),
sparse_tensor.SparseTensor(
indices=self.feature_favorited_indices,
values=self.feature_favorited_values,
dense_shape=[self.data_batch_size, 2]),
sparse_tensor.SparseTensor(
indices=self.feature_friends_indices,
values=self.feature_friends_values,
dense_shape=[self.data_batch_size, 3]))
if include_weights:
weights = []
for sparse in sparse_features:
values = (
array_ops.ones_like(sparse.values, dtype=dtypes.float32) * weight)
weights.append(sparse_tensor.SparseTensor(
indices=sparse.indices,
values=values,
dense_shape=sparse.dense_shape))
sparse_features = (sparse_features, tuple(weights))
return sparse_features
def _create_sparse_dataset(self, strategy, include_weights=False, weight=0.5):
# Create dataset for enqueue operation
sparse_features = self._create_sparse_data(include_weights, weight)
dataset = dataset_ops.DatasetV2.from_tensors(sparse_features)
# Data is batched to self.data_batch_size, rebatch to global batch size.
return dataset.unbatch().repeat().batch(
self.batch_size * strategy.num_replicas_in_sync, drop_remainder=True)
def _create_ragged_dataset(self, strategy, include_weights=False, weight=0.5):
# Create dataset for enqueue operation
sparse_features = self._create_sparse_data(include_weights, weight)
ragged_features = nest.map_structure(ragged_tensor.RaggedTensor.from_sparse,
sparse_features)
dataset = dataset_ops.DatasetV2.from_tensors(ragged_features)
# Data is batched to self.data_batch_size, rebatch to global batch size.
return dataset.unbatch().repeat().batch(
self.batch_size * strategy.num_replicas_in_sync, drop_remainder=True)
def _create_dense_input_fn(self, strategy, include_weights=False, weight=0.5):
def input_fn(ctx):
del ctx
features = (
constant_op.constant(self.feature_watched_values[-2:],
dtype=dtypes.int32),
constant_op.constant(self.feature_favorited_values[-2:],
dtype=dtypes.int32),
constant_op.constant(self.feature_friends_values[-2:],
dtype=dtypes.int32))
if include_weights:
weights = [array_ops.ones_like(t, dtype=dtypes.float32) * weight
for t in features]
features = (features, tuple(weights))
return dataset_ops.DatasetV2.from_tensors(features).repeat()
return input_fn
def _check_results(self, strategy, shard_out_val, training, input_data,
table_to_variable, optimizer):
num_replicas = strategy.num_replicas_in_sync
# Unpack the values `strategy.run()` returns.
loss = _unpack(strategy, shard_out_val[0])
activation_watched = _unpack(strategy, shard_out_val[1])
activation_favorited = _unpack(strategy, shard_out_val[2])
activation_friends = _unpack(strategy, shard_out_val[3])
# Core 0:
# Calculate the values of embedding activations.
activation_watched_gold0 = np.array([[0, 1, 2, 3], [4, 6, 8, 10]])
activation_favorited_gold0 = np.array([[4, 6, 8, 10], [4, 5, 6, 7]])
# Second row of `activation_friends_gold0` is the mean of the following.
# row 0: 0 1
# row 1: 2 3
# row 2: 4 5
activation_friends_gold0 = np.array([[6, 7], [2, 3]])
loss_gold0 = _compute_loss(activation_watched_gold0,
activation_favorited_gold0,
activation_friends_gold0)
# Add on values from other cores:
# Activations for watched are an alternating sequence of
# activation_watched_gold0 and activation_favorited_gold0.
# For favorited it is the same but in the opposite order.
activation_watched_gold = np.concatenate(
(np.concatenate((np.expand_dims(activation_watched_gold0, axis=0),) *
(num_replicas // 2)),
np.concatenate((np.expand_dims(activation_favorited_gold0, axis=0),) *
(num_replicas // 2))),
axis=1).reshape([self.batch_size * num_replicas, 4])
activation_favorited_gold = np.concatenate(
(activation_watched_gold[self.batch_size:,],
activation_watched_gold[0:self.batch_size,]))
activation_friends_gold = np.concatenate(
(activation_friends_gold0,) * num_replicas)
loss_gold = [loss_gold0] * num_replicas
# Test values.
self.assertAllClose(activation_watched_gold, activation_watched)
self.assertAllClose(activation_favorited_gold, activation_favorited)
self.assertAllClose(activation_friends_gold, activation_friends)
self.assertAllClose(loss_gold, loss)
embedding_table_video_before = np.copy(
np.reshape(self.embedding_values, [8, 4]))
embedding_table_user_before = np.copy(
np.reshape(self.embedding_values, [16, 2]))
global_batch_size = self.batch_size * num_replicas
if training:
gradient_wrt_watched_gold = (2 * activation_watched_gold /
global_batch_size)
gradient_wrt_favorited_gold = (2 * activation_favorited_gold /
global_batch_size)
gradient_wrt_friends_gold = (2 * activation_friends_gold /
global_batch_size)
# Calculate gradients wrt embedding tables.
gradients_wrt_user = (
_compute_gradients_wrt_embedding_table(
global_batch_size, gradient_wrt_friends_gold,
embedding_table_user_before, input_data[2].indices.numpy(),
input_data[2].values.numpy(), self.table_user.combiner))
gradients_wrt_video = (
_compute_gradients_wrt_embedding_table(
global_batch_size, gradient_wrt_favorited_gold,
embedding_table_video_before, input_data[1].indices.numpy(),
input_data[1].values.numpy(), self.table_video.combiner) +
_compute_gradients_wrt_embedding_table(
global_batch_size, gradient_wrt_watched_gold,
embedding_table_video_before, input_data[0].indices.numpy(),
input_data[0].values.numpy(), self.table_video.combiner))
self._check_embedding_and_slot_variables(embedding_table_user_before,
gradients_wrt_user,
embedding_table_video_before,
gradients_wrt_video,
optimizer,
table_to_variable)
def _check_embedding_and_slot_variables(self, embedding_table_user_before,
gradients_wrt_user,
embedding_table_video_before,
gradients_wrt_video,
optimizer,
table_to_variable):
if isinstance(optimizer, tpu_embedding_v2_utils.SGD):
check_fn = self._check_embedding_and_slot_variables_for_sgd
elif isinstance(optimizer, tpu_embedding_v2_utils.Adagrad):
check_fn = self._check_embedding_and_slot_variables_for_adagrad
elif isinstance(optimizer, tpu_embedding_v2_utils.Adam):
check_fn = self._check_embedding_and_slot_variables_for_adam
else:
raise ValueError('optimizer is not recognized: ', type(optimizer))
check_fn(embedding_table_user_before, gradients_wrt_user,
optimizer, table_to_variable[self.table_user.name])
check_fn(embedding_table_video_before, gradients_wrt_video,
optimizer, table_to_variable[self.table_video.name])
def _check_embedding_and_slot_variables_for_sgd(self, embedding_table_before,
gradients,
optimizer,
variables):
embedding_table = np.copy(embedding_table_before)
embedding_table -= optimizer.learning_rate * np.sum(gradients, axis=0)
self.assertAllClose(_get_variable(variables['parameters']).numpy(),
embedding_table)
def _check_embedding_and_slot_variables_for_adagrad(self,
embedding_table_before,
gradients,
optimizer,
variable):
embedding_table = np.copy(embedding_table_before)
accumulator = (
optimizer.initial_accumulator_value + np.sum(gradients, axis=0)**2)
embedding_table -= (
optimizer.learning_rate * np.sum(gradients, axis=0) /
np.sqrt(accumulator))
self.assertAllClose(_get_variable(variable['parameters']).numpy(),
embedding_table)
self.assertAllClose(_get_variable(variable['accumulators']).numpy(),
accumulator)
def _check_embedding_and_slot_variables_for_adam(self, embedding_table_before,
gradients,
optimizer,
variable):
embedding_table = np.copy(embedding_table_before)
g = np.sum(gradients, axis=0)
v = g**2 * (1 - optimizer.beta_2)
m = g * (1 - optimizer.beta_1)
epsilon = optimizer.epsilon
# TPU Embeddings don't have the LR decay factor for Adam.
lr_modifier = 1
embedding_table -= (
m * optimizer.learning_rate * lr_modifier / (np.sqrt(v) + epsilon))
self.assertAllClose(_get_variable(variable['parameters']).numpy(),
embedding_table, rtol=1e-4)
self.assertAllClose(_get_variable(variable['momenta']).numpy(),
m, rtol=1e-4)
self.assertAllClose(_get_variable(variable['velocities']).numpy(),
v, rtol=1e-4)
def _get_replica_numpy(self, structured, strategy, replica_id):
def select_replica(x):
x = strategy.experimental_local_results(x)
if len(x) == 1:
return x.numpy()
return x[replica_id].numpy()
return nest.map_structure(select_replica, structured)
def test_dense_lookup(self):
strategy, mid_level_api, _ = self._create_strategy_and_mid_level('sgd')
input_fn = self._create_dense_input_fn(strategy)
dist = strategy.distribute_datasets_from_function(
input_fn,
options=distribute_lib.InputOptions(
experimental_prefetch_to_device=False))
dist_iter = iter(dist)
@def_function.function
def test_fn():
def step():
return mid_level_api.dequeue()
mid_level_api.enqueue(next(dist_iter), training=False)
return strategy.run(step)
# Run model.
shard0 = self._get_replica_numpy(test_fn(), strategy, 0)
# embedding_values is a linear list, so we reshape to match the correct
# shape of the corresponding table before performing the lookup.
numpy_videos = np.reshape(self.embedding_values, (8, 4))
numpy_users = np.reshape(self.embedding_values, (16, 2))
golden = ((numpy_videos[self.feature_watched_values[-2:]],
numpy_videos[self.feature_favorited_values[-2:]],
numpy_users[self.feature_friends_values[-2:]]))
self.assertAllClose(shard0, golden)
@parameterized.parameters([True, False])
def test_sequence_embeddings(self, sparse):
feature_config = (
tpu_embedding_v2_utils.FeatureConfig(
table=self.table_video, name='watched',
max_sequence_length=2),
tpu_embedding_v2_utils.FeatureConfig(
table=self.table_video, name='favorited',
max_sequence_length=2),
tpu_embedding_v2_utils.FeatureConfig(
table=self.table_user, name='friends',
max_sequence_length=3))
optimizer = tpu_embedding_v2_utils.SGD(learning_rate=0.1)
strategy = self._get_strategy()
num_replicas = strategy.num_replicas_in_sync
with strategy.scope():
mid_level = tpu_embedding_v2.TPUEmbedding(
feature_config=feature_config,
optimizer=optimizer)
# Call build here. We call 'next' outside of the tf.function and this
# results in data where the shape of the sparse tensor is a tensor which we
# can't tell the shape of at tracing time.
mid_level.build(self.batch_size)
if sparse:
dataset = self._create_sparse_dataset(strategy)
else:
dataset = self._create_ragged_dataset(strategy)
data = next(iter(strategy.experimental_distribute_dataset(
dataset,
options=distribute_lib.InputOptions(
experimental_prefetch_to_device=False))))
@def_function.function
def embedding_and_set_gradients(data):
def tpu_fn():
activations = mid_level.dequeue()
mid_level.apply_gradients(nest.map_structure(array_ops.ones_like,
activations))
return activations
mid_level.enqueue(data)
return strategy.run(tpu_fn)
@def_function.function
def embedding_only(data):
def tpu_fn():
return mid_level.dequeue()
mid_level.enqueue(data)
return strategy.run(tpu_fn)
# Only check core 0.
before_update = self._get_replica_numpy(
embedding_and_set_gradients(data), strategy, 0)
after_update = self._get_replica_numpy(embedding_only(data), strategy, 0)
# For videos table, row 0 and row 1 are looked up 3*num_replicas times as
# they occur 3 times per replica (considering the features 0 and 1 which are
# both looked up in the videos table).
# Feature 0 has ids [0, 0, 1], [0, 1, 1], ... repeated over num_replicas
# Feature 1 has ids [0, 1, 1], [0, 0, 1], ... repeated over num_replicas
# This means that both rows 0 and 1 get a -0.1*3*num_replicas update
# For users table, each row is looked up twice:
# Feature 2 has ids [3, 0, 1, 2], .. repeated over num_replicas
# This means that we get a -0.1*num_replicas update to the third feature.
# In general this means that after the update, if we lookup feature 0 and 1
# the values will be 0.3*num_replicas lower per entry and for feature 2 they
# will be 0.1*num_replicas lower.
# The one issue is that these lookups contain padding values.
# For core 0, we get the first 2 elements of the 4 element batch.
# For feature 0, the indices are [[0, 0], [1, 0], [1, 1]] with max sequence
# length of 2, which means that [0, 1] will be 0s.
# For feature 1, the indices are [[0, 0], [0, 1], [1, 0]] with max sequence
# length of 2, which means that [1, 1] will be 0s.
# For feature 2, the indices are [[0, 0], [1, 0], [1, 1], [1, 2]] with max
# sequence length of 3, which means that [0, 1], [0, 2] will be 0s.
# The following masks represent that so that we only apply the above updates
# to the non-padding rows:
masks = (
np.array([[[1], [0]], [[1], [1]]]),
np.array([[[1], [1]], [[1], [0]]]),
np.array([[[1], [0], [0]], [[1], [1], [1]]]))
per_row_update = (0.3 * num_replicas,
0.3 * num_replicas,
0.1 * num_replicas)
golden = tuple([before - update * mask for before, update, mask in
zip(before_update, per_row_update, masks)])
self.assertAllClose(golden, after_update)
def _compute_gradients_wrt_embedding_table(batch_size,
gradient_wrt_activation,
embedding_table,
feature_indices,
feature_values,
combiner,
max_sequence_length=0):
"""Compute gradients wrt embedding_table.
Args:
batch_size: `int`, batch size.
gradient_wrt_activation: `np.array` with shape `batch_size` by
embedding `dimension`.
embedding_table: `np.array` with shape `vocabulary_size` by embedding
`dimension`.
feature_indices: `indices` as used to construct `SparseTensor`.
feature_values: `values` as used to construct `SparseTensor`.
combiner: `String`, 'mean' or 'sum'.
max_sequence_length: If non-zero, a sequence feature with the given length.
Returns:
Gradients wrt `embedding_table`, an `np.array`s with shape
`batch_size` by `vocabulary_size` by
embedding `dimension`.
Raises:
ValueError: if `combiner` is not one of 'mean' or 'sum'.
"""
if combiner not in ('mean', 'sum'):
raise ValueError('`combiner` must be mean or sum; got {}.'.format(combiner))
grads = []
for i in range(batch_size):
grad = np.zeros_like(embedding_table)
count = 0
for (batch_i, seq_index), vocabulary_id in zip(feature_indices,
feature_values):
if batch_i == i:
count += 1
if max_sequence_length > 0:
if seq_index < max_sequence_length:
grad[vocabulary_id, :] += gradient_wrt_activation[i, seq_index, :]
else:
grad[vocabulary_id, :] += gradient_wrt_activation[i, :]
if combiner == 'mean' and not max_sequence_length:
grad = grad / count
grads.append(grad)
return np.stack(grads)
def _unpack(strategy, per_replica_output):
per_replica_output = strategy.experimental_local_results(per_replica_output)
per_replica_output = array_ops.concat(per_replica_output, axis=0).numpy()
return per_replica_output
def _get_total_loss_tensor(activations):
losses = []
for activation in activations:
losses.append(
math_ops.reduce_mean(
math_ops.reduce_sum(
gen_math_ops.squared_difference(activation, 0), 1)))
total_loss = array_ops.expand_dims_v2(sum(losses), 0)
return total_loss
def _compute_loss(activation_watched, activation_favorited, activation_friends):
watched_loss = np.mean(np.sum(activation_watched**2, axis=1))
if len(activation_favorited.shape) == 2:
favorited_loss = np.mean(np.sum(activation_favorited**2, axis=1))
else:
favorited_loss = np.mean(np.sum(activation_favorited**2, axis=(1, 2)))
if len(activation_friends.shape) == 2:
friends_loss = np.mean(np.sum(activation_friends**2, axis=1))
else:
friends_loss = np.mean(np.sum(activation_friends**2, axis=(1, 2)))
loss = watched_loss + favorited_loss + friends_loss
return loss
def _get_variable(variable):
if isinstance(variable, tpu_embedding_v2.TPUShardedVariable):
return variable.variables[0]
return variable
if __name__ == '__main__':
v2_compat.enable_v2_behavior()
test.main()