blob: 8f50072b53323dbe069ec89a25a3dbfeaebf4b70 [file] [log] [blame]
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
from caffe2.python import cnn, workspace, core, rnn_cell
from caffe2.proto import caffe2_pb2
import numpy as np
import unittest
@unittest.skipIf(not workspace.has_gpu_support, "No gpu support.")
class TestLSTMs(unittest.TestCase):
def testEqualToCudnn(self):
with core.DeviceScope(core.DeviceOption(caffe2_pb2.CUDA)):
T = 8
batch_size = 4
input_dim = 8
hidden_dim = 31
workspace.FeedBlob(
"seq_lengths",
np.array([T] * batch_size, dtype=np.int32)
)
workspace.FeedBlob("target", np.zeros(
[T, batch_size, hidden_dim], dtype=np.float32
))
workspace.FeedBlob("hidden_init", np.zeros(
[1, batch_size, hidden_dim], dtype=np.float32
))
workspace.FeedBlob("cell_init", np.zeros(
[1, batch_size, hidden_dim], dtype=np.float32
))
own_model = cnn.CNNModelHelper(name="own_lstm")
input_shape = [T, batch_size, input_dim]
cudnn_model = cnn.CNNModelHelper(name="cudnn_lstm")
input_blob = cudnn_model.param_init_net.UniformFill(
[], "input", shape=input_shape)
workspace.FeedBlob("CUDNN/hidden_init_cudnn", np.zeros(
[1, batch_size, hidden_dim], dtype=np.float32
))
workspace.FeedBlob("CUDNN/cell_init_cudnn", np.zeros(
[1, batch_size, hidden_dim], dtype=np.float32
))
cudnn_output, cudnn_last_hidden, _, param_extract = rnn_cell.cudnn_LSTM(
model=cudnn_model,
input_blob=input_blob,
initial_states=("hidden_init_cudnn", "hidden_init_cudnn"),
dim_in=input_dim,
dim_out=hidden_dim,
scope="CUDNN",
return_params=True,
)
cudnn_loss = cudnn_model.AveragedLoss(
cudnn_model.SquaredL2Distance(
[cudnn_output, "target"], "CUDNN/dist"
), "CUDNN/loss"
)
own_output, own_last_hidden, _, last_state, own_params = rnn_cell.LSTM(
model=own_model,
input_blob=input_blob,
seq_lengths="seq_lengths",
initial_states=("hidden_init", "cell_init"),
dim_in=input_dim,
dim_out=hidden_dim,
scope="OWN",
return_params=True,
)
own_loss = own_model.AveragedLoss(
own_model.SquaredL2Distance([own_output, "target"], "OWN/dist"),
"OWN/loss"
)
# Add gradients
cudnn_model.AddGradientOperators([cudnn_loss])
own_model.AddGradientOperators([own_loss])
# Add parameter updates
LR = cudnn_model.param_init_net.ConstantFill(
[], shape=[1], value=0.01
)
ONE = cudnn_model.param_init_net.ConstantFill(
[], shape=[1], value=1.0
)
for param in cudnn_model.GetParams():
cudnn_model.WeightedSum(
[param, ONE, cudnn_model.param_to_grad[param], LR], param
)
for param in own_model.GetParams():
own_model.WeightedSum(
[param, ONE, own_model.param_to_grad[param], LR], param
)
workspace.RunNetOnce(cudnn_model.param_init_net)
workspace.CreateNet(cudnn_model.net)
##
## CUDNN LSTM MODEL EXECUTION
##
# Get initial values from CuDNN LSTM so we can feed them
# to our own.
(param_extract_net, param_extract_mapping) = param_extract
workspace.RunNetOnce(param_extract_net)
cudnn_lstm_params = {}
for input_type, pars in param_extract_mapping.items():
cudnn_lstm_params[input_type] = {}
for k, v in pars.items():
cudnn_lstm_params[input_type][k] = workspace.FetchBlob(v[0])
# Run the model 3 times, so that some parameter updates are done
workspace.RunNet(cudnn_model.net.Proto().name, 3)
##
## OWN LSTM MODEL EXECUTION
##
# Map the cuDNN parameters to our own
workspace.RunNetOnce(own_model.param_init_net)
rnn_cell.InitFromLSTMParams(own_params, cudnn_lstm_params)
# Run the model 3 times, so that some parameter updates are done
workspace.CreateNet(own_model.net)
workspace.RunNet(own_model.net.Proto().name, 3)
##
## COMPARE RESULTS
##
# Then compare that final results after 3 runs are equal
own_output_data = workspace.FetchBlob(own_output)
own_last_hidden = workspace.FetchBlob(own_last_hidden)
own_loss = workspace.FetchBlob(own_loss)
cudnn_output_data = workspace.FetchBlob(cudnn_output)
cudnn_last_hidden = workspace.FetchBlob(cudnn_last_hidden)
cudnn_loss = workspace.FetchBlob(cudnn_loss)
self.assertTrue(np.allclose(own_output_data, cudnn_output_data))
self.assertTrue(np.allclose(own_last_hidden, cudnn_last_hidden))
self.assertTrue(np.allclose(own_loss, cudnn_loss))