blob: 776a7252ddc0d22233f7becb0378fa35b696f369 [file] [log] [blame]
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
from caffe2.python import core
from caffe2.python.cnn import CNNModelHelper
def recurrent_net(
net, cell_net, inputs, initial_cell_inputs,
links, timestep=None, scope=None
):
'''
net: the main net operator should be added to
cell_net: cell_net which is executed in a recurrent fasion
inputs: sequences to be fed into the recurrent net. Currently only one input
is supported. It has to be in a format T x N x (D1...Dk) where T is lengths
of the sequence. N is a batch size and (D1...Dk) are the rest of dimentions
initial_cell_inputs: inputs of the cell_net for the 0 timestamp.
Format for each input is:
(cell_net_input_name, external_blob_with_data)
links: a dictionary from cell_net input names in moment t+1 and
output names of moment t. Currently we assume that each output becomes
an input for the next timestep.
timestep: name of the timestep blob to be used. If not provided "timestep"
is used.
scope: Internal blobs are going to be scoped in a format
<scope_name>/<blob_name>
If not provided we generate a scope name automatically
'''
assert len(inputs) == 1, "Only one input blob is supported so far"
input_blobs = [str(i[0]) for i in inputs]
initial_input_blobs = [str(x[1]) for x in initial_cell_inputs]
op_name = net.NextName('recurrent')
def s(name):
# We have to manually scope due to our internal/external blob
# relationships.
scope_name = op_name if scope is None else scope
return "{}/{}".format(str(scope_name), str(name))
# determine inputs that are considered to be references
# it is those that are not referred to in inputs or initial_cell_inputs
known_inputs = map(str, input_blobs + initial_input_blobs)
known_inputs += [str(x[0]) for x in initial_cell_inputs]
if timestep is not None:
known_inputs.append(str(timestep))
references = [
b for b in cell_net.Proto().external_input
if b not in known_inputs]
inner_outputs = list(cell_net.Proto().external_output)
# These gradients are expected to be available during the backward pass
inner_outputs_map = {o: o + '_grad' for o in inner_outputs}
# compute the backward pass of the cell net
backward_ops, backward_mapping = core.GradientRegistry.GetBackwardPass(
cell_net.Proto().op, inner_outputs_map)
backward_mapping = {str(k): str(v) for k, v in backward_mapping.items()}
backward_cell_net = core.Net("RecurrentBackwardStep")
del backward_cell_net.Proto().op[:]
backward_cell_net.Proto().op.extend(backward_ops)
# compute blobs used but not defined in the backward pass
ssa, _ = core.get_ssa(backward_cell_net.Proto())
undefined = core.get_undefined_blobs(ssa)
# also add to the output list the intermediate outputs of fwd_step that
# are used by backward.
ssa, blob_versions = core.get_ssa(cell_net.Proto())
scratches = [
blob for (blob, ver) in blob_versions.items()
if ver > 0 and
blob in undefined and
blob not in cell_net.Proto().external_output]
backward_cell_net.Proto().external_input.extend(scratches)
all_inputs = [i[1] for i in inputs] + [
x[1] for x in initial_cell_inputs] + references
all_outputs = []
cell_net.Proto().type = 'simple'
backward_cell_net.Proto().type = 'simple'
# Internal arguments used by RecurrentNetwork operator
# Links are in the format blob_name, recurrent_states, offset.
# In the moment t we know that corresponding data block is at
# t + offset position in the recurrent_states tensor
forward_links = []
backward_links = []
# Aliases are used to expose outputs to external world
# Format (internal_blob, external_blob, offset)
# Negative offset stands for going from the end,
# positive - from the beginning
aliases = []
# States held inputs to the cell net
recurrent_states = []
for cell_input, _ in initial_cell_inputs:
cell_input = str(cell_input)
# Recurrent_states is going to be (T + 1) x ...
# It stores all inputs and outputs of the cell net over time.
# Or their gradients in the case of the backward pass.
state = s(cell_input + "_states")
states_grad = state + "_grad"
cell_output = links[str(cell_input)]
forward_links.append((cell_input, state, 0))
forward_links.append((cell_output, state, 1))
backward_links.append((cell_input + "_grad", states_grad, 0))
backward_links.append((cell_output + "_grad", states_grad, 1))
backward_cell_net.Proto().external_input.append(
str(cell_output) + "_grad")
aliases.append((state, cell_output + "_last", -1))
aliases.append((state, cell_output + "_all", 1))
all_outputs.extend([cell_output + "_all", cell_output + "_last"])
recurrent_states.append(state)
for input_id, (input_t, input_blob) in enumerate(inputs):
forward_links.append((str(input_t), str(input_blob), 0))
backward_links.append((
backward_mapping[str(input_t)], str(input_blob) + "_grad", 0
))
backward_cell_net.Proto().external_input.extend(
cell_net.Proto().external_input)
backward_cell_net.Proto().external_input.extend(
cell_net.Proto().external_output)
def unpack_triple(x):
if x:
a, b, c = zip(*x)
return a, b, c
return [], [], []
# Splitting to separate lists so we can pass them to c++
# where we ensemle them back
link_internal, link_external, link_offset = unpack_triple(forward_links)
backward_link_internal, backward_link_external, backward_link_offset = \
unpack_triple(backward_links)
alias_src, alias_dst, alias_offset = unpack_triple(aliases)
params = [x for x in references if x in backward_mapping.keys()]
recurrent_inputs = [str(x[1]) for x in initial_cell_inputs]
results = net.RecurrentNetwork(
all_inputs,
all_outputs + [s("step_workspaces")],
param=map(all_inputs.index, params),
alias_src=alias_src,
alias_dst=map(str, alias_dst),
alias_offset=alias_offset,
recurrent_states=recurrent_states,
initial_recurrent_state_ids=map(all_inputs.index, recurrent_inputs),
link_internal=map(str, link_internal),
link_external=map(str, link_external),
link_offset=link_offset,
backward_link_internal=map(str, backward_link_internal),
backward_link_external=map(str, backward_link_external),
backward_link_offset=backward_link_offset,
step_net=str(cell_net.Proto()),
backward_step_net=str(backward_cell_net.Proto()),
timestep="timestep" if timestep is None else str(timestep),
)
# The last output is a list of step workspaces,
# which is only needed internally for gradient propogation
return results[:-1]
def LSTM(model, input_blob, seq_lengths, initial_states, dim_in, dim_out,
scope):
'''
Adds a standard LSTM recurrent network operator to a model.
model: CNNModelHelper object new operators would be added to
input_blob: the input sequence in a format T x N x D
where T is sequence size, N - batch size and D - input dimention
seq_lengths: blob containing sequence lengths which would be passed to
LSTMUnit operator
initial_states: a tupple of (hidden_input_blob, cell_input_blob)
which are going to be inputs to the cell net on the first iteration
dim_in: input dimention
dim_out: output dimention
'''
def s(name):
# We have to manually scope due to our internal/external blob
# relationships.
return "{}/{}".format(str(scope), str(name))
""" initial bulk fully-connected """
input_blob = model.FC(
input_blob, s('i2h'), dim_in=dim_in, dim_out=4 * dim_out, axis=2)
""" the step net """
step_model = CNNModelHelper(name='lstm_cell', param_model=model)
input_t, timestep, cell_t_prev, hidden_t_prev = (
step_model.net.AddExternalInputs(
'input_t', 'timestep', 'cell_t_prev', 'hidden_t_prev'))
gates_t = step_model.FC(
hidden_t_prev, s('gates_t'), dim_in=dim_out,
dim_out=4 * dim_out, axis=2)
step_model.net.Sum([gates_t, input_t], gates_t)
hidden_t, cell_t = step_model.net.LSTMUnit(
[cell_t_prev, gates_t, seq_lengths, timestep],
[s('hidden_t'), s('cell_t')],
)
step_model.net.AddExternalOutputs(cell_t, hidden_t)
""" recurrent network """
(hidden_input_blob, cell_input_blob) = initial_states
output, last_output, all_states, last_state = recurrent_net(
net=model.net,
cell_net=step_model.net,
inputs=[(input_t, input_blob)],
initial_cell_inputs=[
(hidden_t_prev, hidden_input_blob),
(cell_t_prev, cell_input_blob),
],
links={
hidden_t_prev: hidden_t,
cell_t_prev: cell_t,
},
timestep=timestep,
scope=scope,
)
return output, last_output, all_states, last_state