blob: 76b57968db44c0ac44acce305748e364c94f4fe7 [file] [log] [blame]
""" A bunch of util functions to build Seq2Seq models with Caffe2."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
from caffe2.python import recurrent
from caffe2.python.cnn import CNNModelHelper
class ModelHelper(CNNModelHelper):
def __init__(self, init_params=True):
super(ModelHelper, self).__init__(
order='NCHW', # this is only relevant for convolutional networks
init_params=init_params,
)
self.non_trainable_params = []
def AddParam(self, name, init=None, init_value=None, trainable=True):
"""Adds a parameter to the model's net and it's initializer if needed
Args:
init: a tuple (<initialization_op_name>, <initialization_op_kwargs>)
init_value: int, float or str. Can be used instead of `init` as a
simple constant initializer
trainable: bool, whether to compute gradient for this param or not
"""
if init_value is not None:
assert init is None
assert type(init_value) in [int, float, str]
init = ('ConstantFill', dict(
shape=[1],
value=init_value,
))
if self.init_params:
param = self.param_init_net.__getattr__(init[0])(
[],
name,
**init[1]
)
else:
param = self.net.AddExternalInput(name)
if trainable:
self.params.append(param)
else:
self.non_trainable_params.append(param)
return param
def rnn_unidirectional_encoder(
model,
embedded_inputs,
input_lengths,
initial_hidden_state,
initial_cell_state,
embedding_size,
encoder_num_units,
use_attention
):
""" Unidirectional (forward pass) LSTM encoder."""
outputs, final_hidden_state, _, final_cell_state = recurrent.LSTM(
model=model,
input_blob=embedded_inputs,
seq_lengths=input_lengths,
initial_states=(initial_hidden_state, initial_cell_state),
dim_in=embedding_size,
dim_out=encoder_num_units,
scope='encoder',
outputs_with_grads=([0] if use_attention else [1, 3]),
)
return outputs, final_hidden_state, final_cell_state
def rnn_bidirectional_encoder(
model,
embedded_inputs,
input_lengths,
initial_hidden_state,
initial_cell_state,
embedding_size,
encoder_num_units,
use_attention
):
""" Bidirectional (forward pass and backward pass) LSTM encoder."""
# Forward pass
(
outputs_fw,
final_hidden_state_fw,
_,
final_cell_state_fw,
) = recurrent.LSTM(
model=model,
input_blob=embedded_inputs,
seq_lengths=input_lengths,
initial_states=(initial_hidden_state, initial_cell_state),
dim_in=embedding_size,
dim_out=encoder_num_units,
scope='forward_encoder',
outputs_with_grads=([0] if use_attention else [1, 3]),
)
# Backward pass
reversed_embedded_inputs = model.net.ReversePackedSegs(
[embedded_inputs, input_lengths],
['reversed_embedded_inputs'],
)
(
outputs_bw,
final_hidden_state_bw,
_,
final_cell_state_bw,
) = recurrent.LSTM(
model=model,
input_blob=reversed_embedded_inputs,
seq_lengths=input_lengths,
initial_states=(initial_hidden_state, initial_cell_state),
dim_in=embedding_size,
dim_out=encoder_num_units,
scope='backward_encoder',
outputs_with_grads=([0] if use_attention else [1, 3]),
)
outputs_bw = model.net.ReversePackedSegs(
[outputs_bw, input_lengths],
['outputs_bw'],
)
# Concatenate forward and backward results
outputs, _ = model.net.Concat(
[outputs_fw, outputs_bw],
['outputs', 'outputs_dim'],
axis=2,
)
final_hidden_state, _ = model.net.Concat(
[final_hidden_state_fw, final_hidden_state_bw],
['final_hidden_state', 'final_hidden_state_dim'],
axis=2,
)
final_cell_state, _ = model.net.Concat(
[final_cell_state_fw, final_cell_state_bw],
['final_cell_state', 'final_cell_state_dim'],
axis=2,
)
return outputs, final_hidden_state, final_cell_state