blob: 3476cc89534efb7fe05640935d1387d02737f240 [file] [log] [blame]
syntax = "proto3";
package tensorflow.tpu;
import "tensorflow/contrib/tpu/proto/optimization_parameters.proto";
// The TPUEmbeddingConfiguration contains specification of TPU Embedding lookups
// and gradient updates separate from the TF Graph.
message TPUEmbeddingConfiguration {
// model_mode specifies whether the model is to be run in training or
// inference. In inference mode, gradient updates to embedding tables are not
// performed.
enum ModelMode {
INVALID = 0;
TRAINING = 1;
INFERENCE = 2;
}
ModelMode model_mode = 1;
// num_hosts is the number of host CPU systems in the training/inference job.
// Each embedding table must be sharded into num_hosts separate Variables,
// placed separately on the num_hosts CPU devices in the cluster. Sharding
// will be performed equivalently to the 'div' sharding_strategy option of
// embedding_lookup() and embedding_lookup_sparse().
int32 num_hosts = 2;
// The total number of TensorNodes. This is equal to num_hosts times the
// number of TensorNodes attached to each host.
int32 num_tensornodes = 3;
// The number of training examples per TensorNode.
int32 batch_size = 4;
// Each Embedding
message TPUEmbeddingTable {
// Name of the embedding table. This will be used to name Variables in the
// Tensorflow Graph.
string name = 1;
// Number of rows of the embedding table. The Variable created to hold the
// learned embedding table values will have shape (num_rows, width).
int32 num_rows = 3;
// Width of the embedding table. The Variable created to hold the
// learned embedding table values will have shape (num_rows, width).
int32 width = 4;
// Number of distinct embedding activation vectors per training example
// produced by lookups into this table during model evaluation. For each
// table, the Graph will receive an activations Tensor of shape
// (batch_size * table.num_features, table.width).
// For example, num_features = 1 produces equivalent behavior to a single
// tf.nn.embedding_lookup() call. In the case of 'multivalent' embeddings,
// (i.e. tf.nn.embedding_lookup_sparse()) which compute weighted averages of
// embedding table rows, num_features is the number of vectors produced
// after averaging. In sequence models num_features is typically equal
// to the sequence length, since each sequence element must be represented
// separately to the convolutional or recurrent network.
int32 num_features = 5;
OptimizationParameters optimization_parameters = 6;
}
repeated TPUEmbeddingTable table_config = 5;
}