Change TPU Embedding API to allow passing functions in the initializer rather than keys. This more closely maps to the feature column API.

PiperOrigin-RevId: 281926084
Change-Id: I6f653d048aa0c00940b70a4616dbc63376ab25c2
diff --git a/tensorflow/python/tpu/tpu_embedding.py b/tensorflow/python/tpu/tpu_embedding.py
index 5b61abc..7648369 100644
--- a/tensorflow/python/tpu/tpu_embedding.py
+++ b/tensorflow/python/tpu/tpu_embedding.py
@@ -32,6 +32,7 @@
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import control_flow_ops
 from tensorflow.python.ops import init_ops
+from tensorflow.python.ops import math_ops
 from tensorflow.python.ops import partitioned_variables
 from tensorflow.python.ops import state_ops
 from tensorflow.python.ops import variable_scope
@@ -49,7 +50,7 @@
 class TableConfig(
     collections.namedtuple('TableConfig', [
         'vocabulary_size', 'dimension', 'initializer', 'combiner',
-        'hot_id_replication', 'learning_rate', 'learning_rate_key'
+        'hot_id_replication', 'learning_rate', 'learning_rate_fn'
     ])):
   """Embedding table configuration."""
 
@@ -60,7 +61,7 @@
               combiner='mean',
               hot_id_replication=False,
               learning_rate=None,
-              learning_rate_key=None):
+              learning_rate_fn=None):
     """Embedding table configuration.
 
     Args:
@@ -79,17 +80,16 @@
       hot_id_replication: If true, enables hot id replication, which can make
         embedding lookups faster if there are some hot rows in the table.
       learning_rate: float, static learning rate for this table. If
-        learning_rate and learning_rate_key are both `None`, global
+        learning_rate and learning_rate_fn are both `None`, global
         static learning rate as specified in `optimization_parameters` in
-        `TPUEmbedding` constructor will be used. `learning_rate_key` must be
+        `TPUEmbedding` constructor will be used. `learning_rate_fn` must be
         `None` if `learning_rate` is not `None.
-      learning_rate_key: string, use dynamic learning rate of
-        `learning_rates[learning_rate_key]` for this table, where
-        `learning_rates` is the second argument of
-        `generate_send_gradients_op()`. If learning_rate and learning_rate_key
-        are both `None`, global static learning rate as specified in
-        `optimization_parameters` in `TPUEmbedding` constructor will be used.
-        `learning_rate` must be `None` if `learning_rate_key` is not `None.
+      learning_rate_fn: string, use dynamic learning rate given by the function.
+        This function function will be passed the current global step. If
+        learning_rate and learning_rate_fn are both `None`, global static
+        learning rate as specified in `optimization_parameters` in
+        `TPUEmbedding` constructor will be used. `learning_rate` must be `None`
+        if `learning_rate_fn` is not `None.
 
     Returns:
       `TableConfig`.
@@ -99,7 +99,7 @@
       ValueError: if `dimension` is not positive integer.
       ValueError: if `initializer` is specified and is not callable.
       ValueError: if `combiner` is not supported.
-      ValueError: if `learning_rate` and `learning_rate_key` are both not
+      ValueError: if `learning_rate` and `learning_rate_fn` are both not
         `None`.
     """
     if not isinstance(vocabulary_size, int) or vocabulary_size < 1:
@@ -117,14 +117,14 @@
     if combiner not in ('mean', 'sum', 'sqrtn', None):
       raise ValueError('Invalid combiner {}'.format(combiner))
 
-    if learning_rate is not None and learning_rate_key is not None:
-      raise ValueError('At most one of learning_rate and learning_rate_key '
+    if learning_rate is not None and learning_rate_fn is not None:
+      raise ValueError('At most one of learning_rate and learning_rate_fn '
                        'can be None; got {} and {}'
-                       .format(learning_rate, learning_rate_key))
+                       .format(learning_rate, learning_rate_fn))
 
     return super(TableConfig, cls).__new__(
         cls, vocabulary_size, dimension, initializer, combiner,
-        hot_id_replication, learning_rate, learning_rate_key)
+        hot_id_replication, learning_rate, learning_rate_fn)
 
 
 class FeatureConfig(
@@ -694,6 +694,11 @@
         self._optimization_parameters)
     self._pipeline_execution_with_tensor_core = (
         pipeline_execution_with_tensor_core)
+    self._learning_rate_fn = list(set(
+        c.learning_rate_fn for c in self._table_to_config_dict.values()
+        if c.learning_rate_fn is not None))
+    self._learning_rate_fn_to_tag = {
+        fn: id for id, fn in enumerate(self._learning_rate_fn)}
 
     self._config_proto = self._create_config_proto()
 
@@ -767,10 +772,6 @@
 
   def _create_config_proto(self):
     """Create `TPUEmbeddingConfiguration`."""
-    self._learning_rate_keys = list(
-        set(c.learning_rate_key
-            for c in self._table_to_config_dict.values()
-            if c.learning_rate_key is not None))
     config_proto = elc.TPUEmbeddingConfiguration()
     for table in self._table_to_config_dict:
       table_descriptor = config_proto.table_descriptor.add()
@@ -788,9 +789,9 @@
       parameters = table_descriptor.optimization_parameters
       if table_config.learning_rate:
         parameters.learning_rate.constant = (table_config.learning_rate)
-      elif table_config.learning_rate_key:
+      elif table_config.learning_rate_fn:
         parameters.learning_rate.dynamic.tag = (
-            self._learning_rate_keys.index(table_config.learning_rate_key))
+            self._learning_rate_fn_to_tag[table_config.learning_rate_fn])
       else:
         parameters.learning_rate.constant = (
             self._optimization_parameters.learning_rate)
@@ -1097,14 +1098,13 @@
 
   def generate_send_gradients_op(self,
                                  feature_to_gradient_dict,
-                                 learning_rates=None):
+                                 step=None):
     """Send gradient to TPU embedding.
 
     Args:
       feature_to_gradient_dict: dict mapping feature names to gradient wrt
         activations.
-      learning_rates: dict mapping from learning rate key to dynamic learning
-        rate. Defaults to `None`.
+      step: the current global step, used for dynamic learning rate.
 
     Returns:
       SendTPUEmbeddingGradients Op.
@@ -1116,9 +1116,8 @@
       raise RuntimeError('Only in training mode gradients need to '
                          'be sent to TPU embedding; got mode {}.'
                          .format(self._mode))
-
-    if learning_rates is None:
-      learning_rates = dict()
+    if step is None and self._learning_rate_fn:
+      raise ValueError('There are dynamic learning rates but step is None.')
 
     gradients = []
     for table in self._table_to_features_dict:
@@ -1137,9 +1136,8 @@
 
     return tpu_ops.send_tpu_embedding_gradients(
         inputs=gradients,
-        learning_rates=[
-            learning_rates[tag] for tag in self._learning_rate_keys
-        ],
+        learning_rates=[math_ops.cast(fn(step), dtype=dtypes.float32)
+                        for fn in self._learning_rate_fn],
         config=self.config_proto.SerializeToString())