| # Copyright 2018 The TensorFlow Authors. All Rights Reserved. |
| # |
| # Licensed under the Apache License, Version 2.0 (the "License"); |
| # you may not use this file except in compliance with the License. |
| # You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| # ============================================================================== |
| """TPU embedding APIs.""" |
| |
| import collections |
| import copy |
| import math |
| import re |
| |
| from typing import Optional |
| |
| import six |
| |
| from tensorflow.core.protobuf.tpu import optimization_parameters_pb2 |
| from tensorflow.core.protobuf.tpu import tpu_embedding_configuration_pb2 as elc |
| from tensorflow.python.eager import context |
| from tensorflow.python.framework import dtypes |
| from tensorflow.python.framework import ops |
| from tensorflow.python.ops import array_ops |
| from tensorflow.python.ops import control_flow_ops |
| from tensorflow.python.ops import init_ops |
| from tensorflow.python.ops import math_ops |
| from tensorflow.python.ops import partitioned_variables |
| from tensorflow.python.ops import state_ops |
| from tensorflow.python.ops import variable_scope |
| from tensorflow.python.platform import tf_logging as logging |
| from tensorflow.python.tpu import tpu_system_metadata as tpu_system_metadata_lib |
| from tensorflow.python.tpu.ops import tpu_ops |
| from tensorflow.python.util.tf_export import tf_export |
| |
| TRAINING = elc.TPUEmbeddingConfiguration.TRAINING |
| INFERENCE = elc.TPUEmbeddingConfiguration.INFERENCE |
| |
| |
| # TODO(shizhiw): a more future-proof way is to have optimization_parameter such |
| # as AdagradParameters etc instead of learning_rate. |
| class TableConfig( |
| collections.namedtuple('TableConfig', [ |
| 'vocabulary_size', |
| 'dimension', |
| 'initializer', |
| 'combiner', |
| 'hot_id_replication', |
| 'learning_rate', |
| 'learning_rate_fn', |
| 'optimization_parameters', |
| ])): |
| """Embedding table configuration.""" |
| |
| def __new__(cls, |
| vocabulary_size, |
| dimension, |
| initializer=None, |
| combiner='mean', |
| hot_id_replication=False, |
| learning_rate=None, |
| learning_rate_fn=None, |
| optimization_parameters=None): |
| """Embedding table configuration. |
| |
| Args: |
| vocabulary_size: Number of vocabulary (/rows) in the table. |
| dimension: The embedding dimension. |
| initializer: A variable initializer function to be used in embedding |
| variable initialization. If not specified, defaults to |
| `tf.compat.v1.truncated_normal_initializer` with mean `0.0` and standard |
| deviation `1/sqrt(dimension)`. |
| combiner: A string specifying how to reduce if there are multiple entries |
| in a single row. Currently 'mean', 'sqrtn', 'sum' and None are |
| supported, with 'mean' the default. 'sqrtn' often achieves good |
| accuracy, in particular with bag-of-words columns. For more information, |
| see `tf.nn.embedding_lookup_sparse`. None is only valid for dense rather |
| than sparse tensors. |
| hot_id_replication: If true, enables hot id replication, which can make |
| embedding lookups faster if there are some hot rows in the table. |
| learning_rate: float, static learning rate for this table. If |
| learning_rate and learning_rate_fn are both `None`, static learning rate |
| as specified in local `optimization_parameters` will be used. In case |
| local `optimization_parameters` is `None`, global |
| `optimization_parameters` in `TPUEmbedding` constructor will be used. |
| `learning_rate_fn` must be `None` if `learning_rate` is not `None. |
| learning_rate_fn: string, use dynamic learning rate given by the function. |
| This function will be passed the current global step. If learning_rate |
| and learning_rate_fn are both `None`, static learning rate as specified |
| in `optimization_parameters` is used. `learning_rate` must be `None` if |
| `learning_rate_fn` is not `None. |
| optimization_parameters: `AdagradParameters`, `AdamParameters`, |
| `Stochasticgradientdescentparameters`. Specifies table level optimizer. |
| If it's `None` global optimizer in `TPUEmbedding` constructor is used. |
| |
| Returns: |
| `TableConfig`. |
| |
| Raises: |
| ValueError: if `vocabulary_size` is not positive integer. |
| ValueError: if `dimension` is not positive integer. |
| ValueError: if `initializer` is specified and is not callable. |
| ValueError: if `combiner` is not supported. |
| ValueError: if `learning_rate` and `learning_rate_fn` are both not |
| `None`. |
| """ |
| if not isinstance(vocabulary_size, int) or vocabulary_size < 1: |
| raise ValueError(f'vocabulary_size must >= 1. ' |
| f'Received: {vocabulary_size}.') |
| |
| if not isinstance(dimension, int) or dimension < 1: |
| raise ValueError( |
| f'dimension must be a positive int. Received: {dimension}.') |
| |
| if (initializer is not None) and (not callable(initializer)): |
| raise ValueError(f'initializer must be callable if specified. ' |
| f'Received: {initializer}.') |
| if initializer is None: |
| initializer = init_ops.truncated_normal_initializer( |
| mean=0.0, stddev=1 / math.sqrt(dimension)) |
| |
| if combiner not in ('mean', 'sum', 'sqrtn', None): |
| raise ValueError(f'combiner must be "mean", "sum", "sqrtn" or None. ' |
| f'Received: {combiner}.') |
| |
| if learning_rate is not None and learning_rate_fn is not None: |
| raise ValueError('At most one of learning_rate and learning_rate_fn ' |
| 'can be None. Received: {} and {}'.format( |
| learning_rate, learning_rate_fn)) |
| |
| if optimization_parameters is not None: |
| if not isinstance(optimization_parameters, _OptimizationParameters): |
| raise ValueError(f'`optimization_parameters` must inherit from ' |
| f'`_OptimizationParameters`. ' |
| f'Received: `type(optimization_parameters)`=' |
| f'{type(optimization_parameters)}.') |
| |
| return super(TableConfig, |
| cls).__new__(cls, vocabulary_size, dimension, initializer, |
| combiner, hot_id_replication, learning_rate, |
| learning_rate_fn, optimization_parameters) |
| |
| |
| class FeatureConfig( |
| collections.namedtuple('FeatureConfig', |
| ['table_id', 'max_sequence_length', 'weight_key'])): |
| """Feature configuration.""" |
| |
| def __new__(cls, table_id, max_sequence_length=0, weight_key=None): |
| """Feature configuration. |
| |
| Args: |
| table_id: Which table the feature is uses for embedding lookups. |
| max_sequence_length: If positive, the feature is a sequence feature with |
| the corresponding maximum sequence length. If the sequence is longer |
| than this, it will be truncated. If 0, the feature is not a sequence |
| feature. |
| weight_key: If using weights for the combiner, this key specifies which |
| input feature contains the weights. |
| |
| Returns: |
| `FeatureConfig`. |
| |
| Raises: |
| ValueError: if `max_sequence_length` non-integer or negative. |
| """ |
| if not isinstance(max_sequence_length, int) or max_sequence_length < 0: |
| raise ValueError(f'max_sequence_length must be zero or a positive int, ' |
| f'got {max_sequence_length}.') |
| |
| return super(FeatureConfig, cls).__new__(cls, table_id, max_sequence_length, |
| weight_key) |
| |
| |
| class EnqueueData( |
| collections.namedtuple( |
| 'EnqueueData', |
| ['embedding_indices', 'sample_indices', 'aggregation_weights'])): |
| """Data to be enqueued through generate_enqueue_ops().""" |
| |
| def __new__(cls, |
| embedding_indices, |
| sample_indices=None, |
| aggregation_weights=None): |
| """Data to be enqueued through generate_enqueue_ops(). |
| |
| Args: |
| embedding_indices: A rank 1 Tensor, indices into the embedding tables. It |
| corresponds to sp_ids.values in embedding_lookup_sparse(). Both int32 |
| and int64 are allowed and will be converted to int32 internally. |
| sample_indices: A rank 2 Tensor specifying the training example to which |
| the corresponding embedding_indices and aggregation_weights values |
| belong. It corresponds to sp_ids.indices in embedding_lookup_sparse(). |
| If it is None, we assume each embedding_indices belongs to a different |
| sample. Both int32 and int64 are allowed and will be converted to int32 |
| internally. |
| aggregation_weights: A rank 1 Tensor containing aggregation weights. It |
| corresponds to sp_weights.values in embedding_lookup_sparse(). If it is |
| None, we assume all weights are 1. Both float32 and float64 are allowed |
| and will be converted to float32 internally. |
| |
| Returns: |
| An EnqueueData tuple. |
| |
| """ |
| return super(EnqueueData, cls).__new__(cls, embedding_indices, |
| sample_indices, aggregation_weights) |
| |
| @staticmethod |
| def from_sparse_tensor(sp_tensor, weights=None): |
| return EnqueueData( |
| sp_tensor.values, |
| sp_tensor.indices, |
| aggregation_weights=weights.values if weights is not None else None) |
| |
| |
| class RaggedEnqueueData( |
| collections.namedtuple( |
| 'RaggedEnqueueData', |
| ['embedding_indices', 'sample_splits', 'aggregation_weights'])): |
| """RaggedTensor Data to be enqueued through generate_enqueue_ops().""" |
| |
| def __new__(cls, |
| embedding_indices, |
| sample_splits=None, |
| aggregation_weights=None): |
| """Data to be enqueued through generate_enqueue_ops(). |
| |
| Args: |
| embedding_indices: A rank 1 Tensor, indices into the embedding tables. It |
| corresponds to ids.values in embedding_lookup(), when ids is a |
| RaggedTensor. Both int32 and int64 are allowed and will be converted to |
| int32 internally. |
| sample_splits: A rank 1 Tensor specifying the break points for splitting |
| embedding_indices and aggregation_weights into rows. It corresponds to |
| ids.row_splits in embedding_lookup(), when ids is a RaggedTensor. Both |
| int32 and int64 are allowed and will be converted to int32 internally. |
| aggregation_weights: A rank 1 Tensor containing per training example |
| aggregation weights. It corresponds to the values field of a |
| RaggedTensor with the same row_splits as ids in embedding_lookup(), when |
| ids is a RaggedTensor. |
| |
| Returns: |
| An RaggedEnqueueData tuple. |
| |
| """ |
| return super(RaggedEnqueueData, |
| cls).__new__(cls, embedding_indices, sample_splits, |
| aggregation_weights) |
| |
| @staticmethod |
| def from_ragged_tensor(rg_tensor, weights=None): |
| return RaggedEnqueueData( |
| rg_tensor.values, |
| rg_tensor.row_splits, |
| aggregation_weights=weights.values if weights is not None else None) |
| |
| |
| def get_enqueue_datas_list_from_sparse_tensors_list(sp_tensors_list): |
| """Convenient function for generate_enqueue_ops(). |
| |
| Args: |
| sp_tensors_list: a list of dictionary mapping from string of feature names |
| to SparseTensor. Each dictionary is for one TPU core. Dictionaries for the |
| same host should be contiguous on the list. |
| |
| Returns: |
| enqueue_datas_list: a list of dictionary mapping from string |
| of feature names to EnqueueData. Each dictionary is for one |
| TPU core. Dictionaries for the same host should be contiguous |
| on the list. |
| |
| """ |
| enqueue_datas_list = [] |
| for sp_tensors in sp_tensors_list: |
| enqueue_datas = collections.OrderedDict( |
| (k, EnqueueData.from_sparse_tensor(v)) |
| for k, v in six.iteritems(sp_tensors)) |
| enqueue_datas_list.append(enqueue_datas) |
| return enqueue_datas_list |
| |
| |
| def get_enqueue_datas_list_from_ragged_tensors_list(rg_tensors_list): |
| """Convenient function for generate_enqueue_ops(). |
| |
| Args: |
| rg_tensors_list: a list of dictionary mapping from string of feature names |
| to RaggedTensor. Each dictionary is for one TPU core. Dictionaries for the |
| same host should be contiguous on the list. |
| |
| Returns: |
| enqueue_datas_list: a list of dictionary mapping from string |
| of feature names to RaggedEnqueueData. Each dictionary is for one |
| TPU core. Dictionaries for the same host should be contiguous |
| on the list. |
| |
| """ |
| enqueue_datas_list = [] |
| for rg_tensors in rg_tensors_list: |
| enqueue_datas = collections.OrderedDict( |
| (k, RaggedEnqueueData.from_ragged_tensor(v)) |
| for k, v in six.iteritems(rg_tensors)) |
| enqueue_datas_list.append(enqueue_datas) |
| return enqueue_datas_list |
| |
| |
| AdamSlotVariableNames = collections.namedtuple('AdamSlotVariableNames', |
| ['m', 'v']) |
| |
| AdagradSlotVariableNames = collections.namedtuple('AdagradSlotVariableNames', |
| ['accumulator']) |
| |
| MomentumSlotVariableNames = collections.namedtuple('MomentumSlotVariableNames', |
| ['momenta']) |
| |
| AdagradMomentumSlotVariableNames = collections.namedtuple( |
| 'AdagradMomentumSlotVariableNames', ['accumulator', 'momenta']) |
| |
| RMSPropSlotVariableNames = collections.namedtuple('RMSPropSlotVariableNames', |
| ['ms', 'mom']) |
| |
| ProximalAdagradSlotVariableNames = collections.namedtuple( |
| 'ProximalAdagradSlotVariableNames', ['accumulator']) |
| |
| FtrlSlotVariableNames = collections.namedtuple('FtrlSlotVariableNames', |
| ['accumulator', 'linear']) |
| |
| ProximalYogiSlotVariableNames = collections.namedtuple( |
| 'ProximalYogiSlotVariableNames', ['v', 'm']) |
| |
| FrequencyEstimatorSlotVariableNames = collections.namedtuple( |
| 'FrequencyEstimatorSlotVariableNames', ['last_hit_step']) |
| |
| AdamSlotVariables = collections.namedtuple('AdamSlotVariables', ['m', 'v']) |
| |
| MomentumSlotVariables = collections.namedtuple('MomentumSlotVariables', |
| ['momenta']) |
| |
| AdagradMomentumSlotVariables = collections.namedtuple( |
| 'AdagradMomentumSlotVariables', ['accumulator', 'momenta']) |
| |
| RMSPropSlotVariables = collections.namedtuple('RMSPropSlotVariables', |
| ['ms', 'mom']) |
| |
| AdagradSlotVariables = collections.namedtuple('AdagradSlotVariables', |
| ['accumulator']) |
| |
| ProximalAdagradSlotVariables = collections.namedtuple( |
| 'ProximalAdagradSlotVariables', ['accumulator']) |
| |
| FtrlSlotVariable = collections.namedtuple('FtrlSlotVariable', |
| ['accumulator', 'linear']) |
| |
| ProximalYogiSlotVariables = collections.namedtuple('ProximalYogiSlotVariables', |
| ['v', 'm']) |
| |
| FrequencyEstimatorSlotVariables = collections.namedtuple( |
| 'FrequencyEstimatorSlotVariables', ['last_hit_step']) |
| |
| VariablesAndOps = collections.namedtuple('VariablesAndOps', [ |
| 'embedding_variables_by_table', 'slot_variables_by_table', 'load_ops', |
| 'retrieve_ops' |
| ]) |
| |
| |
| class _OptimizationParameters(object): |
| """Parameters common to all optimizations.""" |
| |
| def __init__( |
| self, |
| learning_rate: float, |
| use_gradient_accumulation: bool, |
| clip_weight_min: Optional[float], |
| clip_weight_max: Optional[float], |
| weight_decay_factor: Optional[float], |
| multiply_weight_decay_factor_by_learning_rate: Optional[bool], |
| clip_gradient_min: Optional[float] = None, |
| clip_gradient_max: Optional[float] = None, |
| ): |
| self.learning_rate = learning_rate |
| self.use_gradient_accumulation = use_gradient_accumulation |
| self.clip_weight_min = clip_weight_min |
| self.clip_weight_max = clip_weight_max |
| self.weight_decay_factor = weight_decay_factor |
| self.multiply_weight_decay_factor_by_learning_rate = ( |
| multiply_weight_decay_factor_by_learning_rate) |
| self.clip_gradient_min = clip_gradient_min |
| self.clip_gradient_max = clip_gradient_max |
| |
| if not use_gradient_accumulation and (clip_gradient_min is not None or |
| clip_gradient_max is not None): |
| raise ValueError('When using gradient clipping limits, gradient ' |
| 'accumulation must be enabled.') |
| |
| |
| @tf_export(v1=['tpu.experimental.AdagradParameters']) |
| class AdagradParameters(_OptimizationParameters): |
| """Optimization parameters for Adagrad with TPU embeddings. |
| |
| Pass this to `tf.estimator.tpu.experimental.EmbeddingConfigSpec` via the |
| `optimization_parameters` argument to set the optimizer and its parameters. |
| See the documentation for `tf.estimator.tpu.experimental.EmbeddingConfigSpec` |
| for more details. |
| |
| ``` |
| estimator = tf.estimator.tpu.TPUEstimator( |
| ... |
| embedding_spec=tf.estimator.tpu.experimental.EmbeddingConfigSpec( |
| ... |
| optimization_parameters=tf.tpu.experimental.AdagradParameters(0.1), |
| ...)) |
| ``` |
| |
| """ |
| |
| def __init__( |
| self, |
| learning_rate: float, |
| initial_accumulator: float = 0.1, |
| use_gradient_accumulation: bool = True, |
| clip_weight_min: Optional[float] = None, |
| clip_weight_max: Optional[float] = None, |
| weight_decay_factor: Optional[float] = None, |
| multiply_weight_decay_factor_by_learning_rate: Optional[bool] = None, |
| clip_gradient_min: Optional[float] = None, |
| clip_gradient_max: Optional[float] = None, |
| ): |
| """Optimization parameters for Adagrad. |
| |
| Args: |
| learning_rate: used for updating embedding table. |
| initial_accumulator: initial accumulator for Adagrad. |
| use_gradient_accumulation: setting this to `False` makes embedding |
| gradients calculation less accurate but faster. Please see |
| `optimization_parameters.proto` for details. |
| clip_weight_min: the minimum value to clip by; None means -infinity. |
| clip_weight_max: the maximum value to clip by; None means +infinity. |
| weight_decay_factor: amount of weight decay to apply; None means that the |
| weights are not decayed. |
| multiply_weight_decay_factor_by_learning_rate: if true, |
| `weight_decay_factor` is multiplied by the current learning rate. |
| clip_gradient_min: the minimum value to clip by; None means -infinity. |
| Gradient accumulation must be set to true if this is set. |
| clip_gradient_max: the maximum value to clip by; None means +infinity. |
| Gradient accumulation must be set to true if this is set. |
| """ |
| super(AdagradParameters, self).__init__( |
| learning_rate=learning_rate, |
| use_gradient_accumulation=use_gradient_accumulation, |
| clip_weight_min=clip_weight_min, |
| clip_weight_max=clip_weight_max, |
| weight_decay_factor=weight_decay_factor, |
| multiply_weight_decay_factor_by_learning_rate=( |
| multiply_weight_decay_factor_by_learning_rate), |
| clip_gradient_min=clip_gradient_min, |
| clip_gradient_max=clip_gradient_max, |
| ) |
| if initial_accumulator <= 0: |
| raise ValueError( |
| f'Adagrad initial_accumulator must be greater than zero. ' |
| f'Received: {initial_accumulator}.') |
| self.initial_accumulator = initial_accumulator |
| |
| |
| class AdagradMomentumParameters(_OptimizationParameters): |
| """Optimization parameters for Adagrad + Momentum with TPU embeddings. |
| |
| Pass this to `tf.estimator.tpu.experimental.EmbeddingConfigSpec` via the |
| `optimization_parameters` argument to set the optimizer and its parameters. |
| See the documentation for `tf.estimator.tpu.experimental.EmbeddingConfigSpec` |
| for more details. |
| |
| ``` |
| estimator = tf.estimator.tpu.TPUEstimator( |
| ... |
| embedding_spec=tf.estimator.tpu.experimental.EmbeddingConfigSpec( |
| ... |
| optimization_parameters=tf.tpu.experimental.AdagradMomentumParameters(0.1), |
| ...)) |
| ``` |
| |
| """ |
| |
| def __init__( |
| self, |
| learning_rate: float, |
| momentum: float, |
| use_nesterov: bool = False, |
| exponent: float = 2, |
| beta2: float = 1, |
| epsilon: float = 1e-10, |
| use_gradient_accumulation: bool = True, |
| clip_weight_min: Optional[float] = None, |
| clip_weight_max: Optional[float] = None, |
| weight_decay_factor: Optional[float] = None, |
| multiply_weight_decay_factor_by_learning_rate: Optional[bool] = None, |
| clip_gradient_min: Optional[float] = None, |
| clip_gradient_max: Optional[float] = None, |
| ): |
| """Optimization parameters for Adagrad. |
| |
| Args: |
| learning_rate: used for updating embedding table. |
| momentum: Moving average parameter for the momentum accumulator. |
| use_nesterov: Whether to use the Nesterov variant of momentum. See |
| Sutskever et al., 2013. |
| exponent: Exponent for the Adagrad accumulator. |
| beta2: Moving average parameter for the Adagrad accumulator. |
| epsilon: initial accumulator for Adagrad accumulator. |
| use_gradient_accumulation: setting this to `False` makes embedding |
| gradients calculation less accurate but faster. Please see |
| `optimization_parameters.proto` for details. |
| clip_weight_min: the minimum value to clip by; None means -infinity. |
| clip_weight_max: the maximum value to clip by; None means +infinity. |
| weight_decay_factor: amount of weight decay to apply; None means that the |
| weights are not decayed. |
| multiply_weight_decay_factor_by_learning_rate: if true, |
| `weight_decay_factor` is multiplied by the current learning rate. |
| clip_gradient_min: the minimum value to clip by; None means -infinity. |
| Gradient accumulation must be set to true if this is set. |
| clip_gradient_max: the maximum value to clip by; None means +infinity. |
| Gradient accumulation must be set to true if this is set. |
| """ |
| super(AdagradMomentumParameters, self).__init__( |
| learning_rate=learning_rate, |
| use_gradient_accumulation=use_gradient_accumulation, |
| clip_weight_min=clip_weight_min, |
| clip_weight_max=clip_weight_max, |
| weight_decay_factor=weight_decay_factor, |
| multiply_weight_decay_factor_by_learning_rate=( |
| multiply_weight_decay_factor_by_learning_rate), |
| clip_gradient_min=clip_gradient_min, |
| clip_gradient_max=clip_gradient_max, |
| ) |
| if epsilon <= 0: |
| raise ValueError('Adagrad momentum: epsilon must be positive') |
| if exponent <= 0: |
| raise ValueError('Adagrad momentum: Precondition exponent must >0') |
| self.momentum = momentum |
| self.use_nesterov = use_nesterov |
| self.exponent = exponent |
| self.beta2 = beta2 |
| self.epsilon = epsilon |
| |
| |
| class ProximalAdagradParameters(_OptimizationParameters): |
| """Optimization parameters for ProximalAdagrad with TPU embeddings. |
| |
| Pass this to `tf.estimator.tpu.experimental.EmbeddingConfigSpec` via the |
| `optimization_parameters` argument to set the optimizer and its parameters. |
| See the documentation for `tf.estimator.tpu.experimental.EmbeddingConfigSpec` |
| for more details. |
| """ |
| |
| def __init__( |
| self, |
| learning_rate: float, |
| initial_accumulator: float = 0.1, |
| l1_regularization_strength: float = 0.0, |
| l2_regularization_strength: float = 0.0, |
| use_gradient_accumulation: bool = True, |
| clip_weight_min: Optional[float] = None, |
| clip_weight_max: Optional[float] = None, |
| weight_decay_factor: Optional[float] = None, |
| multiply_weight_decay_factor_by_learning_rate: Optional[bool] = None, |
| clip_gradient_min: Optional[float] = None, |
| clip_gradient_max: Optional[float] = None, |
| ): |
| """Optimization parameters for Adagrad. |
| |
| Args: |
| learning_rate: used for updating embedding table. |
| initial_accumulator: initial accumulator for Adagrad. |
| l1_regularization_strength: A float value, must be greater than or equal |
| to zero. |
| l2_regularization_strength: A float value, must be greater than or equal |
| to zero. |
| use_gradient_accumulation: setting this to `False` makes embedding |
| gradients calculation less accurate but faster. Please see |
| `optimization_parameters.proto` for details. for details. |
| clip_weight_min: the minimum value to clip by; None means -infinity. |
| clip_weight_max: the maximum value to clip by; None means +infinity. |
| weight_decay_factor: amount of weight decay to apply; None means that the |
| weights are not decayed. |
| multiply_weight_decay_factor_by_learning_rate: if true, |
| `weight_decay_factor` is multiplied by the current learning rate. |
| clip_gradient_min: the minimum value to clip by; None means -infinity. |
| Gradient accumulation must be set to true if this is set. |
| clip_gradient_max: the maximum value to clip by; None means +infinity. |
| Gradient accumulation must be set to true if this is set. |
| """ |
| super(ProximalAdagradParameters, self).__init__( |
| learning_rate=learning_rate, |
| use_gradient_accumulation=use_gradient_accumulation, |
| clip_weight_min=clip_weight_min, |
| clip_weight_max=clip_weight_max, |
| weight_decay_factor=weight_decay_factor, |
| multiply_weight_decay_factor_by_learning_rate=( |
| multiply_weight_decay_factor_by_learning_rate), |
| clip_gradient_min=clip_gradient_min, |
| clip_gradient_max=clip_gradient_max, |
| ) |
| if initial_accumulator <= 0: |
| raise ValueError(f'Adagrad initial_accumulator must be positive. ' |
| f'Received: {initial_accumulator}.') |
| if l1_regularization_strength < 0.: |
| raise ValueError('l1_regularization_strength must be greater than or ' |
| 'equal to 0. got {}.'.format(l1_regularization_strength)) |
| |
| if l2_regularization_strength < 0.: |
| raise ValueError('l2_regularization_strength must be greater than or ' |
| 'equal to 0. got {}.'.format(l2_regularization_strength)) |
| |
| self.initial_accumulator = initial_accumulator |
| self.l1_regularization_strength = l1_regularization_strength |
| self.l2_regularization_strength = l2_regularization_strength |
| |
| |
| @tf_export(v1=['tpu.experimental.AdamParameters']) |
| class AdamParameters(_OptimizationParameters): |
| """Optimization parameters for Adam with TPU embeddings. |
| |
| Pass this to `tf.estimator.tpu.experimental.EmbeddingConfigSpec` via the |
| `optimization_parameters` argument to set the optimizer and its parameters. |
| See the documentation for `tf.estimator.tpu.experimental.EmbeddingConfigSpec` |
| for more details. |
| |
| ``` |
| estimator = tf.estimator.tpu.TPUEstimator( |
| ... |
| embedding_config_spec=tf.estimator.tpu.experimental.EmbeddingConfigSpec( |
| ... |
| optimization_parameters=tf.tpu.experimental.AdamParameters(0.1), |
| ...)) |
| ``` |
| |
| """ |
| |
| def __init__( |
| self, |
| learning_rate: float, |
| beta1: float = 0.9, |
| beta2: float = 0.999, |
| epsilon: float = 1e-08, |
| lazy_adam: bool = True, |
| sum_inside_sqrt: bool = True, |
| use_gradient_accumulation: bool = True, |
| clip_weight_min: Optional[float] = None, |
| clip_weight_max: Optional[float] = None, |
| weight_decay_factor: Optional[float] = None, |
| multiply_weight_decay_factor_by_learning_rate: Optional[bool] = None, |
| clip_gradient_min: Optional[float] = None, |
| clip_gradient_max: Optional[float] = None, |
| ): |
| """Optimization parameters for Adam. |
| |
| Args: |
| learning_rate: a floating point value. The learning rate. |
| beta1: A float value. The exponential decay rate for the 1st moment |
| estimates. |
| beta2: A float value. The exponential decay rate for the 2nd moment |
| estimates. |
| epsilon: A small constant for numerical stability. |
| lazy_adam: Use lazy Adam instead of Adam. Lazy Adam trains faster. See |
| `optimization_parameters.proto` for details. |
| sum_inside_sqrt: This improves training speed. Please see |
| `optimization_parameters.proto` for details. |
| use_gradient_accumulation: setting this to `False` makes embedding |
| gradients calculation less accurate but faster. Please see |
| `optimization_parameters.proto` for details. |
| clip_weight_min: the minimum value to clip by; None means -infinity. |
| clip_weight_max: the maximum value to clip by; None means +infinity. |
| weight_decay_factor: amount of weight decay to apply; None means that the |
| weights are not decayed. |
| multiply_weight_decay_factor_by_learning_rate: if true, |
| `weight_decay_factor` is multiplied by the current learning rate. |
| clip_gradient_min: the minimum value to clip by; None means -infinity. |
| Gradient accumulation must be set to true if this is set. |
| clip_gradient_max: the maximum value to clip by; None means +infinity. |
| Gradient accumulation must be set to true if this is set. |
| """ |
| super(AdamParameters, self).__init__( |
| learning_rate=learning_rate, |
| use_gradient_accumulation=use_gradient_accumulation, |
| clip_weight_min=clip_weight_min, |
| clip_weight_max=clip_weight_max, |
| weight_decay_factor=weight_decay_factor, |
| multiply_weight_decay_factor_by_learning_rate=( |
| multiply_weight_decay_factor_by_learning_rate), |
| clip_gradient_min=clip_gradient_min, |
| clip_gradient_max=clip_gradient_max, |
| ) |
| if beta1 < 0. or beta1 >= 1.: |
| raise ValueError('beta1 must be between 0. and 1; got {}.'.format(beta1)) |
| if beta2 < 0. or beta2 >= 1.: |
| raise ValueError('beta2 must be between 0. and 1; got {}.'.format(beta2)) |
| if epsilon <= 0.: |
| raise ValueError('epsilon must be positive; got {}.'.format(epsilon)) |
| if not use_gradient_accumulation and not lazy_adam: |
| raise ValueError( |
| 'When disabling Lazy Adam, gradient accumulation must be used.') |
| |
| self.beta1 = beta1 |
| self.beta2 = beta2 |
| self.epsilon = epsilon |
| self.lazy_adam = lazy_adam |
| self.sum_inside_sqrt = sum_inside_sqrt |
| |
| |
| @tf_export(v1=['tpu.experimental.FtrlParameters']) |
| class FtrlParameters(_OptimizationParameters): |
| """Optimization parameters for Ftrl with TPU embeddings. |
| |
| Pass this to `tf.estimator.tpu.experimental.EmbeddingConfigSpec` via the |
| `optimization_parameters` argument to set the optimizer and its parameters. |
| See the documentation for `tf.estimator.tpu.experimental.EmbeddingConfigSpec` |
| for more details. |
| |
| ``` |
| estimator = tf.estimator.tpu.TPUEstimator( |
| ... |
| embedding_config_spec=tf.estimator.tpu.experimental.EmbeddingConfigSpec( |
| ... |
| optimization_parameters=tf.tpu.experimental.FtrlParameters(0.1), |
| ...)) |
| ``` |
| |
| """ |
| |
| def __init__( |
| self, |
| learning_rate: float, |
| learning_rate_power: float = -0.5, |
| initial_accumulator_value: float = 0.1, |
| l1_regularization_strength: float = 0.0, |
| l2_regularization_strength: float = 0.0, |
| use_gradient_accumulation: bool = True, |
| clip_weight_min: Optional[float] = None, |
| clip_weight_max: Optional[float] = None, |
| weight_decay_factor: Optional[float] = None, |
| multiply_weight_decay_factor_by_learning_rate: Optional[bool] = None, |
| multiply_linear_by_learning_rate: bool = False, |
| beta: float = 0, |
| allow_zero_accumulator: bool = False, |
| clip_gradient_min: Optional[float] = None, |
| clip_gradient_max: Optional[float] = None, |
| ): |
| """Optimization parameters for Ftrl. |
| |
| Implements FTRL as described in the following [paper]( |
| https://static.googleusercontent.com/media/research.google.com/en//pubs/archive/41159.pdf) |
| |
| Args: |
| learning_rate: a floating point value. The learning rate. |
| learning_rate_power: A float value, must be less or equal to zero. |
| Controls how the learning rate decreases during training. Use zero for a |
| fixed learning rate. See section 3.1 in the |
| [paper](https://www.eecs.tufts.edu/~dsculley/papers/ad-click-prediction.pdf). |
| initial_accumulator_value: The starting value for accumulators. Only zero |
| or positive values are allowed. |
| l1_regularization_strength: A float value, must be greater than or equal |
| to zero. |
| l2_regularization_strength: A float value, must be greater than or equal |
| to zero. |
| use_gradient_accumulation: setting this to `False` makes embedding |
| gradients calculation less accurate but faster. Please see |
| `optimization_parameters.proto` for details. for details. |
| clip_weight_min: the minimum value to clip by; None means -infinity. |
| clip_weight_max: the maximum value to clip by; None means +infinity. |
| weight_decay_factor: amount of weight decay to apply; None means that the |
| weights are not decayed. |
| multiply_weight_decay_factor_by_learning_rate: if true, |
| `weight_decay_factor` is multiplied by the current learning rate. |
| multiply_linear_by_learning_rate: When true, multiplies the usages of the |
| linear slot in the weight update by the learning rate. This is useful |
| when ramping up learning rate from 0 (which would normally produce |
| NaNs). |
| beta: The beta parameter for FTRL. |
| allow_zero_accumulator: Changes the implementation of the square root to |
| allow for the case of initial_accumulator_value being zero. This will |
| cause a slight performance drop. |
| clip_gradient_min: the minimum value to clip by; None means -infinity. |
| Gradient accumulation must be set to true if this is set. |
| clip_gradient_max: the maximum value to clip by; None means +infinity. |
| Gradient accumulation must be set to true if this is set. |
| """ |
| super(FtrlParameters, self).__init__( |
| learning_rate=learning_rate, |
| use_gradient_accumulation=use_gradient_accumulation, |
| clip_weight_min=clip_weight_min, |
| clip_weight_max=clip_weight_max, |
| weight_decay_factor=weight_decay_factor, |
| multiply_weight_decay_factor_by_learning_rate=( |
| multiply_weight_decay_factor_by_learning_rate), |
| clip_gradient_min=clip_gradient_min, |
| clip_gradient_max=clip_gradient_max, |
| ) |
| if learning_rate_power > 0.: |
| raise ValueError('learning_rate_power must be less than or equal to 0. ' |
| 'got {}.'.format(learning_rate_power)) |
| |
| if initial_accumulator_value < 0.: |
| raise ValueError('initial_accumulator_value must be greater than or equal' |
| ' to 0. got {}.'.format(initial_accumulator_value)) |
| |
| if l1_regularization_strength < 0.: |
| raise ValueError('l1_regularization_strength must be greater than or ' |
| 'equal to 0. got {}.'.format(l1_regularization_strength)) |
| |
| if l2_regularization_strength < 0.: |
| raise ValueError('l2_regularization_strength must be greater than or ' |
| 'equal to 0. got {}.'.format(l2_regularization_strength)) |
| |
| self.learning_rate_power = learning_rate_power |
| self.initial_accumulator_value = initial_accumulator_value |
| self.initial_linear_value = 0.0 |
| self.l1_regularization_strength = l1_regularization_strength |
| self.l2_regularization_strength = l2_regularization_strength |
| self.multiply_linear_by_learning_rate = multiply_linear_by_learning_rate |
| self.beta = beta |
| self.allow_zero_accumulator = allow_zero_accumulator |
| |
| |
| class ProximalYogiParameters(_OptimizationParameters): |
| # pylint: disable=line-too-long |
| """Optimization parameters for Proximal Yogi with TPU embeddings. |
| |
| Implements the Yogi optimizer as described in |
| [Adaptive Methods for Nonconvex |
| Optimization](https://papers.nips.cc/paper/8186-adaptive-methods-for-nonconvex-optimization). |
| |
| Pass this to `tf.estimator.tpu.experimental.EmbeddingConfigSpec` via the |
| `optimization_parameters` argument to set the optimizer and its parameters. |
| See the documentation for `tf.estimator.tpu.experimental.EmbeddingConfigSpec` |
| for more details. |
| """ |
| |
| # pylint: enable=line-too-long |
| |
| def __init__( |
| self, |
| learning_rate: float = 0.01, |
| beta1: float = 0.9, |
| beta2: float = 0.999, |
| epsilon: float = 1e-3, |
| l1_regularization_strength: float = 0.0, |
| l2_regularization_strength: float = 0.0, |
| initial_accumulator_value: float = 1e-6, |
| use_gradient_accumulation: bool = True, |
| clip_weight_min: Optional[float] = None, |
| clip_weight_max: Optional[float] = None, |
| weight_decay_factor: Optional[float] = None, |
| multiply_weight_decay_factor_by_learning_rate: Optional[bool] = None, |
| clip_gradient_min: Optional[float] = None, |
| clip_gradient_max: Optional[float] = None, |
| ): |
| """Optimization parameters for Proximal Yogi. |
| |
| Args: |
| learning_rate: a floating point value. The learning rate. |
| beta1: A float value. The exponential decay rate for the 1st moment |
| estimates. |
| beta2: A float value. The exponential decay rate for the 2nd moment |
| estimates. |
| epsilon: A small constant for numerical stability. |
| l1_regularization_strength: A float value, must be greater than or equal |
| to zero. |
| l2_regularization_strength: A float value, must be greater than or equal |
| to zero. |
| initial_accumulator_value: The starting value for accumulators. Only zero |
| or positive values are allowed. |
| use_gradient_accumulation: setting this to `False` makes embedding |
| gradients calculation less accurate but faster. Please see |
| `optimization_parameters.proto` for details. for details. |
| clip_weight_min: the minimum value to clip by; None means -infinity. |
| clip_weight_max: the maximum value to clip by; None means +infinity. |
| weight_decay_factor: amount of weight decay to apply; None means that the |
| weights are not decayed. |
| multiply_weight_decay_factor_by_learning_rate: if true, |
| `weight_decay_factor` is multiplied by the current learning rate. |
| clip_gradient_min: the minimum value to clip by; None means -infinity. |
| Gradient accumulation must be set to true if this is set. |
| clip_gradient_max: the maximum value to clip by; None means +infinity. |
| Gradient accumulation must be set to true if this is set. |
| """ |
| super(ProximalYogiParameters, self).__init__( |
| learning_rate=learning_rate, |
| use_gradient_accumulation=use_gradient_accumulation, |
| clip_weight_min=clip_weight_min, |
| clip_weight_max=clip_weight_max, |
| weight_decay_factor=weight_decay_factor, |
| multiply_weight_decay_factor_by_learning_rate=( |
| multiply_weight_decay_factor_by_learning_rate), |
| clip_gradient_min=clip_gradient_min, |
| clip_gradient_max=clip_gradient_max, |
| ) |
| if beta1 < 0. or beta1 >= 1.: |
| raise ValueError('beta1 must be between 0. and 1; got {}.'.format(beta1)) |
| if beta2 < 0. or beta2 >= 1.: |
| raise ValueError('beta2 must be between 0. and 1; got {}.'.format(beta2)) |
| if epsilon <= 0.: |
| raise ValueError('epsilon must be positive; got {}.'.format(epsilon)) |
| if l1_regularization_strength < 0.: |
| raise ValueError('l1_regularization_strength must be greater than or ' |
| 'equal to 0. got {}.'.format(l1_regularization_strength)) |
| if l2_regularization_strength < 0.: |
| raise ValueError('l2_regularization_strength must be greater than or ' |
| 'equal to 0. got {}.'.format(l2_regularization_strength)) |
| |
| self.beta1 = beta1 |
| self.beta2 = beta2 |
| self.epsilon = epsilon |
| self.l1_regularization_strength = l1_regularization_strength |
| self.l2_regularization_strength = l2_regularization_strength |
| self.initial_accumulator_value = initial_accumulator_value |
| |
| |
| class MomentumParameters(_OptimizationParameters): |
| """Optimization parameters for Momentum with TPU embeddings. |
| |
| Pass this to `tf.estimator.tpu.experimental.EmbeddingConfigSpec` via the |
| `optimization_parameters` argument to set the optimizer and its parameters. |
| See the documentation for `tf.estimator.tpu.experimental.EmbeddingConfigSpec` |
| for more details. |
| |
| ``` |
| estimator = tf.estimator.tpu.TPUEstimator( |
| ... |
| embedding_spec=tf.estimator.tpu.experimental.EmbeddingConfigSpec( |
| ... |
| optimization_parameters=tf.tpu.experimental.MomentumParameters(0.1), |
| ...)) |
| ``` |
| |
| """ |
| |
| def __init__( |
| self, |
| learning_rate: float, |
| momentum: float, |
| use_nesterov: bool = False, |
| use_gradient_accumulation: bool = True, |
| clip_weight_min: Optional[float] = None, |
| clip_weight_max: Optional[float] = None, |
| weight_decay_factor: Optional[float] = None, |
| multiply_weight_decay_factor_by_learning_rate: Optional[bool] = None, |
| clip_gradient_min: Optional[float] = None, |
| clip_gradient_max: Optional[float] = None, |
| ): |
| """Optimization parameters for momentum. |
| |
| Args: |
| learning_rate: a floating point value. The learning rate. |
| momentum: a floating point value. The momentum. |
| use_nesterov: If `True` use Nesterov Momentum. See (Sutskever et al., |
| 2013). This implementation always computes gradients at the value of the |
| variable(s) passed to the optimizer. Using Nesterov Momentum makes the |
| variable(s) track the values called `theta_t + mu*v_t` in the paper. |
| This implementation is an approximation of the original formula, valid |
| for high values of momentum. It will compute the "adjusted gradient" in |
| NAG by assuming that the new gradient will be estimated by the current |
| average gradient plus the product of momentum and the change in the |
| average gradient. |
| use_gradient_accumulation: setting this to `False` makes embedding |
| gradients calculation less accurate but faster. Please see |
| `optimization_parameters.proto` for details. |
| clip_weight_min: the minimum value to clip by; None means -infinity. |
| clip_weight_max: the maximum value to clip by; None means +infinity. |
| weight_decay_factor: amount of weight decay to apply; None means that the |
| weights are not decayed. |
| multiply_weight_decay_factor_by_learning_rate: if true, |
| `weight_decay_factor` is multiplied by the current learning rate. |
| clip_gradient_min: the minimum value to clip by; None means -infinity. |
| Gradient accumulation must be set to true if this is set. |
| clip_gradient_max: the maximum value to clip by; None means +infinity. |
| Gradient accumulation must be set to true if this is set. |
| """ |
| super(MomentumParameters, self).__init__( |
| learning_rate=learning_rate, |
| use_gradient_accumulation=use_gradient_accumulation, |
| clip_weight_min=clip_weight_min, |
| clip_weight_max=clip_weight_max, |
| weight_decay_factor=weight_decay_factor, |
| multiply_weight_decay_factor_by_learning_rate=( |
| multiply_weight_decay_factor_by_learning_rate), |
| clip_gradient_min=clip_gradient_min, |
| clip_gradient_max=clip_gradient_max, |
| ) |
| self.momentum = momentum |
| self.use_nesterov = use_nesterov |
| |
| |
| class RMSPropParameters(_OptimizationParameters): |
| """Optimization parameters for RMSProp with TPU embeddings. |
| |
| Pass this to `tf.estimator.tpu.experimental.EmbeddingConfigSpec` via the |
| `optimization_parameters` argument to set the optimizer and its parameters. |
| See the documentation for `tf.estimator.tpu.experimental.EmbeddingConfigSpec` |
| for more details. |
| |
| ``` |
| estimator = tf.estimator.tpu.TPUEstimator( |
| ... |
| embedding_spec=tf.estimator.tpu.experimental.EmbeddingConfigSpec( |
| ... |
| optimization_parameters=tf.tpu.experimental.MomentumParameters(0.1), |
| ...)) |
| ``` |
| |
| """ |
| |
| def __init__( |
| self, |
| learning_rate: float, |
| rho: float, |
| momentum: float, |
| epsilon: float, |
| use_gradient_accumulation: bool = True, |
| clip_weight_min: Optional[float] = None, |
| clip_weight_max: Optional[float] = None, |
| weight_decay_factor: Optional[float] = None, |
| multiply_weight_decay_factor_by_learning_rate: Optional[bool] = None, |
| clip_gradient_min: Optional[float] = None, |
| clip_gradient_max: Optional[float] = None, |
| ): |
| """Optimization parameters for RMS prop. |
| |
| Args: |
| learning_rate: a floating point value. The learning rate. |
| rho: Discounting factor for the history/coming gradient |
| momentum: A scalar tensor. |
| epsilon: Small value to avoid zero denominator. |
| use_gradient_accumulation: setting this to `False` makes embedding |
| gradients calculation less accurate but faster. Please see |
| `optimization_parameters.proto` for details. for details. |
| clip_weight_min: the minimum value to clip by; None means -infinity. |
| clip_weight_max: the maximum value to clip by; None means +infinity. |
| weight_decay_factor: amount of weight decay to apply; None means that the |
| weights are not decayed. |
| multiply_weight_decay_factor_by_learning_rate: if true, |
| `weight_decay_factor` is multiplied by the current learning rate. |
| clip_gradient_min: the minimum value to clip by; None means -infinity. |
| Gradient accumulation must be set to true if this is set. |
| clip_gradient_max: the maximum value to clip by; None means +infinity. |
| Gradient accumulation must be set to true if this is set. |
| """ |
| super(RMSPropParameters, self).__init__( |
| learning_rate=learning_rate, |
| use_gradient_accumulation=use_gradient_accumulation, |
| clip_weight_min=clip_weight_min, |
| clip_weight_max=clip_weight_max, |
| weight_decay_factor=weight_decay_factor, |
| multiply_weight_decay_factor_by_learning_rate=( |
| multiply_weight_decay_factor_by_learning_rate), |
| clip_gradient_min=clip_gradient_min, |
| clip_gradient_max=clip_gradient_max, |
| ) |
| self.rho = rho |
| self.momentum = momentum |
| self.epsilon = epsilon |
| |
| |
| @tf_export(v1=['tpu.experimental.StochasticGradientDescentParameters']) |
| class StochasticGradientDescentParameters(_OptimizationParameters): |
| """Optimization parameters for stochastic gradient descent for TPU embeddings. |
| |
| Pass this to `tf.estimator.tpu.experimental.EmbeddingConfigSpec` via the |
| `optimization_parameters` argument to set the optimizer and its parameters. |
| See the documentation for `tf.estimator.tpu.experimental.EmbeddingConfigSpec` |
| for more details. |
| |
| ``` |
| estimator = tf.estimator.tpu.TPUEstimator( |
| ... |
| embedding_config_spec=tf.estimator.tpu.experimental.EmbeddingConfigSpec( |
| ... |
| optimization_parameters=( |
| tf.tpu.experimental.StochasticGradientDescentParameters(0.1)))) |
| ``` |
| |
| """ |
| |
| def __init__( |
| self, |
| learning_rate: float, |
| clip_weight_min: Optional[float] = None, |
| clip_weight_max: Optional[float] = None, |
| weight_decay_factor: Optional[float] = None, |
| multiply_weight_decay_factor_by_learning_rate: Optional[bool] = None, |
| clip_gradient_min: Optional[float] = None, |
| clip_gradient_max: Optional[float] = None, |
| ): |
| """Optimization parameters for stochastic gradient descent. |
| |
| Args: |
| learning_rate: a floating point value. The learning rate. |
| clip_weight_min: the minimum value to clip by; None means -infinity. |
| clip_weight_max: the maximum value to clip by; None means +infinity. |
| weight_decay_factor: amount of weight decay to apply; None means that the |
| weights are not decayed. |
| multiply_weight_decay_factor_by_learning_rate: if true, |
| `weight_decay_factor` is multiplied by the current learning rate. |
| clip_gradient_min: the minimum value to clip by; None means -infinity. |
| clip_gradient_max: the maximum value to clip by; None means +infinity. |
| """ |
| # Gradient accumulation is generally a no-op for SGD, but if gradient |
| # clipping is enabled, then we must also enable gradient accumulation. |
| # In the other optimizers this up to the user, but we don't give the user |
| # the option to turn gradient accumulation on or off for SGD. |
| use_gradient_accumulation = False |
| if (clip_gradient_min is not None or clip_gradient_max is not None): |
| use_gradient_accumulation = True |
| super(StochasticGradientDescentParameters, self).__init__( |
| learning_rate=learning_rate, |
| use_gradient_accumulation=use_gradient_accumulation, |
| clip_weight_min=clip_weight_min, |
| clip_weight_max=clip_weight_max, |
| weight_decay_factor=weight_decay_factor, |
| multiply_weight_decay_factor_by_learning_rate=( |
| multiply_weight_decay_factor_by_learning_rate), |
| clip_gradient_min=clip_gradient_min, |
| clip_gradient_max=clip_gradient_max, |
| ) |
| |
| |
| class FrequencyEstimatorParameters(_OptimizationParameters): |
| """Optimization parameters for Frequency Estimator TPU embeddings. |
| |
| This is a non-standard optimizer, which returns the estimated frequency of |
| lookup for the feature passed to it. It should only be used on a table of |
| width 1. The gradient fed back to the TPU embedding should always be zero. |
| This can be acomplished via using `tf.stop_gradients` on the feature before |
| using it. |
| |
| You must use the dynamic learning rate mechanism to set the 'learning rate' |
| for this table to be the a float32 cast of the global training step counter. |
| |
| See `tensorflow/core/protobuf/tpu/optimization_parameters.proto` for more |
| details on this optimizer. |
| |
| Pass this to `tf.estimator.tpu.experimental.EmbeddingConfigSpec` via the |
| `optimization_parameters` argument to set the optimizer and its parameters. |
| See the documentation for `tf.estimator.tpu.experimental.EmbeddingConfigSpec` |
| for more details. |
| |
| ``` |
| estimator = tf.estimator.tpu.TPUEstimator( |
| ... |
| embedding_spec=tf.estimator.tpu.experimental.EmbeddingConfigSpec( |
| ... |
| optimization_parameters=FrequencyEstimatorParameters(0.1), |
| ...)) |
| ``` |
| |
| """ |
| |
| def __init__(self, tau: float, max_delta: float, outlier_threshold: float, |
| weight_exponent: float): |
| """Optimization parameters for frequency estimator. |
| |
| Args: |
| tau: Learning rate between (0, 1) that is used to update the array. |
| max_delta: Maximum value of delta, the difference between the current |
| global step and the last global step at which the row was sampled. |
| outlier_threshold: Threshold used to determine whether the current update |
| is an outlier. |
| weight_exponent: The weight exponent used to transform the estimated delta |
| into weights. |
| """ |
| super(FrequencyEstimatorParameters, self).__init__( |
| learning_rate=1.0, |
| use_gradient_accumulation=True, |
| clip_weight_min=None, |
| clip_weight_max=None, |
| weight_decay_factor=None, |
| multiply_weight_decay_factor_by_learning_rate=None, |
| ) |
| self.tau = tau |
| self.max_delta = max_delta |
| self.outlier_threshold = outlier_threshold |
| self.weight_exponent = weight_exponent |
| |
| |
| DeviceConfig = collections.namedtuple('DeviceConfig', |
| ['num_hosts', 'num_cores', 'job_name']) |
| |
| |
| class TPUEmbedding(object): |
| """API for using TPU for embedding. |
| |
| Example: |
| ``` |
| table_config_user = tpu_embedding.TableConfig( |
| vocabulary_size=4, dimension=2, |
| initializer=initializer, combiner='mean') |
| table_to_config_dict = {'video': table_config_video, |
| 'user': table_config_user} |
| feature_to_config_dict = {'watched': tpu_embedding.FeatureConfig('video'), |
| 'favorited': tpu_embedding.FeatureConfig('video'), |
| 'friends': tpu_embedding.FeatureConfig('user')} |
| batch_size = 4 |
| num_hosts = 1 |
| optimization_parameters = tpu_embedding.AdagradParameters(1., 1.) |
| mode = tpu_embedding.TRAINING |
| embedding = tpu_embedding.TPUEmbedding( |
| table_to_config_dict, feature_to_config_dict, |
| batch_size, num_hosts, mode, optimization_parameters) |
| |
| batch_size_per_core = embedding.batch_size_per_core |
| sparse_features_list = [] |
| for host in hosts: |
| with ops.device(host): |
| for _ in range(embedding.num_cores_per_host): |
| sparse_features = {} |
| sparse_features['watched'] = sparse_tensor.SparseTensor(...) |
| sparse_features['favorited'] = sparse_tensor.SparseTensor(...) |
| sparse_features['friends'] = sparse_tensor.SparseTensor(...) |
| sparse_features_list.append(sparse_features) |
| |
| enqueue_ops = embedding.generate_enqueue_ops(sparse_features_list) |
| embedding_variables_and_ops = embedding.create_variables_and_ops() |
| |
| def computation(): |
| activations = embedding.get_activations() |
| loss = compute_loss(activations) |
| |
| base_optimizer = gradient_descent.GradientDescentOptimizer( |
| learning_rate=1) |
| cross_shard_optimizer = tpu_optimizer.CrossShardOptimizer( |
| base_optimizer) |
| |
| train_op = cross_shard_optimizer.minimize(loss) |
| gradients = ( |
| tpu_embedding_gradient.get_gradients_through_compute_gradients( |
| cross_shard_optimizer, loss, activations) |
| send_gradients_op = embedding.generate_send_gradients_op(gradients) |
| with ops.control_dependencies([train_op, send_gradients_op]): |
| loss = array_ops.identity(loss) |
| |
| loss = tpu.shard(computation, |
| num_shards=embedding.num_cores) |
| |
| with self.test_session() as sess: |
| sess.run(tpu.initialize_system(embedding_config= |
| embedding.config_proto)) |
| sess.run(variables.global_variables_initializer()) |
| sess.run(embedding_variables_and_ops.load_ops()) |
| sess.run(enqueue_ops) |
| loss_val = sess.run(loss) |
| ``` |
| |
| Example with weight decay: |
| |
| >>> def learning_rate_fn(global_step): |
| ... return tf.compat.v1.train.polynomial_decay( |
| ... learning_rate=5e-5, |
| ... global_step=global_step, |
| ... decay_steps=100000, |
| ... end_learning_rate=0.0) |
| >>> wordpiece_table_config = TableConfig( |
| ... vocabulary_size=119547, |
| ... dimension=256, |
| ... learning_rate_fn=learning_rate_fn) |
| >>> wordpiece_feature_config = FeatureConfig( |
| ... table_id='bert/embeddings/word_embeddings', |
| ... max_sequence_length=512) |
| >>> optimization_parameters = AdamParameters( |
| ... learning_rate=5e-5, |
| ... epsilon=1e-6, |
| ... weight_decay_factor=0.01, |
| ... multiply_weight_decay_factor_by_learning_rate=True) |
| >>> tpu_embedding = TPUEmbedding( |
| ... table_to_config_dict={ |
| ... 'bert/embeddings/word_embeddings': wordpiece_table_config, |
| ... }, |
| ... feature_to_config_dict={'input_ids': wordpiece_feature_config}, |
| ... batch_size=128, |
| ... mode=TRAINING, |
| ... optimization_parameters=optimization_parameters, |
| ... master='') |
| >>> with tf.Graph().as_default(): |
| ... init_tpu_op = tf.compat.v1.tpu.initialize_system( |
| ... embedding_config=tpu_embedding.config_proto) |
| ... tf.compat.v1.Session().run(init_tpu_op) |
| """ |
| |
| # TODO(shizhiw): Consider adding a field to FeatureConfig that indicates that |
| # the feature should not be used to update embedding table (cr/204852758, |
| # cr/204940540). Also, this can support different combiners for different |
| # features within the same table. |
| # TODO(shizhiw, b/118512626): Remove `batch_size` from `__init__` and move it |
| # to `FeatureConfig`? |
| |
| # TODO(shizhiw): will it be cleaner to make `table_to_config_dict` and |
| # `feature_to_config_dict` lists of `TableSpec` and `FeatureSpec` |
| # respectively? |
| |
| # TODO(shizhiw): Consider adding `input_fn` as an option to remove boilerplate |
| # for-loops around construction of inputs. |
| |
| # `optimization_parameter` applies to all tables. If the need arises, |
| # we can add `optimization_parameters` to `TableConfig` to override this |
| # global setting. |
| def __init__(self, |
| table_to_config_dict, |
| feature_to_config_dict, |
| batch_size, |
| mode, |
| master=None, |
| optimization_parameters=None, |
| cluster_def=None, |
| pipeline_execution_with_tensor_core=False, |
| partition_strategy='div', |
| profile_data_directory=None, |
| device_config=None, |
| master_job_name=None): |
| """API for using TPU for embedding lookups. |
| |
| Args: |
| table_to_config_dict: A dictionary mapping from string of table name to |
| `TableConfig`. Table refers to an embedding table, e.g. `params` |
| argument to `tf.nn.embedding_lookup_sparse()`. |
| feature_to_config_dict: A dictionary mapping from string of feature name |
| to `FeatureConfig`. Feature refers to ids to lookup in embedding table, |
| e.g. `sp_ids` argument to `tf.nn.embedding_lookup_sparse()`. |
| batch_size: An `int` representing the global batch size. |
| mode: `TRAINING` or `INFERENCE`. |
| master: A `string` representing the TensorFlow master to use. |
| optimization_parameters: `AdagradParameters`, `AdamParameters`, |
| `Stochasticgradientdescentparameters`. Must be set in training unless |
| all tables specify their own optimizers. And it must be `None` in |
| inference. |
| cluster_def: A ClusterDef object describing the TPU cluster. |
| pipeline_execution_with_tensor_core: setting this to `True` makes training |
| faster, but trained model will be different if step N and step N+1 |
| involve the same set of embedding IDs. Please see |
| `tpu_embedding_configuration.proto` for details. |
| partition_strategy: A string, either 'mod' or 'div', specifying how to map |
| the lookup id to the embedding tensor. For more information see |
| `tf.nn.embedding_lookup_sparse`. |
| profile_data_directory: Directory where embedding lookup statistics are |
| stored. These statistics summarize information about the inputs to the |
| embedding lookup operation, in particular, the average number of |
| embedding IDs per example and how well the embedding IDs are load |
| balanced across the system. The lookup statistics are used during TPU |
| initialization for embedding table partitioning. Collection of lookup |
| statistics is done at runtime by profiling the embedding inputs, only a |
| small fraction of input samples are profiled to minimize host CPU |
| overhead. Once a suitable number of samples are profiled, the lookup |
| statistics are saved to table-specific files in the profile data |
| directory generally at the end of a TPU training loop. The filename |
| corresponding to each table is obtained by hashing table specific |
| parameters (e.g., table name and number of features) and global |
| configuration parameters (e.g., sharding strategy and task count). The |
| same profile data directory can be shared among several models to reuse |
| embedding lookup statistics. |
| device_config: A DeviceConfig instance, used when `master` and |
| `cluster_def` are both `None`. |
| master_job_name: if set, overrides the master job name used to schedule |
| embedding ops. |
| |
| Raises: |
| ValueError: if any input is invalid. |
| """ |
| if partition_strategy not in ('div', 'mod'): |
| raise ValueError(f'partition_strategy must be "div" or "mod". ' |
| f'Received: {partition_strategy}.') |
| self._partition_strategy = partition_strategy |
| |
| self._profile_data_directory = profile_data_directory |
| |
| _validate_table_to_config_dict(table_to_config_dict) |
| # Avoid nondeterminism from `Dict` iteration order by using `OrderedDict`. |
| self._table_to_config_dict = _create_ordered_dict(table_to_config_dict) |
| |
| _validate_feature_to_config_dict(table_to_config_dict, |
| feature_to_config_dict) |
| self._feature_to_config_dict = _create_ordered_dict(feature_to_config_dict) |
| self._table_to_features_dict, self._table_to_num_features_dict = ( |
| _create_table_to_features_and_num_features_dicts( |
| self._feature_to_config_dict)) |
| self._combiners = _create_combiners(self._table_to_config_dict, |
| self._table_to_features_dict) |
| |
| self._batch_size = batch_size |
| |
| if master is None and cluster_def is None: |
| if device_config is None: |
| raise ValueError('When master and cluster_def are both None,' |
| 'device_config must be set but is not.') |
| if device_config.num_cores % device_config.num_hosts: |
| raise ValueError('num_hosts ({}) should divide num_cores ({}) ' |
| 'but does not.'.format(device_config.num_cores, |
| device_config.num_hosts)) |
| self._num_hosts = device_config.num_hosts |
| self._num_cores = device_config.num_cores |
| self._num_cores_per_host = self._num_cores // self._num_hosts |
| self._hosts = [ |
| '{}/replica:0/task:{}/device:CPU:0'.format(device_config.job_name, i) |
| for i in range(self._num_hosts) |
| ] |
| else: |
| tpu_system_metadata = ( |
| tpu_system_metadata_lib._query_tpu_system_metadata( # pylint: disable=protected-access |
| master, |
| cluster_def=cluster_def)) |
| if tpu_system_metadata.num_cores == 0: |
| raise ValueError('TPUEmbedding needs TPUs, but master {} does not have ' |
| 'TPUs.'.format(master)) |
| self._num_hosts = tpu_system_metadata.num_hosts |
| if master_job_name is None: |
| try: |
| master_job_name = tpu_system_metadata_lib.master_job( |
| master, cluster_def) |
| except ValueError as e: |
| raise ValueError(str(e) + ' Please specify a master_job_name.') |
| self._hosts = [] |
| for device in tpu_system_metadata.devices: |
| if 'device:CPU:' in device.name and (master_job_name is None or |
| master_job_name in device.name): |
| self._hosts.append(device.name) |
| self._num_cores_per_host = tpu_system_metadata.num_of_cores_per_host |
| self._num_cores = tpu_system_metadata.num_cores |
| |
| _validate_batch_size(self._batch_size, self._num_cores) |
| self._batch_size_per_core = self._batch_size // self._num_cores |
| |
| # TODO(shizhiw): remove `mode`? |
| if mode == TRAINING: |
| _validate_optimization_parameters(optimization_parameters, |
| self._table_to_config_dict) |
| self._optimization_parameters = optimization_parameters |
| elif mode == INFERENCE: |
| if optimization_parameters is not None: |
| raise ValueError(f'`optimization_parameters` should be `None` ' |
| f'for inference mode. ' |
| f'Received: {optimization_parameters}.') |
| self._optimization_parameters = (StochasticGradientDescentParameters(1.)) |
| else: |
| raise ValueError('`mode` only supports {} and {}; got {}.'.format( |
| TRAINING, INFERENCE, mode)) |
| self._mode = mode |
| |
| # TODO(shizhiw): move `optimization_parameters` into `_optimizer_handler` |
| # and create special handler for inference that inherits from |
| # StochasticGradientDescentHandler with more user-friendly error message |
| # on get_slot(). |
| self._optimizer_handler_dict = self._get_optimizer_handler_by_table() |
| |
| self._pipeline_execution_with_tensor_core = ( |
| pipeline_execution_with_tensor_core) |
| self._learning_rate_fn = list( |
| set(c.learning_rate_fn |
| for c in self._table_to_config_dict.values() |
| if c.learning_rate_fn is not None)) |
| self._learning_rate_fn_to_tag = { |
| fn: id for id, fn in enumerate(self._learning_rate_fn) |
| } |
| |
| self._config_proto = self._create_config_proto() |
| |
| @property |
| def hosts(self): |
| """A list of device names for CPU hosts. |
| |
| Returns: |
| A list of device names for CPU hosts. |
| """ |
| return copy.copy(self._hosts) |
| |
| # TODO(shizhiw): change to num_tensor_cores_per_host to be more explicit and |
| # to be consistent with `tpu_embedding_configuration.proto`. |
| @property |
| def num_cores_per_host(self): |
| """Number of TPU cores on a CPU host. |
| |
| Returns: |
| Number of TPU cores on a CPU host. |
| """ |
| return self._num_cores_per_host |
| |
| @property |
| def num_cores(self): |
| """Total number of TPU cores on all hosts. |
| |
| Returns: |
| Total number of TPU cores on all hosts. |
| """ |
| return self._num_cores |
| |
| @property |
| def batch_size_per_core(self): |
| """Batch size for each TPU core. |
| |
| The sparse tensors in `sparse_features_list` to `generate_enqueue_ops` |
| must have batch dimension equal to this. |
| |
| Returns: |
| Batch size for each TPU core. |
| """ |
| return self._batch_size_per_core |
| |
| @property |
| def config_proto(self): |
| """Create embedding config proto for `tpu.initialize_system()`. |
| |
| Returns: |
| an `TPUEmbeddingConfiguration` proto describing the desired |
| configuration of the hardware embedding lookup tables, which |
| is passed to `tpu.initialize_system()`. |
| """ |
| return self._config_proto |
| |
| @property |
| def table_to_config_dict(self): |
| return copy.copy(self._table_to_config_dict) |
| |
| @property |
| def feature_to_config_dict(self): |
| return copy.copy(self._feature_to_config_dict) |
| |
| @property |
| def table_to_features_dict(self): |
| return copy.copy(self._table_to_features_dict) |
| |
| @property |
| def optimization_parameters(self): |
| return self._optimization_parameters |
| |
| def _create_config_proto(self): |
| """Create `TPUEmbeddingConfiguration`.""" |
| config_proto = elc.TPUEmbeddingConfiguration() |
| for table in self._table_to_config_dict: |
| table_descriptor = config_proto.table_descriptor.add() |
| table_descriptor.name = table |
| |
| table_config = self._table_to_config_dict[table] |
| # For small tables, we pad to the number of hosts so that at least one |
| # id will be assigned to each host. |
| table_descriptor.vocabulary_size = max(table_config.vocabulary_size, |
| len(self.hosts)) |
| table_descriptor.dimension = table_config.dimension |
| |
| table_descriptor.num_features = self._table_to_num_features_dict[table] |
| |
| optimization_parameters = ( |
| self._optimizer_handler_dict[table].get_optimization_parameters()) |
| |
| parameters = table_descriptor.optimization_parameters |
| if table_config.learning_rate: |
| parameters.learning_rate.constant = table_config.learning_rate |
| elif table_config.learning_rate_fn: |
| parameters.learning_rate.dynamic.tag = ( |
| self._learning_rate_fn_to_tag[table_config.learning_rate_fn]) |
| else: |
| parameters.learning_rate.constant = ( |
| optimization_parameters.learning_rate) |
| parameters.gradient_accumulation_status = ( |
| optimization_parameters_pb2.GradientAccumulationStatus.ENABLED |
| if optimization_parameters.use_gradient_accumulation else |
| optimization_parameters_pb2.GradientAccumulationStatus.DISABLED) |
| |
| if optimization_parameters.clip_gradient_min is not None: |
| parameters.gradient_clipping_limits.lower.value = ( |
| optimization_parameters.clip_gradient_min) |
| if optimization_parameters.clip_gradient_max is not None: |
| parameters.gradient_clipping_limits.upper.value = ( |
| optimization_parameters.clip_gradient_max) |
| |
| if optimization_parameters.clip_weight_min is not None: |
| parameters.clipping_limits.lower.value = ( |
| optimization_parameters.clip_weight_min) |
| if optimization_parameters.clip_weight_max is not None: |
| parameters.clipping_limits.upper.value = ( |
| optimization_parameters.clip_weight_max) |
| if optimization_parameters.weight_decay_factor: |
| parameters.weight_decay_factor = ( |
| optimization_parameters.weight_decay_factor) |
| if (optimization_parameters |
| .multiply_weight_decay_factor_by_learning_rate): |
| parameters.multiply_weight_decay_factor_by_learning_rate = True |
| if table_config.hot_id_replication: |
| parameters.hot_id_replication_configuration.status = ( |
| optimization_parameters_pb2.HotIdReplicationConfiguration.ENABLED) |
| optimizer_handler = self._optimizer_handler_dict[table] |
| optimizer_handler.set_optimization_parameters(table_descriptor) |
| |
| config_proto.mode = self._mode |
| config_proto.batch_size_per_tensor_core = self._batch_size_per_core |
| config_proto.num_hosts = self._num_hosts |
| config_proto.num_tensor_cores = self._num_cores |
| config_proto.sharding_strategy = ( |
| elc.TPUEmbeddingConfiguration.DIV_DEFAULT if self._partition_strategy |
| == 'div' else elc.TPUEmbeddingConfiguration.MOD) |
| config_proto.pipeline_execution_with_tensor_core = ( |
| self._pipeline_execution_with_tensor_core) |
| if self._profile_data_directory: |
| config_proto.profile_data_directory = self._profile_data_directory |
| |
| return config_proto |
| |
| def create_variables_and_ops(self, |
| embedding_variable_name_by_table=None, |
| slot_variable_names_by_table=None): |
| """Create embedding and slot variables, with ops to load and retrieve them. |
| |
| N.B.: the retrieve embedding variables (including slot variables) ops are |
| returned as lambda fn, as the call side might want to impose control |
| dependencies between the TPU computation and retrieving actions. For |
| example, the following code snippet ensures the TPU computation finishes |
| first, and then we pull the variables back from TPU to CPU. |
| |
| ``` |
| updates_ops = [] |
| with ops.control_dependencies([loss]): |
| for op_fn in retrieve_parameters_op_fns: |
| update_ops.append(op_fn()) |
| ``` |
| |
| Args: |
| embedding_variable_name_by_table: A dictionary mapping from string of |
| table name to string of embedding variable name. If `None`, defaults |
| from `get_default_slot_variable_names()` will be used. |
| slot_variable_names_by_table: A dictionary mapping from string of table |
| name to `AdamSlotVariableNames`, `AdagradSlotVariableNames` etc. If |
| `None`, defaults from `get_default_slot_variable_names()` will be used. |
| |
| Returns: |
| `tpu_embedding.VariablesAndOps` with: |
| A dictionary mapping from string of table name to embedding variables, |
| A dictionary mapping from string of table name to AdagradSlotVariables, |
| AdamSlotVariables etc with slot variables, |
| A function which returns a list of ops to load embedding and slot |
| variables from CPU to TPU. |
| A function which returns a list of ops to retrieve embedding and slot |
| variables from TPU to CPU. |
| """ |
| embedding_variables_by_table = {} |
| slot_variables_by_table = {} |
| load_op_fns = [] |
| retrieve_op_fns = [] |
| |
| for i, table in enumerate(self._table_to_config_dict): |
| if embedding_variable_name_by_table: |
| embedding_variable_name = embedding_variable_name_by_table[table] |
| else: |
| embedding_variable_name = table |
| if slot_variable_names_by_table: |
| slot_variable_names = slot_variable_names_by_table[table] |
| else: |
| optimizer_handler = self._optimizer_handler_dict[table] |
| slot_variable_names = ( |
| optimizer_handler.get_default_slot_variable_names(table)) |
| |
| # TODO(b/139144091): Multi-host support for mid-level API in |
| # eager context (TF 2.0) |
| # Workaround below allows single-host use case in TF 2.0 |
| if context.executing_eagerly(): |
| device = '' |
| else: |
| device = _create_device_fn(self._hosts) |
| |
| with ops.device(device): |
| table_variables = _create_partitioned_variables( |
| name=embedding_variable_name, |
| num_hosts=self._num_hosts, |
| vocabulary_size=self._table_to_config_dict[table].vocabulary_size, |
| embedding_dimension=self._table_to_config_dict[table].dimension, |
| initializer=self._table_to_config_dict[table].initializer, |
| collections=[ops.GraphKeys.GLOBAL_VARIABLES]) |
| embedding_variables_by_table[table] = table_variables |
| |
| # Only loads embedding config to load/retrieve nodes for the first table |
| # on the first host, other nodes would use config from the first node. |
| config = None if i else self.config_proto.SerializeToString() |
| slot_variables_for_table, load_ops_fn, retrieve_ops_fn = ( |
| self._optimizer_handler_dict[table].create_variables_and_ops( |
| table, slot_variable_names, self._num_hosts, |
| self._table_to_config_dict[table], table_variables, config)) |
| slot_variables_by_table[table] = slot_variables_for_table |
| load_op_fns.append(load_ops_fn) |
| retrieve_op_fns.append(retrieve_ops_fn) |
| |
| def load_ops(): |
| """Calls and returns the load ops for each embedding table. |
| |
| Returns: |
| A list of ops to load embedding and slot variables from CPU to TPU. |
| """ |
| load_ops_list = [] |
| for load_op_fn in load_op_fns: |
| load_ops_list.extend(load_op_fn()) |
| return load_ops_list |
| |
| def retrieve_ops(): |
| """Calls and returns the retrieve ops for each embedding table. |
| |
| Returns: |
| A list of ops to retrieve embedding and slot variables from TPU to CPU. |
| """ |
| retrieve_ops_list = [] |
| for retrieve_op_fn in retrieve_op_fns: |
| retrieve_ops_list.extend(retrieve_op_fn()) |
| return retrieve_ops_list |
| |
| return VariablesAndOps(embedding_variables_by_table, |
| slot_variables_by_table, load_ops, retrieve_ops) |
| |
| def generate_enqueue_ops( |
| self, |
| enqueue_datas_list, |
| mode_override=None, |
| ragged=False, |
| ): |
| """Generate enqueue ops. |
| |
| Args: |
| enqueue_datas_list: a list of dictionary mapping from string of feature |
| names to EnqueueData. Each dictionary is for one TPU core. Dictionaries |
| for the same host should be contiguous in the list. |
| mode_override: A string input that overrides the mode specified in the |
| TPUEmbeddingConfiguration. Supported values are {'unspecified', |
| 'inference', 'training', 'backward_pass_only'}. When set to |
| 'unspecified', the mode set in TPUEmbeddingConfiguration is used, |
| otherwise mode_override is used (optional). |
| ragged: If True, creates RaggedTensor enqueue ops rather than |
| SparseTensor. |
| |
| Returns: |
| Ops to enqueue to TPU for embedding. |
| """ |
| self._validate_generate_enqueue_ops_enqueue_datas_list(enqueue_datas_list) |
| return [ |
| self._generate_enqueue_op( # pylint: disable=g-complex-comprehension |
| enqueue_datas, |
| device_ordinal=i % self._num_cores_per_host, |
| mode_override=mode_override, |
| ragged=ragged, |
| ) for i, enqueue_datas in enumerate(enqueue_datas_list) |
| ] |
| |
| def _validate_generate_enqueue_ops_enqueue_datas_list(self, |
| enqueue_datas_list): |
| """Validate `enqueue_datas_list`.""" |
| |
| def _check_agreement(data, name, feature, enqueue_data): |
| """Helper function to check device agreement.""" |
| if (data is not None and |
| data.device != enqueue_data.embedding_indices.device): |
| raise ValueError('Device of {0} does not agree with that of' |
| 'embedding_indices for feature {1}.'.format( |
| name, feature)) |
| |
| feature_set = set(self._feature_to_config_dict.keys()) |
| contiguous_device = None |
| for i, enqueue_datas in enumerate(enqueue_datas_list): |
| used_feature_set = set(enqueue_datas.keys()) |
| |
| # Check features are valid. |
| missing_feature_set = feature_set - used_feature_set |
| if missing_feature_set: |
| raise ValueError('`enqueue_datas_list[{}]` misses a feature that is ' |
| 'in `feature_to_config_dict`: {}.'.format( |
| i, missing_feature_set)) |
| |
| extra_feature_set = used_feature_set - feature_set |
| if extra_feature_set: |
| raise ValueError('`enqueue_datas_list[{}]` has a feature that is not ' |
| 'in `feature_to_config_dict`: {}.'.format( |
| i, extra_feature_set)) |
| |
| device = None |
| device_feature = None |
| for feature, enqueue_data in six.iteritems(enqueue_datas): |
| combiner = self._table_to_config_dict[ |
| self._feature_to_config_dict[feature].table_id].combiner |
| |
| if isinstance(enqueue_data, EnqueueData): |
| if enqueue_data.sample_indices is None and combiner: |
| logging.warn( |
| 'No sample indices set for features %f table %f but ' |
| 'combiner is set to %s.', feature, |
| self._feature_to_config_dict[feature].table_id, combiner) |
| _check_agreement(enqueue_data.sample_indices, 'sample_indices', |
| feature, enqueue_data) |
| _check_agreement(enqueue_data.aggregation_weights, |
| 'aggregation_weights', feature, enqueue_data) |
| |
| elif isinstance(enqueue_data, RaggedEnqueueData): |
| if enqueue_data.sample_splits is None and combiner: |
| logging.warn( |
| 'No sample splits set for features %f table %f but ' |
| 'combiner is set to %s.', feature, |
| self._feature_to_config_dict[feature].table_id, combiner) |
| _check_agreement(enqueue_data.sample_splits, 'sample_splits', feature, |
| enqueue_data) |
| _check_agreement(enqueue_data.aggregation_weights, |
| 'aggregation_weights', feature, enqueue_data) |
| else: |
| raise ValueError( |
| '`enqueue_datas_list[{}]` has a feature that is not mapped to ' |
| '`EnqueueData` or `RaggedEnqueueData`. `feature`: {}'.format( |
| i, feature)) |
| # Check all features are on the same device. |
| if device is None: |
| device = enqueue_data.embedding_indices.device |
| device_feature = feature |
| else: |
| if device != enqueue_data.embedding_indices.device: |
| raise ValueError('Devices are different between features in ' |
| '`enqueue_datas_list[{}]`; ' |
| 'devices: {}, {}; features: {}, {}.'.format( |
| i, device, |
| enqueue_data.embedding_indices.device, feature, |
| device_feature)) |
| |
| if i % self._num_cores_per_host: |
| if device != contiguous_device: |
| raise ValueError('We expect the `enqueue_datas` which are on the ' |
| 'same host to be contiguous in ' |
| '`enqueue_datas_list`, ' |
| '`enqueue_datas_list[{}]` is on device {}, ' |
| 'but is expected to be on device {}.'.format( |
| i, device, contiguous_device)) |
| else: |
| contiguous_device = device |
| |
| def _generate_enqueue_op(self, |
| enqueue_datas, |
| device_ordinal, |
| mode_override=None, |
| ragged=False): |
| """Creates op for enqueuing batch to TPU.""" |
| enqueue_data0 = list(enqueue_datas.values())[0] |
| with ops.colocate_with(enqueue_data0.embedding_indices): |
| if ragged: |
| # note that this is currently identical in behavior |
| return tpu_ops.enqueue_tpu_embedding_ragged_tensor_batch( |
| device_ordinal=device_ordinal, |
| combiners=self._combiners, |
| mode_override=mode_override, |
| **self._format_for_tpu_embedding_ragged_tensor_batch(enqueue_datas)) |
| else: |
| return tpu_ops.enqueue_tpu_embedding_sparse_tensor_batch( |
| device_ordinal=device_ordinal, |
| combiners=self._combiners, |
| mode_override=mode_override, |
| **self._format_for_tpu_embedding_sparse_tensor_batch(enqueue_datas)) |
| |
| def _format_for_tpu_embedding_ragged_tensor_batch(self, enqueue_datas): |
| """Format sparse features for `enqueue_tpu_embedding_ragged_tensor_batch()`. |
| |
| Args: |
| enqueue_datas: a `Dict` of `RaggedEnqueueData` objects for embedding. |
| |
| Returns: |
| Dict of arguments for `enqueue_tpu_embedding_ragged_tensor_batch()`. |
| """ |
| |
| kwargs = { |
| 'sample_splits': [], |
| 'embedding_indices': [], |
| 'aggregation_weights': [], |
| 'table_ids': [], |
| 'max_sequence_lengths': [], |
| } |
| int_zeros = array_ops.zeros((0,), dtype=dtypes.int64) |
| float_zeros = array_ops.zeros((0,), dtype=dtypes.float32) |
| for table_id, table in enumerate(self._table_to_features_dict): |
| features = self._table_to_features_dict[table] |
| for feature in features: |
| enqueue_data = enqueue_datas[feature] |
| |
| kwargs['sample_splits'].append( |
| enqueue_data.sample_splits if enqueue_data |
| .sample_splits is not None else int_zeros) |
| |
| kwargs['aggregation_weights'].append( |
| enqueue_data.aggregation_weights if enqueue_data |
| .aggregation_weights is not None else float_zeros) |
| |
| kwargs['embedding_indices'].append(enqueue_data.embedding_indices) |
| |
| kwargs['table_ids'].append(table_id) |
| kwargs['max_sequence_lengths'].append( |
| self._feature_to_config_dict[feature].max_sequence_length) |
| |
| return kwargs |
| |
| def _format_for_tpu_embedding_sparse_tensor_batch(self, enqueue_datas): |
| """Format sparse features for `enqueue_tpu_embedding_sparse_tensor_batch()`. |
| |
| Args: |
| enqueue_datas: a `Dict` of `EnqueueData` objects for embedding. |
| |
| Returns: |
| Dict of arguments for `enqueue_tpu_embedding_sparse_tensor_batch()`. |
| """ |
| kwargs = { |
| 'sample_indices': [], |
| 'embedding_indices': [], |
| 'aggregation_weights': [], |
| 'table_ids': [], |
| 'max_sequence_lengths': [], |
| } |
| int_zeros = array_ops.zeros((0,), dtype=dtypes.int64) |
| float_zeros = array_ops.zeros((0,), dtype=dtypes.float32) |
| for table_id, table in enumerate(self._table_to_features_dict): |
| features = self._table_to_features_dict[table] |
| for feature in features: |
| enqueue_data = enqueue_datas[feature] |
| |
| kwargs['sample_indices'].append( |
| enqueue_data.sample_indices if enqueue_data |
| .sample_indices is not None else int_zeros) |
| |
| kwargs['aggregation_weights'].append( |
| enqueue_data.aggregation_weights if enqueue_data |
| .aggregation_weights is not None else float_zeros) |
| |
| kwargs['embedding_indices'].append(enqueue_data.embedding_indices) |
| |
| kwargs['table_ids'].append(table_id) |
| kwargs['max_sequence_lengths'].append( |
| self._feature_to_config_dict[feature].max_sequence_length) |
| |
| return kwargs |
| |
| def get_activations(self): |
| """Get activations for features. |
| |
| This should be called within `computation` that is passed to |
| `tpu.replicate` and friends. |
| |
| Returns: |
| A dictionary mapping from `String` of feature name to `Tensor` |
| of activation. |
| """ |
| recv_activations = tpu_ops.recv_tpu_embedding_activations( |
| num_outputs=len(self._table_to_config_dict), |
| config=self._config_proto.SerializeToString()) |
| |
| activations = collections.OrderedDict() |
| for table_id, table in enumerate(self._table_to_features_dict): |
| features = self._table_to_features_dict[table] |
| num_features = self._table_to_num_features_dict[table] |
| feature_index = 0 |
| table_activations = array_ops.reshape( |
| recv_activations[table_id], |
| [self.batch_size_per_core, num_features, -1]) |
| for feature in features: |
| seq_length = self._feature_to_config_dict[feature].max_sequence_length |
| if not seq_length: |
| activations[feature] = table_activations[:, feature_index, :] |
| feature_index = feature_index + 1 |
| else: |
| activations[feature] = ( |
| table_activations[:, |
| feature_index:(feature_index + seq_length), :]) |
| feature_index = feature_index + seq_length |
| |
| return activations |
| |
| def generate_send_gradients_op(self, feature_to_gradient_dict, step=None): |
| """Send gradient to TPU embedding. |
| |
| Args: |
| feature_to_gradient_dict: dict mapping feature names to gradient wrt |
| activations. |
| step: the current global step, used for dynamic learning rate. |
| |
| Returns: |
| SendTPUEmbeddingGradients Op. |
| |
| Raises: |
| RuntimeError: If `mode` is not `TRAINING`. |
| """ |
| if self._mode != TRAINING: |
| raise RuntimeError('Only in training mode gradients need to ' |
| 'be sent to TPU embedding; got mode {}.'.format( |
| self._mode)) |
| if step is None and self._learning_rate_fn: |
| raise ValueError('There are dynamic learning rates but step is None.') |
| |
| gradients = [] |
| for table in self._table_to_features_dict: |
| features = self._table_to_features_dict[table] |
| table_gradients = [] |
| for feature in features: |
| gradient = feature_to_gradient_dict[feature] |
| # Expand dims for non-sequence feature to match sequence features. |
| if gradient.shape.ndims == 2: |
| gradient = array_ops.expand_dims(gradient, 1) |
| table_gradients.append(gradient) |
| interleaved_table_grads = array_ops.reshape( |
| array_ops.concat(table_gradients, axis=1), |
| [-1, array_ops.shape(table_gradients[0])[-1]]) |
| gradients.append(interleaved_table_grads) |
| |
| return tpu_ops.send_tpu_embedding_gradients( |
| inputs=gradients, |
| learning_rates=[ |
| math_ops.cast(fn(step), dtype=dtypes.float32) |
| for fn in self._learning_rate_fn |
| ], |
| config=self.config_proto.SerializeToString()) |
| |
| def _get_optimizer_handler_by_table(self): |
| optimizer_handlers = {} |
| for table, table_config in self.table_to_config_dict.items(): |
| if table_config.optimization_parameters is not None: |
| optimizer = table_config.optimization_parameters |
| else: |
| optimizer = self._optimization_parameters |
| optimizer_handlers[table] = _get_optimization_handler(optimizer) |
| |
| return optimizer_handlers |
| |
| |
| def _validate_table_to_config_dict(table_to_config_dict): |
| """Validate `table_to_config_dict`.""" |
| for k, v in six.iteritems(table_to_config_dict): |
| if not isinstance(v, TableConfig): |
| raise ValueError('Value of `table_to_config_dict` must be of type ' |
| '`TableConfig`, got {} for {}.'.format(type(v), k)) |
| |
| |
| def _validate_feature_to_config_dict(table_to_config_dict, |
| feature_to_config_dict): |
| """Validate `feature_to_config_dict`.""" |
| used_table_set = set( |
| [feature.table_id for feature in feature_to_config_dict.values()]) |
| table_set = set(table_to_config_dict.keys()) |
| |
| unused_table_set = table_set - used_table_set |
| if unused_table_set: |
| raise ValueError( |
| '`table_to_config_dict` specifies table that is not ' |
| 'used in `feature_to_config_dict`: {}.'.format(unused_table_set)) |
| |
| extra_table_set = used_table_set - table_set |
| if extra_table_set: |
| raise ValueError( |
| '`feature_to_config_dict` refers to a table that is not ' |
| 'specified in `table_to_config_dict`: {}.'.format(extra_table_set)) |
| |
| |
| def _validate_batch_size(batch_size, num_cores): |
| if batch_size % num_cores: |
| raise ValueError('`batch_size` is not a multiple of number of ' |
| 'cores. `batch_size`={}, `_num_cores`={}.'.format( |
| batch_size, num_cores)) |
| |
| |
| def _validate_optimization_parameters(optimization_parameters, |
| table_to_config_dict): |
| """Validate global optimization_parameters and per table optimizers. |
| |
| If global optimizer is `None`, all table optimizers should be non `None`. |
| |
| Args: |
| optimization_parameters: global optimizer provided in `TPUEmbedding` |
| constructor. |
| table_to_config_dict: A dictionary mapping from string of table name to |
| `TableConfig`. |
| """ |
| tbl_optimizer_missing = False |
| for _, table_config in table_to_config_dict.items(): |
| if table_config.optimization_parameters is None: |
| tbl_optimizer_missing = True |
| break |
| |
| if optimization_parameters: |
| if not isinstance(optimization_parameters, _OptimizationParameters): |
| raise ValueError('`optimization_parameters` must inherit from ' |
| '`_OptimizationParameters`. ' |
| '`type(optimization_parameters)`={}'.format( |
| type(optimization_parameters))) |
| else: |
| # Missing global optimization_parameters. |
| if tbl_optimizer_missing: |
| raise ValueError('`optimization_parameters` is missing.') |
| |
| |
| class _OptimizerHandler(object): |
| """Interface class for handling optimizer specific logic.""" |
| |
| def __init__(self, optimization_parameters): |
| self._optimization_parameters = optimization_parameters |
| |
| def get_optimization_parameters(self): |
| return self._optimization_parameters |
| |
| def set_optimization_parameters(self, table_descriptor): |
| raise NotImplementedError() |
| |
| def get_default_slot_variable_names(self, table): |
| raise NotImplementedError() |
| |
| def create_variables_and_ops(self, table, slot_variable_names, num_hosts, |
| table_config, table_variables, config_proto): |
| raise NotImplementedError() |
| |
| |
| class _AdagradHandler(_OptimizerHandler): |
| """Handles Adagrad specific logic.""" |
| |
| def set_optimization_parameters(self, table_descriptor): |
| table_descriptor.optimization_parameters.adagrad.SetInParent() |
| |
| def get_default_slot_variable_names(self, table): |
| return AdagradSlotVariableNames('{}/{}'.format(table, 'Adagrad')) |
| |
| def create_variables_and_ops(self, table, slot_variable_names, num_hosts, |
| table_config, table_variables, config_proto): |
| accumulator_initializer = init_ops.constant_initializer( |
| self._optimization_parameters.initial_accumulator) |
| accumulator_variables = _create_partitioned_variables( |
| name=slot_variable_names.accumulator, |
| num_hosts=num_hosts, |
| vocabulary_size=table_config.vocabulary_size, |
| embedding_dimension=table_config.dimension, |
| collections=[ops.GraphKeys.GLOBAL_VARIABLES], |
| initializer=accumulator_initializer) |
| slot_variables = AdagradSlotVariables(accumulator_variables) |
| |
| def load_ops_fn(): |
| """Returns the retrieve ops for AdaGrad embedding tables. |
| |
| Returns: |
| A list of ops to load embedding and slot variables from CPU to TPU. |
| """ |
| config = config_proto |
| load_op_list = [] |
| for host_id, table_variable, accumulator_variable in zip( |
| range(num_hosts), table_variables, accumulator_variables): |
| with ops.colocate_with(table_variable): |
| load_parameters_op = ( |
| tpu_ops.load_tpu_embedding_adagrad_parameters( |
| parameters=table_variable, |
| accumulators=accumulator_variable, |
| table_name=table, |
| num_shards=num_hosts, |
| shard_id=host_id, |
| config=config)) |
| config = None |
| load_op_list.append(load_parameters_op) |
| return load_op_list |
| |
| def retrieve_ops_fn(): |
| """Returns the retrieve ops for AdaGrad embedding tables. |
| |
| Returns: |
| A list of ops to retrieve embedding and slot variables from TPU to CPU. |
| """ |
| config = config_proto |
| retrieve_op_list = [] |
| for host_id, table_variable, accumulator_variable in (zip( |
| range(num_hosts), table_variables, accumulator_variables)): |
| with ops.colocate_with(table_variable): |
| retrieved_table, retrieved_accumulator = ( |
| tpu_ops.retrieve_tpu_embedding_adagrad_parameters( |
| table_name=table, |
| num_shards=num_hosts, |
| shard_id=host_id, |
| config=config)) |
| retrieve_parameters_op = control_flow_ops.group( |
| state_ops.assign(table_variable, retrieved_table), |
| state_ops.assign(accumulator_variable, retrieved_accumulator)) |
| config = None |
| retrieve_op_list.append(retrieve_parameters_op) |
| return retrieve_op_list |
| |
| return slot_variables, load_ops_fn, retrieve_ops_fn |
| |
| |
| class _AdagradMomentumHandler(_OptimizerHandler): |
| """Handles Adagrad with Momentum specific logic. |
| |
| Creates slot variables and defines their initializers. Defines load/retrieve |
| operations to be used for loading variables into TPU memory (from host memory) |
| and retrieving variables from TPU memory (into host memory). |
| """ |
| |
| def set_optimization_parameters(self, table_descriptor): |
| table_descriptor.optimization_parameters.adagrad_momentum.SetInParent() |
| table_descriptor.optimization_parameters.adagrad_momentum.momentum = ( |
| self._optimization_parameters.momentum) |
| table_descriptor.optimization_parameters.adagrad_momentum.use_nesterov = ( |
| self._optimization_parameters.use_nesterov) |
| table_descriptor.optimization_parameters.adagrad_momentum.exponent = ( |
| self._optimization_parameters.exponent) |
| table_descriptor.optimization_parameters.adagrad_momentum.beta2 = ( |
| self._optimization_parameters.beta2) |
| table_descriptor.optimization_parameters.adagrad_momentum.epsilon = ( |
| self._optimization_parameters.epsilon) |
| |
| def get_default_slot_variable_names(self, table): |
| return AdagradMomentumSlotVariableNames( |
| '{}/{}/Accumulator'.format(table, 'AdagradMomentum'), |
| '{}/{}/Momentum'.format(table, 'AdagradMomentum')) |
| |
| def create_variables_and_ops(self, table, slot_variable_names, num_hosts, |
| table_config, table_variables, config_proto): |
| accumulator_initializer = init_ops.zeros_initializer() |
| accumulator_variables = _create_partitioned_variables( |
| name=slot_variable_names.accumulator, |
| num_hosts=num_hosts, |
| vocabulary_size=table_config.vocabulary_size, |
| embedding_dimension=table_config.dimension, |
| collections=[ops.GraphKeys.GLOBAL_VARIABLES], |
| initializer=accumulator_initializer) |
| momenta_initializer = init_ops.zeros_initializer() |
| momenta_variables = _create_partitioned_variables( |
| name=slot_variable_names.momenta, |
| num_hosts=num_hosts, |
| vocabulary_size=table_config.vocabulary_size, |
| embedding_dimension=table_config.dimension, |
| collections=[ops.GraphKeys.GLOBAL_VARIABLES], |
| initializer=momenta_initializer) |
| slot_variables = AdagradMomentumSlotVariables(accumulator_variables, |
| momenta_variables) |
| |
| def load_ops_fn(): |
| """Returns the load ops for AdaGrad with momentum embedding tables. |
| |
| Returns: |
| A list of ops to load embedding and slot variables from CPU to TPU. |
| """ |
| config = config_proto |
| load_op_list = [] |
| for host_id, table_variable, accumulator_variable, momenta_variable in zip( |
| range(num_hosts), table_variables, accumulator_variables, |
| momenta_variables): |
| with ops.colocate_with(table_variable): |
| load_parameters_op = ( |
| tpu_ops.load_tpu_embedding_adagrad_momentum_parameters( |
| parameters=table_variable, |
| accumulators=accumulator_variable, |
| momenta=momenta_variable, |
| table_name=table, |
| num_shards=num_hosts, |
| shard_id=host_id, |
| config=config)) |
| config = None |
| load_op_list.append(load_parameters_op) |
| return load_op_list |
| |
| def retrieve_ops_fn(): |
| """Returns the retrieve ops for AdaGrad with momentum embedding tables. |
| |
| Returns: |
| A list of ops to retrieve embedding and slot variables from TPU to CPU. |
| """ |
| config = config_proto |
| retrieve_op_list = [] |
| for host_id, table_variable, accumulator_variable, momenta_variable in ( |
| zip( |
| range(num_hosts), table_variables, accumulator_variables, |
| momenta_variables)): |
| with ops.colocate_with(table_variable): |
| retrieved_table, retrieved_accumulator, retrieved_momenta = ( |
| tpu_ops.retrieve_tpu_embedding_adagrad_momentum_parameters( |
| table_name=table, |
| num_shards=num_hosts, |
| shard_id=host_id, |
| config=config)) |
| retrieve_parameters_op = control_flow_ops.group( |
| state_ops.assign(table_variable, retrieved_table), |
| state_ops.assign(accumulator_variable, retrieved_accumulator), |
| state_ops.assign(momenta_variable, retrieved_momenta)) |
| config = None |
| retrieve_op_list.append(retrieve_parameters_op) |
| return retrieve_op_list |
| |
| return slot_variables, load_ops_fn, retrieve_ops_fn |
| |
| |
| class _ProximalAdagradHandler(_OptimizerHandler): |
| """Handles ProximalAdagrad specific logic.""" |
| |
| def set_optimization_parameters(self, table_descriptor): |
| table_descriptor.optimization_parameters.proximal_adagrad.SetInParent() |
| table_descriptor.optimization_parameters.proximal_adagrad.l1 = ( |
| self._optimization_parameters.l1_regularization_strength) |
| table_descriptor.optimization_parameters.proximal_adagrad.l2 = ( |
| self._optimization_parameters.l2_regularization_strength) |
| |
| def get_default_slot_variable_names(self, table): |
| return ProximalAdagradSlotVariableNames('{}/{}'.format( |
| table, 'ProximalAdagrad')) |
| |
| def create_variables_and_ops(self, table, slot_variable_names, num_hosts, |
| table_config, table_variables, config_proto): |
| accumulator_initializer = init_ops.constant_initializer( |
| self._optimization_parameters.initial_accumulator) |
| accumulator_variables = _create_partitioned_variables( |
| name=slot_variable_names.accumulator, |
| num_hosts=num_hosts, |
| vocabulary_size=table_config.vocabulary_size, |
| embedding_dimension=table_config.dimension, |
| collections=[ops.GraphKeys.GLOBAL_VARIABLES], |
| initializer=accumulator_initializer) |
| slot_variables = ProximalAdagradSlotVariables(accumulator_variables) |
| |
| def load_ops_fn(): |
| """Returns the retrieve ops for Proximal AdaGrad embedding tables. |
| |
| Returns: |
| A list of ops to load embedding and slot variables from CPU to TPU. |
| """ |
| config = config_proto |
| load_op_list = [] |
| for host_id, table_variable, accumulator_variable in zip( |
| range(num_hosts), table_variables, accumulator_variables): |
| with ops.colocate_with(table_variable): |
| load_parameters_op = ( |
| tpu_ops.load_tpu_embedding_proximal_adagrad_parameters( |
| parameters=table_variable, |
| accumulators=accumulator_variable, |
| table_name=table, |
| num_shards=num_hosts, |
| shard_id=host_id, |
| config=config)) |
| config = None |
| load_op_list.append(load_parameters_op) |
| return load_op_list |
| |
| def retrieve_ops_fn(): |
| """Returns the retrieve ops for Proximal AdaGrad embedding tables. |
| |
| Returns: |
| A list of ops to retrieve embedding and slot variables from TPU to CPU. |
| """ |
| config = config_proto |
| retrieve_op_list = [] |
| for host_id, table_variable, accumulator_variable in (zip( |
| range(num_hosts), table_variables, accumulator_variables)): |
| with ops.colocate_with(table_variable): |
| retrieved_table, retrieved_accumulator = ( |
| tpu_ops.retrieve_tpu_embedding_proximal_adagrad_parameters( |
| table_name=table, |
| num_shards=num_hosts, |
| shard_id=host_id, |
| config=config)) |
| retrieve_parameters_op = control_flow_ops.group( |
| state_ops.assign(table_variable, retrieved_table), |
| state_ops.assign(accumulator_variable, retrieved_accumulator)) |
| config = None |
| retrieve_op_list.append(retrieve_parameters_op) |
| return retrieve_op_list |
| |
| return slot_variables, load_ops_fn, retrieve_ops_fn |
| |
| |
| class _AdamHandler(_OptimizerHandler): |
| """Handles Adam specific logic.""" |
| |
| def set_optimization_parameters(self, table_descriptor): |
| table_descriptor.optimization_parameters.adam.beta1 = ( |
| self._optimization_parameters.beta1) |
| table_descriptor.optimization_parameters.adam.beta2 = ( |
| self._optimization_parameters.beta2) |
| table_descriptor.optimization_parameters.adam.epsilon = ( |
| self._optimization_parameters.epsilon) |
| table_descriptor.optimization_parameters.adam.use_non_lazy_adam = ( |
| not self._optimization_parameters.lazy_adam) |
| table_descriptor.optimization_parameters.adam.use_sum_inside_sqrt = ( |
| self._optimization_parameters.sum_inside_sqrt) |
| |
| def get_default_slot_variable_names(self, table): |
| return AdamSlotVariableNames('{}/{}/m'.format(table, 'Adam'), |
| '{}/{}/v'.format(table, 'Adam')) |
| |
| def create_variables_and_ops(self, table, slot_variable_names, num_hosts, |
| table_config, table_variables, config_proto): |
| m_initializer = init_ops.zeros_initializer() |
| m_variables = _create_partitioned_variables( |
| name=slot_variable_names.m, |
| num_hosts=num_hosts, |
| vocabulary_size=table_config.vocabulary_size, |
| embedding_dimension=table_config.dimension, |
| collections=[ops.GraphKeys.GLOBAL_VARIABLES], |
| initializer=m_initializer) |
| v_initializer = init_ops.zeros_initializer() |
| v_variables = _create_partitioned_variables( |
| name=slot_variable_names.v, |
| num_hosts=num_hosts, |
| vocabulary_size=table_config.vocabulary_size, |
| embedding_dimension=table_config.dimension, |
| collections=[ops.GraphKeys.GLOBAL_VARIABLES], |
| initializer=v_initializer) |
| slot_variables = AdamSlotVariables(m_variables, v_variables) |
| |
| def load_ops_fn(): |
| """Returns the retrieve ops for AdaGrad embedding tables. |
| |
| Returns: |
| A list of ops to load embedding and slot variables from CPU to TPU. |
| """ |
| load_op_list = [] |
| config = config_proto |
| for host_id, table_variable, m_variable, v_variable in (zip( |
| range(num_hosts), table_variables, m_variables, v_variables)): |
| with ops.colocate_with(table_variable): |
| load_parameters_op = ( |
| tpu_ops.load_tpu_embedding_adam_parameters( |
| parameters=table_variable, |
| momenta=m_variable, |
| velocities=v_variable, |
| table_name=table, |
| num_shards=num_hosts, |
| shard_id=host_id, |
| config=config)) |
| # Set config to None to enforce that config is only loaded to the first |
| # table. |
| config = None |
| load_op_list.append(load_parameters_op) |
| return load_op_list |
| |
| def retrieve_ops_fn(): |
| """Returns the retrieve ops for Adam embedding tables. |
| |
| Returns: |
| A list of ops to retrieve embedding and slot variables from TPU to CPU. |
| """ |
| retrieve_op_list = [] |
| config = config_proto |
| for host_id, table_variable, m_variable, v_variable in (zip( |
| range(num_hosts), table_variables, m_variables, v_variables)): |
| with ops.colocate_with(table_variable): |
| retrieved_table, retrieved_m, retrieved_v = ( |
| tpu_ops.retrieve_tpu_embedding_adam_parameters( |
| table_name=table, |
| num_shards=num_hosts, |
| shard_id=host_id, |
| config=config)) |
| retrieve_parameters_op = control_flow_ops.group( |
| state_ops.assign(table_variable, retrieved_table), |
| state_ops.assign(m_variable, retrieved_m), |
| state_ops.assign(v_variable, retrieved_v)) |
| config = None |
| retrieve_op_list.append(retrieve_parameters_op) |
| return retrieve_op_list |
| |
| return slot_variables, load_ops_fn, retrieve_ops_fn |
| |
| |
| class _FtrlHandler(_OptimizerHandler): |
| """Handles Ftrl specific logic.""" |
| |
| def set_optimization_parameters(self, table_descriptor): |
| table_descriptor.optimization_parameters.ftrl.lr_power = ( |
| self._optimization_parameters.learning_rate_power) |
| table_descriptor.optimization_parameters.ftrl.l1 = ( |
| self._optimization_parameters.l1_regularization_strength) |
| table_descriptor.optimization_parameters.ftrl.l2 = ( |
| self._optimization_parameters.l2_regularization_strength) |
| table_descriptor.optimization_parameters.ftrl.multiply_linear_by_lr = ( |
| self._optimization_parameters.multiply_linear_by_learning_rate) |
| table_descriptor.optimization_parameters.ftrl.beta = ( |
| self._optimization_parameters.beta) |
| table_descriptor.optimization_parameters.ftrl.allow_zero_accumulator = ( |
| self._optimization_parameters.allow_zero_accumulator) |
| |
| def get_default_slot_variable_names(self, table): |
| # These match the default slot variable names created by |
| # tf.train.FtrlOptimizer. |
| return FtrlSlotVariableNames( |
| '{}/{}'.format(table, 'Ftrl'), # accumulator |
| '{}/{}'.format(table, 'Ftrl_1')) # linear |
| |
| def create_variables_and_ops(self, table, slot_variable_names, num_hosts, |
| table_config, table_variables, config_proto): |
| accumulator_initializer = init_ops.constant_initializer( |
| self._optimization_parameters.initial_accumulator_value) |
| accumulator_variables = _create_partitioned_variables( |
| name=slot_variable_names.accumulator, |
| num_hosts=num_hosts, |
| vocabulary_size=table_config.vocabulary_size, |
| embedding_dimension=table_config.dimension, |
| collections=[ops.GraphKeys.GLOBAL_VARIABLES], |
| initializer=accumulator_initializer) |
| linear_initializer = init_ops.constant_initializer( |
| self._optimization_parameters.initial_linear_value) |
| linear_variables = _create_partitioned_variables( |
| name=slot_variable_names.linear, |
| num_hosts=num_hosts, |
| vocabulary_size=table_config.vocabulary_size, |
| embedding_dimension=table_config.dimension, |
| collections=[ops.GraphKeys.GLOBAL_VARIABLES], |
| initializer=linear_initializer) |
| slot_variables = FtrlSlotVariable(accumulator_variables, linear_variables) |
| |
| def load_ops_fn(): |
| """Returns the retrieve ops for Ftrl embedding tables. |
| |
| Returns: |
| A list of ops to load embedding and slot variables from CPU to TPU. |
| """ |
| config = config_proto |
| load_op_list = [] |
| for host_id, table_variable, accumulator_variable, linear_variable in zip( |
| range(num_hosts), table_variables, accumulator_variables, |
| linear_variables): |
| with ops.colocate_with(table_variable): |
| load_parameters_op = ( |
| tpu_ops.load_tpu_embedding_ftrl_parameters( |
| parameters=table_variable, |
| accumulators=accumulator_variable, |
| linears=linear_variable, |
| table_name=table, |
| num_shards=num_hosts, |
| shard_id=host_id, |
| config=config)) |
| config = None |
| load_op_list.append(load_parameters_op) |
| return load_op_list |
| |
| def retrieve_ops_fn(): |
| """Returns the retrieve ops for Ftrl embedding tables. |
| |
| Returns: |
| A list of ops to retrieve embedding and slot variables from TPU to CPU. |
| """ |
| config = config_proto |
| retrieve_op_list = [] |
| for host_id, table_variable, accumulator_variable, linear_variable in zip( |
| range(num_hosts), table_variables, accumulator_variables, |
| linear_variables): |
| with ops.colocate_with(table_variable): |
| retrieved_table, retrieved_accumulator, retrieved_linear = ( |
| tpu_ops.retrieve_tpu_embedding_ftrl_parameters( |
| table_name=table, |
| num_shards=num_hosts, |
| shard_id=host_id, |
| config=config)) |
| retrieve_parameters_op = control_flow_ops.group( |
| state_ops.assign(table_variable, retrieved_table), |
| state_ops.assign(accumulator_variable, retrieved_accumulator), |
| state_ops.assign(linear_variable, retrieved_linear)) |
| config = None |
| retrieve_op_list.append(retrieve_parameters_op) |
| return retrieve_op_list |
| |
| return slot_variables, load_ops_fn, retrieve_ops_fn |
| |
| |
| class _ProximalYogiHandler(_OptimizerHandler): |
| """Handles Proximal Yogi specific logic.""" |
| |
| def set_optimization_parameters(self, table_descriptor): |
| table_descriptor.optimization_parameters.proximal_yogi.SetInParent() |
| table_descriptor.optimization_parameters.proximal_yogi.beta1 = ( |
| self._optimization_parameters.beta1) |
| table_descriptor.optimization_parameters.proximal_yogi.beta2 = ( |
| self._optimization_parameters.beta2) |
| table_descriptor.optimization_parameters.proximal_yogi.epsilon = ( |
| self._optimization_parameters.epsilon) |
| table_descriptor.optimization_parameters.proximal_yogi.l1 = ( |
| self._optimization_parameters.l1_regularization_strength) |
| table_descriptor.optimization_parameters.proximal_yogi.l2 = ( |
| self._optimization_parameters.l2_regularization_strength) |
| |
| def get_default_slot_variable_names(self, table): |
| return ProximalYogiSlotVariableNames( |
| '{}/{}'.format(table, 'ProximalYogi'), # v |
| '{}/{}_1'.format(table, 'ProximalYogi')) # m |
| |
| def create_variables_and_ops(self, table, slot_variable_names, num_hosts, |
| table_config, table_variables, config_proto): |
| v_initializer = init_ops.constant_initializer( |
| self._optimization_parameters.initial_accumulator_value) |
| v_variables = _create_partitioned_variables( |
| name=slot_variable_names.v, |
| num_hosts=num_hosts, |
| vocabulary_size=table_config.vocabulary_size, |
| embedding_dimension=table_config.dimension, |
| collections=[ops.GraphKeys.GLOBAL_VARIABLES], |
| initializer=v_initializer) |
| m_initializer = init_ops.zeros_initializer() |
| m_variables = _create_partitioned_variables( |
| name=slot_variable_names.m, |
| num_hosts=num_hosts, |
| vocabulary_size=table_config.vocabulary_size, |
| embedding_dimension=table_config.dimension, |
| collections=[ops.GraphKeys.GLOBAL_VARIABLES], |
| initializer=m_initializer) |
| slot_variables = ProximalYogiSlotVariables(v_variables, m_variables) |
| |
| def load_ops_fn(): |
| """Returns the load ops for Proximal Yogi embedding tables. |
| |
| Returns: |
| A list of ops to load embedding and slot variables from CPU to TPU. |
| """ |
| load_op_list = [] |
| config = config_proto |
| for host_id, table_variable, v_variable, m_variable in (zip( |
| range(num_hosts), table_variables, v_variables, m_variables)): |
| with ops.colocate_with(table_variable): |
| load_parameters_op = ( |
| tpu_ops.load_tpu_embedding_proximal_yogi_parameters( |
| parameters=table_variable, |
| v=v_variable, |
| m=m_variable, |
| table_name=table, |
| num_shards=num_hosts, |
| shard_id=host_id, |
| config=config)) |
| # Set config to None to enforce that config is only loaded to the first |
| # table. |
| config = None |
| load_op_list.append(load_parameters_op) |
| return load_op_list |
| |
| def retrieve_ops_fn(): |
| """Returns the retrieve ops for Proximal Yogi embedding tables. |
| |
| Returns: |
| A list of ops to retrieve embedding and slot variables from TPU to CPU. |
| """ |
| retrieve_op_list = [] |
| config = config_proto |
| for host_id, table_variable, v_variable, m_variable in (zip( |
| range(num_hosts), table_variables, v_variables, m_variables)): |
| with ops.colocate_with(table_variable): |
| retrieved_table, retrieved_v, retrieved_m = ( |
| tpu_ops.retrieve_tpu_embedding_proximal_yogi_parameters( |
| table_name=table, |
| num_shards=num_hosts, |
| shard_id=host_id, |
| config=config)) |
| retrieve_parameters_op = control_flow_ops.group( |
| state_ops.assign(table_variable, retrieved_table), |
| state_ops.assign(v_variable, retrieved_v), |
| state_ops.assign(m_variable, retrieved_m)) |
| config = None |
| retrieve_op_list.append(retrieve_parameters_op) |
| return retrieve_op_list |
| |
| return slot_variables, load_ops_fn, retrieve_ops_fn |
| |
| |
| class _MomentumHandler(_OptimizerHandler): |
| """Handles Momentum specific logic.""" |
| |
| def set_optimization_parameters(self, table_descriptor): |
| (table_descriptor.optimization_parameters.momentum.SetInParent()) |
| table_descriptor.optimization_parameters.momentum.momentum = ( |
| self._optimization_parameters.momentum) |
| table_descriptor.optimization_parameters.momentum.use_nesterov = ( |
| self._optimization_parameters.use_nesterov) |
| |
| def get_default_slot_variable_names(self, table): |
| return MomentumSlotVariableNames('{}/{}'.format(table, 'Momentum')) |
| |
| def create_variables_and_ops(self, table, slot_variable_names, num_hosts, |
| table_config, table_variables, config_proto): |
| |
| momenta_initializer = init_ops.zeros_initializer() |
| momenta_variables = _create_partitioned_variables( |
| name=slot_variable_names.momenta, |
| num_hosts=num_hosts, |
| vocabulary_size=table_config.vocabulary_size, |
| embedding_dimension=table_config.dimension, |
| collections=[ops.GraphKeys.GLOBAL_VARIABLES], |
| initializer=momenta_initializer) |
| slot_variables = MomentumSlotVariables(momenta_variables) |
| |
| def load_ops_fn(): |
| """Returns the retrieve ops for Momentum embedding tables. |
| |
| Returns: |
| A list of ops to load embedding and slot variables from CPU to TPU. |
| """ |
| load_op_list = [] |
| config = config_proto |
| for host_id, table_variable, momenta_variable in (zip( |
| range(num_hosts), table_variables, momenta_variables)): |
| with ops.colocate_with(table_variable): |
| load_parameters_op = tpu_ops.load_tpu_embedding_momentum_parameters( |
| parameters=table_variable, |
| momenta=momenta_variable, |
| table_name=table, |
| num_shards=num_hosts, |
| shard_id=host_id, |
| config=config, |
| ) |
| config = None |
| load_op_list.append(load_parameters_op) |
| return load_op_list |
| |
| def retrieve_ops_fn(): |
| """Returns the retrieve ops for Momentum embedding tables. |
| |
| Returns: |
| A list of ops to retrieve embedding and slot variables from TPU to CPU. |
| """ |
| retrieve_op_list = [] |
| config = config_proto |
| for host_id, table_variable, momenta_variable in (zip( |
| range(num_hosts), table_variables, momenta_variables)): |
| with ops.colocate_with(table_variable): |
| retrieved_table, retrieved_momenta = ( |
| tpu_ops.retrieve_tpu_embedding_momentum_parameters( |
| table_name=table, |
| num_shards=num_hosts, |
| shard_id=host_id, |
| config=config, |
| )) |
| retrieve_parameters_op = control_flow_ops.group( |
| state_ops.assign(table_variable, retrieved_table), |
| state_ops.assign(momenta_variable, retrieved_momenta)) |
| config = None |
| retrieve_op_list.append(retrieve_parameters_op) |
| return retrieve_op_list |
| |
| return slot_variables, load_ops_fn, retrieve_ops_fn |
| |
| |
| class _RMSPropHandler(_OptimizerHandler): |
| """Handles RMS prop specific logic.""" |
| |
| def set_optimization_parameters(self, table_descriptor): |
| (table_descriptor.optimization_parameters.rms_prop.SetInParent()) |
| table_descriptor.optimization_parameters.rms_prop.rho = ( |
| self._optimization_parameters.rho) |
| table_descriptor.optimization_parameters.rms_prop.epsilon = ( |
| self._optimization_parameters.epsilon) |
| table_descriptor.optimization_parameters.rms_prop.momentum = ( |
| self._optimization_parameters.momentum) |
| |
| def get_default_slot_variable_names(self, table): |
| return RMSPropSlotVariableNames('{}/{}/ms'.format(table, 'RMSProp'), |
| '{}/{}/mom'.format(table, 'RMSProp')) |
| |
| def create_variables_and_ops(self, table, slot_variable_names, num_hosts, |
| table_config, table_variables, config_proto): |
| |
| ms_variables = _create_partitioned_variables( |
| name=slot_variable_names.ms, |
| num_hosts=num_hosts, |
| vocabulary_size=table_config.vocabulary_size, |
| embedding_dimension=table_config.dimension, |
| collections=[ops.GraphKeys.GLOBAL_VARIABLES], |
| initializer=init_ops.zeros_initializer(), |
| ) |
| mom_variables = _create_partitioned_variables( |
| name=slot_variable_names.mom, |
| num_hosts=num_hosts, |
| vocabulary_size=table_config.vocabulary_size, |
| embedding_dimension=table_config.dimension, |
| collections=[ops.GraphKeys.GLOBAL_VARIABLES], |
| initializer=init_ops.zeros_initializer(), |
| ) |
| slot_variables = RMSPropSlotVariables(ms_variables, mom_variables) |
| |
| def load_ops_fn(): |
| """Returns the retrieve ops for RMS Prop embedding tables. |
| |
| Returns: |
| A list of ops to load embedding and slot variables from CPU to TPU. |
| """ |
| load_op_list = [] |
| config = config_proto |
| for host_id, table_variable, ms_variable, mom_variable in (zip( |
| range(num_hosts), table_variables, ms_variables, mom_variables)): |
| with ops.colocate_with(table_variable): |
| load_parameters_op = tpu_ops.load_tpu_embedding_rms_prop_parameters( |
| parameters=table_variable, |
| ms=ms_variable, |
| mom=mom_variable, |
| table_name=table, |
| num_shards=num_hosts, |
| shard_id=host_id, |
| config=config, |
| ) |
| config = None |
| load_op_list.append(load_parameters_op) |
| return load_op_list |
| |
| def retrieve_ops_fn(): |
| """Returns the retrieve ops for RMS Prop embedding tables. |
| |
| Returns: |
| A list of ops to retrieve embedding and slot variables from TPU to CPU. |
| """ |
| retrieve_op_list = [] |
| config = config_proto |
| for host_id, table_variable, ms_variable, mom_variable in (zip( |
| range(num_hosts), table_variables, ms_variables, mom_variables)): |
| with ops.colocate_with(table_variable): |
| retrieved_table, retrieved_ms, retrieved_mom = ( |
| tpu_ops.retrieve_tpu_embedding_rms_prop_parameters( |
| table_name=table, |
| num_shards=num_hosts, |
| shard_id=host_id, |
| config=config, |
| )) |
| retrieve_parameters_op = control_flow_ops.group( |
| state_ops.assign(table_variable, retrieved_table), |
| state_ops.assign(ms_variable, retrieved_ms), |
| state_ops.assign(mom_variable, retrieved_mom)) |
| config = None |
| retrieve_op_list.append(retrieve_parameters_op) |
| return retrieve_op_list |
| |
| return slot_variables, load_ops_fn, retrieve_ops_fn |
| |
| |
| class _FrequencyEstimatorHandler(_OptimizerHandler): |
| """Handles frequency estimator specific logic.""" |
| |
| def set_optimization_parameters(self, table_descriptor): |
| table_descriptor.optimization_parameters.frequency_estimator.SetInParent() |
| freq = table_descriptor.optimization_parameters.frequency_estimator |
| freq.tau = self._optimization_parameters.tau |
| freq.max_delta = self._optimization_parameters.max_delta |
| freq.outlier_threshold = self._optimization_parameters.outlier_threshold |
| freq.weight_exponent = self._optimization_parameters.weight_exponent |
| |
| def get_default_slot_variable_names(self, table): |
| return FrequencyEstimatorSlotVariableNames( |
| '{}/FrequencyEstimator'.format(table)) |
| |
| def create_variables_and_ops(self, table, slot_variable_names, num_hosts, |
| table_config, table_variables, config_proto): |
| if table_config.dimension != 1: |
| raise ValueError('FrequencyEstimator tables should only have a dimension ' |
| 'of 1. Received dimension {}'.format( |
| table_config.dimension)) |
| |
| last_hit_step_variables = _create_partitioned_variables( |
| name=slot_variable_names.last_hit_step, |
| num_hosts=num_hosts, |
| vocabulary_size=table_config.vocabulary_size, |
| embedding_dimension=table_config.dimension, |
| collections=[ops.GraphKeys.GLOBAL_VARIABLES], |
| initializer=init_ops.zeros_initializer(), |
| ) |
| slot_variables = FrequencyEstimatorSlotVariables(last_hit_step_variables) |
| |
| def load_ops_fn(): |
| """Returns the retrieve ops for Frequency Estimator embedding tables. |
| |
| Returns: |
| A list of ops to load embedding and slot variables from CPU to TPU. |
| """ |
| load_op_list = [] |
| config = config_proto |
| for host_id, table_variable, last_hit_step_variable in (zip( |
| range(num_hosts), table_variables, last_hit_step_variables)): |
| with ops.colocate_with(table_variable): |
| load_parameters_op = ( |
| tpu_ops.load_tpu_embedding_frequency_estimator_parameters( |
| parameters=table_variable, |
| last_hit_step=last_hit_step_variable, |
| table_name=table, |
| num_shards=num_hosts, |
| shard_id=host_id, |
| config=config)) |
| config = None |
| load_op_list.append(load_parameters_op) |
| return load_op_list |
| |
| def retrieve_ops_fn(): |
| """Returns the retrieve ops for Frequency Estimator embedding tables. |
| |
| Returns: |
| A list of ops to retrieve embedding and slot variables from TPU to CPU. |
| """ |
| retrieve_op_list = [] |
| config = config_proto |
| for host_id, table_variable, last_hit_step_variable in (zip( |
| range(num_hosts), table_variables, last_hit_step_variables)): |
| with ops.colocate_with(table_variable): |
| retrieved_table, retrieved_last_hit_step = ( |
| tpu_ops.retrieve_tpu_embedding_frequency_estimator_parameters( |
| table_name=table, |
| num_shards=num_hosts, |
| shard_id=host_id, |
| config=config, |
| )) |
| retrieve_parameters_op = control_flow_ops.group( |
| state_ops.assign(table_variable, retrieved_table), |
| state_ops.assign(last_hit_step_variable, retrieved_last_hit_step)) |
| config = None |
| retrieve_op_list.append(retrieve_parameters_op) |
| return retrieve_op_list |
| |
| return slot_variables, load_ops_fn, retrieve_ops_fn |
| |
| |
| class _StochasticGradientDescentHandler(_OptimizerHandler): |
| """Handles stochastic gradient descent specific logic.""" |
| |
| def set_optimization_parameters(self, table_descriptor): |
| (table_descriptor.optimization_parameters.stochastic_gradient_descent |
| .SetInParent()) |
| |
| def get_default_slot_variable_names(self, table): |
| return None |
| |
| def create_variables_and_ops(self, table, slot_variable_names, num_hosts, |
| table_config, table_variables, config_proto): |
| del table_config |
| |
| def load_ops_fn(): |
| """Returns the retrieve ops for AdaGrad embedding tables. |
| |
| Returns: |
| A list of ops to load embedding and slot variables from CPU to TPU. |
| """ |
| load_op_list = [] |
| config = config_proto |
| for host_id, table_variable in enumerate(table_variables): |
| with ops.colocate_with(table_variable): |
| load_parameters_op = ( |
| tpu_ops.load_tpu_embedding_stochastic_gradient_descent_parameters( |
| parameters=table_variable, |
| table_name=table, |
| num_shards=num_hosts, |
| shard_id=host_id, |
| config=config)) |
| config = None |
| load_op_list.append(load_parameters_op) |
| return load_op_list |
| |
| def retrieve_ops_fn(): |
| """Returns the retrieve ops for SGD embedding tables. |
| |
| Returns: |
| A list of ops to retrieve embedding and slot variables from TPU to CPU. |
| """ |
| retrieve_op_list = [] |
| config = config_proto |
| for host_id, table_variable in enumerate(table_variables): |
| with ops.colocate_with(table_variable): |
| retrieved_table = ( |
| tpu_ops |
| .retrieve_tpu_embedding_stochastic_gradient_descent_parameters( |
| table_name=table, |
| num_shards=num_hosts, |
| shard_id=host_id, |
| config=config)) |
| retrieve_parameters_op = control_flow_ops.group( |
| state_ops.assign(table_variable, retrieved_table)) |
| config = None |
| retrieve_op_list.append(retrieve_parameters_op) |
| return retrieve_op_list |
| |
| return None, load_ops_fn, retrieve_ops_fn |
| |
| |
| def _get_optimization_handler(optimization_parameters): |
| """Gets the optimization handler given the parameter type.""" |
| if isinstance(optimization_parameters, AdagradParameters): |
| return _AdagradHandler(optimization_parameters) |
| elif isinstance(optimization_parameters, AdagradMomentumParameters): |
| return _AdagradMomentumHandler(optimization_parameters) |
| elif isinstance(optimization_parameters, ProximalAdagradParameters): |
| return _ProximalAdagradHandler(optimization_parameters) |
| elif isinstance(optimization_parameters, AdamParameters): |
| return _AdamHandler(optimization_parameters) |
| elif isinstance(optimization_parameters, FtrlParameters): |
| return _FtrlHandler(optimization_parameters) |
| elif isinstance(optimization_parameters, ProximalYogiParameters): |
| return _ProximalYogiHandler(optimization_parameters) |
| elif isinstance(optimization_parameters, StochasticGradientDescentParameters): |
| return _StochasticGradientDescentHandler(optimization_parameters) |
| elif isinstance(optimization_parameters, MomentumParameters): |
| return _MomentumHandler(optimization_parameters) |
| elif isinstance(optimization_parameters, RMSPropParameters): |
| return _RMSPropHandler(optimization_parameters) |
| elif isinstance(optimization_parameters, FrequencyEstimatorParameters): |
| return _FrequencyEstimatorHandler(optimization_parameters) |
| return NotImplementedError() |
| |
| |
| def _create_ordered_dict(d): |
| """Create an OrderedDict from Dict.""" |
| return collections.OrderedDict((k, d[k]) for k in sorted(d)) |
| |
| |
| def _create_combiners(table_to_config_dict, table_to_features_dict): |
| """Create a per feature list of combiners, ordered by table.""" |
| combiners = [] |
| for table in table_to_config_dict: |
| combiner = table_to_config_dict[table].combiner or 'sum' |
| combiners.extend([combiner] * len(table_to_features_dict[table])) |
| return combiners |
| |
| |
| def _create_table_to_features_and_num_features_dicts(feature_to_config_dict): |
| """Create mapping from table to a list of its features.""" |
| table_to_features_dict_tmp = {} |
| table_to_num_features_dict_tmp = {} |
| for feature, feature_config in six.iteritems(feature_to_config_dict): |
| if feature_config.table_id in table_to_features_dict_tmp: |
| table_to_features_dict_tmp[feature_config.table_id].append(feature) |
| else: |
| table_to_features_dict_tmp[feature_config.table_id] = [feature] |
| table_to_num_features_dict_tmp[feature_config.table_id] = 0 |
| if feature_config.max_sequence_length == 0: |
| table_to_num_features_dict_tmp[feature_config.table_id] = ( |
| table_to_num_features_dict_tmp[feature_config.table_id] + 1) |
| else: |
| table_to_num_features_dict_tmp[feature_config.table_id] = ( |
| table_to_num_features_dict_tmp[feature_config.table_id] + |
| feature_config.max_sequence_length) |
| |
| table_to_features_dict = collections.OrderedDict() |
| table_to_num_features_dict = collections.OrderedDict() |
| for table in sorted(table_to_features_dict_tmp): |
| table_to_features_dict[table] = sorted(table_to_features_dict_tmp[table]) |
| table_to_num_features_dict[table] = table_to_num_features_dict_tmp[table] |
| return table_to_features_dict, table_to_num_features_dict |
| |
| |
| def _create_device_fn(hosts): |
| """Create device_fn() to use with _create_partitioned_variables().""" |
| |
| def device_fn(op): |
| """Returns the `device` for `op`.""" |
| part_match = re.match(r'.*/part_(\d+)(/|$)', op.name) |
| dummy_match = re.match(r'.*dummy_(\d+).*', op.name) |
| if not part_match and not dummy_match: |
| raise RuntimeError( |
| 'Internal Error: Expected {} to contain /part_* or dummy_*'.format( |
| op.name)) |
| |
| if part_match: |
| idx = int(part_match.group(1)) |
| else: |
| idx = int(dummy_match.group(1)) # pytype: disable=attribute-error |
| |
| device = hosts[idx] |
| logging.debug('assigning {} to {}.', op, device) |
| return device |
| |
| return device_fn |
| |
| |
| def _create_partitioned_variables(name, |
| num_hosts, |
| vocabulary_size, |
| embedding_dimension, |
| initializer, |
| collections=None): # pylint: disable=redefined-outer-name |
| """Creates PartitionedVariables based on `num_hosts` for `table`.""" |
| |
| num_slices = min(vocabulary_size, num_hosts) |
| |
| var_list = list( |
| variable_scope.get_variable( |
| name, |
| shape=(vocabulary_size, embedding_dimension), |
| partitioner=partitioned_variables.fixed_size_partitioner(num_slices), |
| dtype=dtypes.float32, |
| initializer=initializer, |
| collections=collections, |
| trainable=False)) |
| |
| if vocabulary_size >= num_hosts: |
| return var_list |
| |
| # For padded part, define the dummy variable to be loaded into TPU system. |
| for idx in range(num_hosts - vocabulary_size): |
| var_list.append( |
| variable_scope.get_variable( |
| 'dummy_{}_{}'.format(vocabulary_size + idx, name), |
| shape=(1, embedding_dimension), |
| dtype=dtypes.float32, |
| initializer=initializer, |
| collections=[ops.GraphKeys.LOCAL_VARIABLES], |
| trainable=False)) |
| |
| return var_list |