| from __future__ import absolute_import |
| from __future__ import division |
| from __future__ import print_function |
| from __future__ import unicode_literals |
| |
| import argparse |
| import collections |
| import logging |
| import math |
| import numpy as np |
| import random |
| import time |
| import sys |
| |
| from itertools import izip |
| |
| import caffe2.proto.caffe2_pb2 as caffe2_pb2 |
| from caffe2.python import core, workspace, recurrent, data_parallel_model |
| from caffe2.python.examples import seq2seq_util |
| |
| logger = logging.getLogger(__name__) |
| logger.setLevel(logging.INFO) |
| logger.addHandler(logging.StreamHandler(sys.stderr)) |
| |
| Batch = collections.namedtuple('Batch', [ |
| 'encoder_inputs', |
| 'encoder_lengths', |
| 'decoder_inputs', |
| 'decoder_lengths', |
| 'targets', |
| 'target_weights', |
| ]) |
| |
| _PAD_ID = 0 |
| _GO_ID = 1 |
| _EOS_ID = 2 |
| EOS = '<EOS>' |
| UNK = '<UNK>' |
| GO = '<GO>' |
| PAD = '<PAD>' |
| |
| |
| def prepare_batch(batch): |
| encoder_lengths = [len(entry[0]) for entry in batch] |
| max_encoder_length = max(encoder_lengths) |
| decoder_lengths = [] |
| max_decoder_length = max([len(entry[1]) for entry in batch]) |
| |
| batch_encoder_inputs = [] |
| batch_decoder_inputs = [] |
| batch_targets = [] |
| batch_target_weights = [] |
| |
| for source_seq, target_seq in batch: |
| encoder_pads = ( |
| [_PAD_ID] * (max_encoder_length - len(source_seq)) |
| ) |
| batch_encoder_inputs.append( |
| list(reversed(source_seq)) + encoder_pads |
| ) |
| |
| decoder_pads = ( |
| [_PAD_ID] * (max_decoder_length - len(target_seq)) |
| ) |
| target_seq_with_go_token = [_GO_ID] + target_seq |
| decoder_lengths.append(len(target_seq_with_go_token)) |
| batch_decoder_inputs.append(target_seq_with_go_token + decoder_pads) |
| |
| target_seq_with_eos = target_seq + [_EOS_ID] |
| targets = target_seq_with_eos + decoder_pads |
| batch_targets.append(targets) |
| |
| if len(source_seq) + len(target_seq) == 0: |
| target_weights = [0] * len(targets) |
| else: |
| target_weights = [ |
| 1 if target != _PAD_ID else 0 |
| for target in targets |
| ] |
| batch_target_weights.append(target_weights) |
| |
| return Batch( |
| encoder_inputs=np.array( |
| batch_encoder_inputs, |
| dtype=np.int32, |
| ).transpose(), |
| encoder_lengths=np.array(encoder_lengths, dtype=np.int32), |
| decoder_inputs=np.array( |
| batch_decoder_inputs, |
| dtype=np.int32, |
| ).transpose(), |
| decoder_lengths=np.array(decoder_lengths, dtype=np.int32), |
| targets=np.array( |
| batch_targets, |
| dtype=np.int32, |
| ).transpose(), |
| target_weights=np.array( |
| batch_target_weights, |
| dtype=np.float32, |
| ).transpose(), |
| ) |
| |
| |
| class Seq2SeqModelCaffe2: |
| |
| def _build_model( |
| self, |
| init_params, |
| ): |
| model = seq2seq_util.ModelHelper(init_params=init_params) |
| self._build_shared(model) |
| self._build_embeddings(model) |
| |
| forward_model = seq2seq_util.ModelHelper(init_params=init_params) |
| self._build_shared(forward_model) |
| self._build_embeddings(forward_model) |
| |
| if self.num_gpus == 0: |
| loss_blobs = self.model_build_fun(model) |
| model.AddGradientOperators(loss_blobs) |
| self.norm_clipped_grad_update( |
| model, |
| scope='norm_clipped_grad_update' |
| ) |
| self.forward_model_build_fun(forward_model) |
| |
| else: |
| assert (self.batch_size % self.num_gpus) == 0 |
| |
| data_parallel_model.Parallelize_GPU( |
| forward_model, |
| input_builder_fun=lambda m: None, |
| forward_pass_builder_fun=self.forward_model_build_fun, |
| param_update_builder_fun=None, |
| devices=range(self.num_gpus), |
| ) |
| |
| def clipped_grad_update_bound(model): |
| self.norm_clipped_grad_update( |
| model, |
| scope='norm_clipped_grad_update', |
| ) |
| |
| data_parallel_model.Parallelize_GPU( |
| model, |
| input_builder_fun=lambda m: None, |
| forward_pass_builder_fun=self.model_build_fun, |
| param_update_builder_fun=clipped_grad_update_bound, |
| devices=range(self.num_gpus), |
| ) |
| self.norm_clipped_sparse_grad_update( |
| model, |
| scope='norm_clipped_sparse_grad_update', |
| ) |
| self.model = model |
| self.forward_net = forward_model.net |
| |
| def _build_embedding_encoder( |
| self, |
| model, |
| inputs, |
| input_lengths, |
| vocab_size, |
| embeddings, |
| embedding_size, |
| use_attention, |
| num_gpus, |
| forward_only=False, |
| ): |
| if num_gpus == 0: |
| embedded_encoder_inputs = model.net.Gather( |
| [embeddings, inputs], |
| ['embedded_encoder_inputs'], |
| ) |
| else: |
| with core.DeviceScope(core.DeviceOption(caffe2_pb2.CPU)): |
| embedded_encoder_inputs_cpu = model.net.Gather( |
| [embeddings, inputs], |
| ['embedded_encoder_inputs_cpu'], |
| ) |
| embedded_encoder_inputs = model.CopyCPUToGPU( |
| embedded_encoder_inputs_cpu, |
| 'embedded_encoder_inputs', |
| ) |
| |
| if self.encoder_type == 'rnn': |
| assert len(self.encoder_params['encoder_layer_configs']) == 1 |
| encoder_num_units = ( |
| self.encoder_params['encoder_layer_configs'][0]['num_units'] |
| ) |
| encoder_initial_cell_state = model.param_init_net.ConstantFill( |
| [], |
| ['encoder_initial_cell_state'], |
| shape=[encoder_num_units], |
| value=0.0, |
| ) |
| encoder_initial_hidden_state = ( |
| model.param_init_net.ConstantFill( |
| [], |
| 'encoder_initial_hidden_state', |
| shape=[encoder_num_units], |
| value=0.0, |
| ) |
| ) |
| # Choose corresponding rnn encoder function |
| if self.encoder_params['use_bidirectional_encoder']: |
| rnn_encoder_func = seq2seq_util.rnn_bidirectional_encoder |
| encoder_output_dim = 2 * encoder_num_units |
| else: |
| rnn_encoder_func = seq2seq_util.rnn_unidirectional_encoder |
| encoder_output_dim = encoder_num_units |
| |
| ( |
| encoder_outputs, |
| final_encoder_hidden_state, |
| final_encoder_cell_state, |
| ) = rnn_encoder_func( |
| model, |
| embedded_encoder_inputs, |
| input_lengths, |
| encoder_initial_hidden_state, |
| encoder_initial_cell_state, |
| embedding_size, |
| encoder_num_units, |
| use_attention, |
| ) |
| weighted_encoder_outputs = None |
| else: |
| raise ValueError('Unsupported encoder type {}'.format( |
| self.encoder_type)) |
| |
| return ( |
| encoder_outputs, |
| weighted_encoder_outputs, |
| final_encoder_hidden_state, |
| final_encoder_cell_state, |
| encoder_output_dim, |
| ) |
| |
| def output_projection( |
| self, |
| model, |
| decoder_outputs, |
| decoder_output_size, |
| target_vocab_size, |
| decoder_softmax_size, |
| ): |
| if decoder_softmax_size is not None: |
| decoder_outputs = model.FC( |
| decoder_outputs, |
| 'decoder_outputs_scaled', |
| dim_in=decoder_output_size, |
| dim_out=decoder_softmax_size, |
| ) |
| decoder_output_size = decoder_softmax_size |
| |
| output_projection_w = model.param_init_net.XavierFill( |
| [], |
| 'output_projection_w', |
| shape=[self.target_vocab_size, decoder_output_size], |
| ) |
| |
| output_projection_b = model.param_init_net.XavierFill( |
| [], |
| 'output_projection_b', |
| shape=[self.target_vocab_size], |
| ) |
| model.params.extend([ |
| output_projection_w, |
| output_projection_b, |
| ]) |
| output_logits = model.net.FC( |
| [ |
| decoder_outputs, |
| output_projection_w, |
| output_projection_b, |
| ], |
| ['output_logits'], |
| ) |
| return output_logits |
| |
| def _build_shared(self, model): |
| optimizer_params = self.model_params['optimizer_params'] |
| with core.DeviceScope(core.DeviceOption(caffe2_pb2.CPU)): |
| self.learning_rate = model.AddParam( |
| name='learning_rate', |
| init_value=float(optimizer_params['learning_rate']), |
| trainable=False, |
| ) |
| self.global_step = model.AddParam( |
| name='global_step', |
| init_value=0, |
| trainable=False, |
| ) |
| self.start_time = model.AddParam( |
| name='start_time', |
| init_value=time.time(), |
| trainable=False, |
| ) |
| |
| def _build_embeddings(self, model): |
| with core.DeviceScope(core.DeviceOption(caffe2_pb2.CPU)): |
| sqrt3 = math.sqrt(3) |
| self.encoder_embeddings = model.param_init_net.UniformFill( |
| [], |
| 'encoder_embeddings', |
| shape=[ |
| self.source_vocab_size, |
| self.model_params['encoder_embedding_size'], |
| ], |
| min=-sqrt3, |
| max=sqrt3, |
| ) |
| model.params.append(self.encoder_embeddings) |
| self.decoder_embeddings = model.param_init_net.UniformFill( |
| [], |
| 'decoder_embeddings', |
| shape=[ |
| self.target_vocab_size, |
| self.model_params['decoder_embedding_size'], |
| ], |
| min=-sqrt3, |
| max=sqrt3, |
| ) |
| model.params.append(self.decoder_embeddings) |
| |
| def model_build_fun(self, model, forward_only=False, loss_scale=None): |
| encoder_inputs = model.net.AddExternalInput( |
| workspace.GetNameScope() + 'encoder_inputs', |
| ) |
| encoder_lengths = model.net.AddExternalInput( |
| workspace.GetNameScope() + 'encoder_lengths', |
| ) |
| decoder_inputs = model.net.AddExternalInput( |
| workspace.GetNameScope() + 'decoder_inputs', |
| ) |
| decoder_lengths = model.net.AddExternalInput( |
| workspace.GetNameScope() + 'decoder_lengths', |
| ) |
| targets = model.net.AddExternalInput( |
| workspace.GetNameScope() + 'targets', |
| ) |
| target_weights = model.net.AddExternalInput( |
| workspace.GetNameScope() + 'target_weights', |
| ) |
| attention_type = self.model_params['attention'] |
| assert attention_type in ['none', 'regular'] |
| |
| ( |
| encoder_outputs, |
| weighted_encoder_outputs, |
| final_encoder_hidden_state, |
| final_encoder_cell_state, |
| encoder_output_dim, |
| ) = self._build_embedding_encoder( |
| model=model, |
| inputs=encoder_inputs, |
| input_lengths=encoder_lengths, |
| vocab_size=self.source_vocab_size, |
| embeddings=self.encoder_embeddings, |
| embedding_size=self.model_params['encoder_embedding_size'], |
| use_attention=(attention_type != 'none'), |
| num_gpus=self.num_gpus, |
| forward_only=forward_only, |
| ) |
| |
| assert len(self.model_params['decoder_layer_configs']) == 1 |
| decoder_num_units = ( |
| self.model_params['decoder_layer_configs'][0]['num_units'] |
| ) |
| |
| if attention_type == 'none': |
| decoder_initial_hidden_state = model.FC( |
| final_encoder_hidden_state, |
| 'decoder_initial_hidden_state', |
| encoder_output_dim, |
| decoder_num_units, |
| axis=2, |
| ) |
| decoder_initial_cell_state = model.FC( |
| final_encoder_cell_state, |
| 'decoder_initial_cell_state', |
| encoder_output_dim, |
| decoder_num_units, |
| axis=2, |
| ) |
| else: |
| decoder_initial_hidden_state = model.param_init_net.ConstantFill( |
| [], |
| 'decoder_initial_hidden_state', |
| shape=[decoder_num_units], |
| value=0.0, |
| ) |
| decoder_initial_cell_state = model.param_init_net.ConstantFill( |
| [], |
| 'decoder_initial_cell_state', |
| shape=[decoder_num_units], |
| value=0.0, |
| ) |
| initial_attention_weighted_encoder_context = ( |
| model.param_init_net.ConstantFill( |
| [], |
| 'initial_attention_weighted_encoder_context', |
| shape=[encoder_output_dim], |
| value=0.0, |
| ) |
| ) |
| |
| if self.num_gpus == 0: |
| embedded_decoder_inputs = model.net.Gather( |
| [self.decoder_embeddings, decoder_inputs], |
| ['embedded_decoder_inputs'], |
| ) |
| else: |
| with core.DeviceScope(core.DeviceOption(caffe2_pb2.CPU)): |
| embedded_decoder_inputs_cpu = model.net.Gather( |
| [self.decoder_embeddings, decoder_inputs], |
| ['embedded_decoder_inputs_cpu'], |
| ) |
| embedded_decoder_inputs = model.CopyCPUToGPU( |
| embedded_decoder_inputs_cpu, |
| 'embedded_decoder_inputs', |
| ) |
| |
| # seq_len x batch_size x decoder_embedding_size |
| if attention_type == 'none': |
| decoder_outputs, _, _, _ = recurrent.LSTM( |
| model=model, |
| input_blob=embedded_decoder_inputs, |
| seq_lengths=decoder_lengths, |
| initial_states=( |
| decoder_initial_hidden_state, |
| decoder_initial_cell_state, |
| ), |
| dim_in=self.model_params['decoder_embedding_size'], |
| dim_out=decoder_num_units, |
| scope='decoder', |
| outputs_with_grads=[0], |
| ) |
| decoder_output_size = decoder_num_units |
| else: |
| ( |
| decoder_outputs, _, _, _, |
| attention_weighted_encoder_contexts, _ |
| ) = recurrent.LSTMWithAttention( |
| model=model, |
| decoder_inputs=embedded_decoder_inputs, |
| decoder_input_lengths=decoder_lengths, |
| initial_decoder_hidden_state=decoder_initial_hidden_state, |
| initial_decoder_cell_state=decoder_initial_cell_state, |
| initial_attention_weighted_encoder_context=( |
| initial_attention_weighted_encoder_context |
| ), |
| encoder_output_dim=encoder_output_dim, |
| encoder_outputs=encoder_outputs, |
| decoder_input_dim=self.model_params['decoder_embedding_size'], |
| decoder_state_dim=decoder_num_units, |
| scope='decoder', |
| outputs_with_grads=[0, 4], |
| ) |
| decoder_outputs, _ = model.net.Concat( |
| [decoder_outputs, attention_weighted_encoder_contexts], |
| [ |
| 'states_and_context_combination', |
| '_states_and_context_combination_concat_dims', |
| ], |
| axis=2, |
| ) |
| decoder_output_size = decoder_num_units + encoder_output_dim |
| |
| # we do softmax over the whole sequence |
| # (max_length in the batch * batch_size) x decoder embedding size |
| # -1 because we don't know max_length yet |
| decoder_outputs_flattened, _ = model.net.Reshape( |
| [decoder_outputs], |
| [ |
| 'decoder_outputs_flattened', |
| 'decoder_outputs_and_contexts_combination_old_shape', |
| ], |
| shape=[-1, decoder_output_size], |
| ) |
| output_logits = self.output_projection( |
| model=model, |
| decoder_outputs=decoder_outputs_flattened, |
| decoder_output_size=decoder_output_size, |
| target_vocab_size=self.target_vocab_size, |
| decoder_softmax_size=self.model_params['decoder_softmax_size'], |
| ) |
| targets, _ = model.net.Reshape( |
| [targets], |
| ['targets', 'targets_old_shape'], |
| shape=[-1], |
| ) |
| target_weights, _ = model.net.Reshape( |
| [target_weights], |
| ['target_weights', 'target_weights_old_shape'], |
| shape=[-1], |
| ) |
| output_probs = model.net.Softmax( |
| [output_logits], |
| ['output_probs'], |
| engine=('CUDNN' if self.num_gpus > 0 else None), |
| ) |
| label_cross_entropy = model.net.LabelCrossEntropy( |
| [output_probs, targets], |
| ['label_cross_entropy'], |
| ) |
| weighted_label_cross_entropy = model.net.Mul( |
| [label_cross_entropy, target_weights], |
| 'weighted_label_cross_entropy', |
| ) |
| total_loss_scalar = model.net.SumElements( |
| [weighted_label_cross_entropy], |
| 'total_loss_scalar', |
| ) |
| total_loss_scalar_weighted = model.net.Scale( |
| [total_loss_scalar], |
| 'total_loss_scalar_weighted', |
| scale=1.0 / self.batch_size, |
| ) |
| return [total_loss_scalar_weighted] |
| |
| def forward_model_build_fun(self, model, loss_scale=None): |
| return self.model_build_fun( |
| model=model, |
| forward_only=True, |
| loss_scale=loss_scale |
| ) |
| |
| def _calc_norm_ratio(self, model, params, scope, ONE): |
| with core.NameScope(scope): |
| grad_squared_sums = [] |
| for i, param in enumerate(params): |
| logger.info(param) |
| grad = ( |
| model.param_to_grad[param] |
| if not isinstance( |
| model.param_to_grad[param], |
| core.GradientSlice, |
| ) else model.param_to_grad[param].values |
| ) |
| grad_squared = model.net.Sqr( |
| [grad], |
| 'grad_{}_squared'.format(i), |
| ) |
| grad_squared_sum = model.net.SumElements( |
| grad_squared, |
| 'grad_{}_squared_sum'.format(i), |
| ) |
| grad_squared_sums.append(grad_squared_sum) |
| |
| grad_squared_full_sum = model.net.Sum( |
| grad_squared_sums, |
| 'grad_squared_full_sum', |
| ) |
| global_norm = model.net.Pow( |
| grad_squared_full_sum, |
| 'global_norm', |
| exponent=0.5, |
| ) |
| clip_norm = model.param_init_net.ConstantFill( |
| [], |
| 'clip_norm', |
| shape=[], |
| value=float(self.model_params['max_gradient_norm']), |
| ) |
| max_norm = model.net.Max( |
| [global_norm, clip_norm], |
| 'max_norm', |
| ) |
| norm_ratio = model.net.Div( |
| [clip_norm, max_norm], |
| 'norm_ratio', |
| ) |
| return norm_ratio |
| |
| def _apply_norm_ratio( |
| self, norm_ratio, model, params, learning_rate, scope, ONE |
| ): |
| for param in params: |
| param_grad = model.param_to_grad[param] |
| nlr = model.net.Negative( |
| [learning_rate], |
| 'negative_learning_rate', |
| ) |
| with core.NameScope(scope): |
| update_coeff = model.net.Mul( |
| [nlr, norm_ratio], |
| 'update_coeff', |
| broadcast=1, |
| ) |
| if isinstance(param_grad, core.GradientSlice): |
| param_grad_values = param_grad.values |
| |
| model.net.ScatterWeightedSum( |
| [ |
| param, |
| ONE, |
| param_grad.indices, |
| param_grad_values, |
| update_coeff, |
| ], |
| param, |
| ) |
| else: |
| model.net.WeightedSum( |
| [ |
| param, |
| ONE, |
| param_grad, |
| update_coeff, |
| ], |
| param, |
| ) |
| |
| def norm_clipped_grad_update(self, model, scope): |
| |
| if self.num_gpus == 0: |
| learning_rate = self.learning_rate |
| else: |
| learning_rate = model.CopyCPUToGPU(self.learning_rate, 'LR') |
| |
| params = [] |
| for param in model.GetParams(top_scope=True): |
| if param in model.param_to_grad: |
| if not isinstance( |
| model.param_to_grad[param], |
| core.GradientSlice, |
| ): |
| params.append(param) |
| |
| ONE = model.param_init_net.ConstantFill( |
| [], |
| 'ONE', |
| shape=[1], |
| value=1.0, |
| ) |
| logger.info('Dense trainable variables: ') |
| norm_ratio = self._calc_norm_ratio(model, params, scope, ONE) |
| self._apply_norm_ratio( |
| norm_ratio, model, params, learning_rate, scope, ONE |
| ) |
| |
| def norm_clipped_sparse_grad_update(self, model, scope): |
| learning_rate = self.learning_rate |
| |
| params = [] |
| for param in model.GetParams(top_scope=True): |
| if param in model.param_to_grad: |
| if isinstance( |
| model.param_to_grad[param], |
| core.GradientSlice, |
| ): |
| params.append(param) |
| |
| ONE = model.param_init_net.ConstantFill( |
| [], |
| 'ONE', |
| shape=[1], |
| value=1.0, |
| ) |
| logger.info('Sparse trainable variables: ') |
| norm_ratio = self._calc_norm_ratio(model, params, scope, ONE) |
| self._apply_norm_ratio( |
| norm_ratio, model, params, learning_rate, scope, ONE |
| ) |
| |
| def total_loss_scalar(self): |
| if self.num_gpus == 0: |
| return workspace.FetchBlob('total_loss_scalar') |
| else: |
| total_loss = 0 |
| for i in range(self.num_gpus): |
| name = 'gpu_{}/total_loss_scalar'.format(i) |
| gpu_loss = workspace.FetchBlob(name) |
| total_loss += gpu_loss |
| return total_loss |
| |
| def _init_model(self): |
| workspace.RunNetOnce(self.model.param_init_net) |
| |
| def create_net(net): |
| workspace.CreateNet( |
| net, |
| input_blobs=map(str, net.external_inputs), |
| ) |
| |
| create_net(self.model.net) |
| create_net(self.forward_net) |
| |
| def __init__( |
| self, |
| model_params, |
| source_vocab_size, |
| target_vocab_size, |
| num_gpus=1, |
| num_cpus=1, |
| ): |
| self.model_params = model_params |
| self.encoder_type = 'rnn' |
| self.encoder_params = model_params['encoder_type'] |
| self.source_vocab_size = source_vocab_size |
| self.target_vocab_size = target_vocab_size |
| self.num_gpus = num_gpus |
| self.num_cpus = num_cpus |
| self.batch_size = model_params['batch_size'] |
| |
| workspace.GlobalInit([ |
| 'caffe2', |
| # NOTE: modify log level for debugging purposes |
| '--caffe2_log_level=0', |
| # NOTE: modify log level for debugging purposes |
| '--v=0', |
| # Fail gracefully if one of the threads fails |
| '--caffe2_handle_executor_threads_exceptions=1', |
| '--caffe2_mkl_num_threads=' + str(self.num_cpus), |
| ]) |
| |
| def __enter__(self): |
| return self |
| |
| def __exit__(self, exc_type, exc_value, traceback): |
| workspace.ResetWorkspace() |
| |
| def initialize_from_scratch(self): |
| logger.info('Initializing Seq2SeqModelCaffe2 from scratch: Start') |
| self._build_model(init_params=True) |
| self._init_model() |
| logger.info('Initializing Seq2SeqModelCaffe2 from scratch: Finish') |
| |
| def get_current_step(self): |
| return workspace.FetchBlob(self.global_step)[0] |
| |
| def inc_current_step(self): |
| workspace.FeedBlob( |
| self.global_step, |
| np.array([self.get_current_step() + 1]), |
| ) |
| |
| def step( |
| self, |
| batch, |
| forward_only |
| ): |
| if self.num_gpus < 1: |
| batch_obj = prepare_batch(batch) |
| for batch_obj_name, batch_obj_value in izip( |
| Batch._fields, |
| batch_obj, |
| ): |
| workspace.FeedBlob(batch_obj_name, batch_obj_value) |
| else: |
| for i in range(self.num_gpus): |
| gpu_batch = batch[i::self.num_gpus] |
| batch_obj = prepare_batch(gpu_batch) |
| for batch_obj_name, batch_obj_value in izip( |
| Batch._fields, |
| batch_obj, |
| ): |
| name = 'gpu_{}/{}'.format(i, batch_obj_name) |
| if batch_obj_name in ['encoder_inputs', 'decoder_inputs']: |
| dev = core.DeviceOption(caffe2_pb2.CPU) |
| else: |
| dev = core.DeviceOption(caffe2_pb2.CUDA, i) |
| workspace.FeedBlob(name, batch_obj_value, device_option=dev) |
| |
| if forward_only: |
| workspace.RunNet(self.forward_net) |
| else: |
| workspace.RunNet(self.model.net) |
| self.inc_current_step() |
| |
| return self.total_loss_scalar() |
| |
| |
| def gen_vocab(corpus, unk_threshold): |
| vocab = collections.defaultdict(lambda: len(vocab)) |
| freqs = collections.defaultdict(lambda: 0) |
| # Adding padding tokens to the vocabulary to maintain consistency with IDs |
| vocab[PAD] |
| vocab[GO] |
| vocab[EOS] |
| vocab[UNK] |
| |
| with open(corpus) as f: |
| for sentence in f: |
| tokens = sentence.strip().split() |
| for token in tokens: |
| freqs[token] += 1 |
| for token, freq in freqs.items(): |
| if freq > unk_threshold: |
| # TODO: Add reverse lookup dict when it becomes necessary |
| vocab[token] |
| |
| return vocab |
| |
| |
| def get_numberized_sentence(sentence, vocab): |
| numerized_sentence = [] |
| for token in sentence.strip().split(): |
| if token in vocab: |
| numerized_sentence.append(vocab[token]) |
| else: |
| numerized_sentence.append(vocab[UNK]) |
| return numerized_sentence |
| |
| |
| def gen_batches(source_corpus, target_corpus, source_vocab, target_vocab, |
| batch_size, max_length): |
| with open(source_corpus) as source, open(target_corpus) as target: |
| parallel_sentences = [] |
| for source_sentence, target_sentence in zip(source, target): |
| numerized_source_sentence = get_numberized_sentence( |
| source_sentence, |
| source_vocab, |
| ) |
| numerized_target_sentence = get_numberized_sentence( |
| target_sentence, |
| target_vocab, |
| ) |
| if ( |
| len(numerized_source_sentence) > 0 and |
| len(numerized_target_sentence) > 0 and |
| ( |
| max_length is None or ( |
| len(numerized_source_sentence) <= max_length and |
| len(numerized_target_sentence) <= max_length |
| ) |
| ) |
| ): |
| parallel_sentences.append(( |
| numerized_source_sentence, |
| numerized_target_sentence, |
| )) |
| parallel_sentences.sort(key=lambda s_t: (len(s_t[0]), len(s_t[1]))) |
| |
| batches, batch = [], [] |
| for sentence_pair in parallel_sentences: |
| batch.append(sentence_pair) |
| if len(batch) >= batch_size: |
| batches.append(batch) |
| batch = [] |
| if len(batch) > 0: |
| while len(batch) < batch_size: |
| batch.append(batch[-1]) |
| assert len(batch) == batch_size |
| batches.append(batch) |
| random.shuffle(batches) |
| return batches |
| |
| |
| def run_seq2seq_model(args, model_params=None): |
| source_vocab = gen_vocab(args.source_corpus, args.unk_threshold) |
| target_vocab = gen_vocab(args.target_corpus, args.unk_threshold) |
| logger.info('Source vocab size {}'.format(len(source_vocab))) |
| logger.info('Target vocab size {}'.format(len(target_vocab))) |
| |
| batches = gen_batches(args.source_corpus, args.target_corpus, source_vocab, |
| target_vocab, model_params['batch_size'], |
| args.max_length) |
| logger.info('Number of training batches {}'.format(len(batches))) |
| |
| batches_eval = gen_batches(args.source_corpus_eval, args.target_corpus_eval, |
| source_vocab, target_vocab, |
| model_params['batch_size'], args.max_length) |
| logger.info('Number of eval batches {}'.format(len(batches_eval))) |
| |
| with Seq2SeqModelCaffe2( |
| model_params=model_params, |
| source_vocab_size=len(source_vocab), |
| target_vocab_size=len(target_vocab), |
| num_gpus=args.num_gpus, |
| num_cpus=20, |
| ) as model_obj: |
| model_obj.initialize_from_scratch() |
| for i in range(args.epochs): |
| logger.info('Epoch {}'.format(i)) |
| total_loss = 0 |
| for batch in batches: |
| total_loss += model_obj.step( |
| batch=batch, |
| forward_only=False, |
| ) |
| logger.info('\ttraining loss {}'.format(total_loss)) |
| total_loss = 0 |
| for batch in batches_eval: |
| total_loss += model_obj.step( |
| batch=batch, |
| forward_only=False, |
| ) |
| logger.info('\teval loss {}'.format(total_loss)) |
| |
| |
| def run_seq2seq_rnn_unidirection_with_no_attention(args): |
| run_seq2seq_model(args, model_params=dict( |
| attention=('regular' if args.use_attention else 'none'), |
| decoder_layer_configs=[ |
| dict( |
| num_units=args.decoder_cell_num_units, |
| ), |
| ], |
| encoder_type=dict( |
| encoder_layer_configs=[ |
| dict( |
| num_units=args.encoder_cell_num_units, |
| ), |
| ], |
| use_bidirectional_encoder=args.use_bidirectional_encoder, |
| ), |
| batch_size=args.batch_size, |
| optimizer_params=dict( |
| learning_rate=args.learning_rate, |
| ), |
| encoder_embedding_size=args.encoder_embedding_size, |
| decoder_embedding_size=args.decoder_embedding_size, |
| decoder_softmax_size=args.decoder_softmax_size, |
| max_gradient_norm=args.max_gradient_norm, |
| )) |
| |
| |
| def main(): |
| random.seed(31415) |
| parser = argparse.ArgumentParser( |
| description='Caffe2: Seq2Seq Training' |
| ) |
| parser.add_argument('--source-corpus', type=str, default=None, |
| help='Path to source corpus in a text file format. Each ' |
| 'line in the file should contain a single sentence', |
| required=True) |
| parser.add_argument('--target-corpus', type=str, default=None, |
| help='Path to target corpus in a text file format', |
| required=True) |
| parser.add_argument('--max-length', type=int, default=None, |
| help='Maximal lengths of train and eval sentences') |
| parser.add_argument('--batch-size', type=int, default=32, |
| help='Training batch size') |
| parser.add_argument('--epochs', type=int, default=10, |
| help='Number of iterations over training data') |
| parser.add_argument('--learning-rate', type=float, default=0.5, |
| help='Learning rate') |
| parser.add_argument('--unk-threshold', type=int, default=50, |
| help='Threshold frequency under which token becomes ' |
| 'labeled unknown token') |
| parser.add_argument('--max-gradient-norm', type=float, default=1.0, |
| help='Max global norm of gradients at the end of each ' |
| 'backward pass. We do clipping to match the number.') |
| parser.add_argument('--use-bidirectional-encoder', action='store_true', |
| help='Set flag to use bidirectional recurrent network ' |
| 'in encoder') |
| parser.add_argument('--use-attention', action='store_true', |
| help='Set flag to use seq2seq with attention model') |
| parser.add_argument('--source-corpus-eval', type=str, default=None, |
| help='Path to source corpus for evaluation in a text ' |
| 'file format', required=True) |
| parser.add_argument('--target-corpus-eval', type=str, default=None, |
| help='Path to target corpus for evaluation in a text ' |
| 'file format', required=True) |
| parser.add_argument('--encoder-cell-num-units', type=int, default=256, |
| help='Number of cell units in the encoder layer') |
| parser.add_argument('--decoder-cell-num-units', type=int, default=512, |
| help='Number of cell units in the decoder layer') |
| parser.add_argument('--encoder-embedding-size', type=int, default=256, |
| help='Size of embedding in the encoder layer') |
| parser.add_argument('--decoder-embedding-size', type=int, default=512, |
| help='Size of embedding in the decoder layer') |
| parser.add_argument('--decoder-softmax-size', type=int, default=128, |
| help='Size of softmax layer in the decoder') |
| parser.add_argument('--num-gpus', type=int, default=0, |
| help='Number of GPUs for data parallel model') |
| |
| args = parser.parse_args() |
| |
| run_seq2seq_rnn_unidirection_with_no_attention(args) |
| |
| |
| if __name__ == '__main__': |
| main() |