| syntax = "proto3"; |
| |
| package tensorflow.tpu; |
| |
| import "tensorflow/contrib/tpu/proto/optimization_parameters.proto"; |
| |
| // The TPUEmbeddingConfiguration contains specification of TPU Embedding lookups |
| // and gradient updates separate from the TF Graph. |
| message TPUEmbeddingConfiguration { |
| // model_mode specifies whether the model is to be run in training or |
| // inference. In inference mode, gradient updates to embedding tables are not |
| // performed. |
| enum ModelMode { |
| INVALID = 0; |
| TRAINING = 1; |
| INFERENCE = 2; |
| } |
| |
| ModelMode model_mode = 1; |
| |
| // num_hosts is the number of host CPU systems in the training/inference job. |
| // Each embedding table must be sharded into num_hosts separate Variables, |
| // placed separately on the num_hosts CPU devices in the cluster. Sharding |
| // will be performed equivalently to the 'div' sharding_strategy option of |
| // embedding_lookup() and embedding_lookup_sparse(). |
| int32 num_hosts = 2; |
| |
| // The total number of TensorNodes. This is equal to num_hosts times the |
| // number of TensorNodes attached to each host. |
| int32 num_tensornodes = 3; |
| |
| // The number of training examples per TensorNode. |
| int32 batch_size = 4; |
| |
| // Each Embedding |
| message TPUEmbeddingTable { |
| // Name of the embedding table. This will be used to name Variables in the |
| // Tensorflow Graph. |
| string name = 1; |
| |
| // Number of rows of the embedding table. The Variable created to hold the |
| // learned embedding table values will have shape (num_rows, width). |
| int32 num_rows = 3; |
| |
| // Width of the embedding table. The Variable created to hold the |
| // learned embedding table values will have shape (num_rows, width). |
| int32 width = 4; |
| |
| // Number of distinct embedding activation vectors per training example |
| // produced by lookups into this table during model evaluation. For each |
| // table, the Graph will receive an activations Tensor of shape |
| // (batch_size * table.num_features, table.width). |
| // For example, num_features = 1 produces equivalent behavior to a single |
| // tf.nn.embedding_lookup() call. In the case of 'multivalent' embeddings, |
| // (i.e. tf.nn.embedding_lookup_sparse()) which compute weighted averages of |
| // embedding table rows, num_features is the number of vectors produced |
| // after averaging. In sequence models num_features is typically equal |
| // to the sequence length, since each sequence element must be represented |
| // separately to the convolutional or recurrent network. |
| int32 num_features = 5; |
| |
| OptimizationParameters optimization_parameters = 6; |
| } |
| |
| repeated TPUEmbeddingTable table_config = 5; |
| } |