blob: 003abb17ed74d632d10184d1d2fffdb4a880cde1 [file] [log] [blame]
#include "lstm_unit_op.h"
namespace caffe2 {
namespace detail {
// Using macros here instead of linlined functions
// Needed for performance: g++ inliner loses 10-20%
#undef C2_EIGEN_SIGMOID_INLINE
#undef C2_EIGEN_HOST_TANH_INLINE
#define C2_EIGEN_SIGMOID_INLINE(x) (1.0f / ((-(x)).exp() + 1.0))
#define C2_EIGEN_HOST_TANH_INLINE(x) \
(2.0 * C2_EIGEN_SIGMOID_INLINE(2.0 * (x)) - 1.0)
template <typename T, typename Context>
void LSTMUnit(
int N,
int D,
int t,
const T* H_prev,
const T* C_prev,
const T* X,
const int32_t* seqLengths,
bool drop_states,
T* C,
T* H,
const T& forget_bias,
Context* context) {
for (int n = 0; n < N; ++n) {
const bool valid = t < seqLengths[n];
// create data aliases into Eigen vectors
EigenVectorArrayMap<T> vH(H, D);
EigenVectorArrayMap<T> vC(C, D);
ConstEigenVectorArrayMap<T> vH_prev(H_prev, D);
ConstEigenVectorArrayMap<T> vC_prev(C_prev, D);
ConstEigenVectorArrayMap<T> vX0(X + 0 * D, D);
ConstEigenVectorArrayMap<T> vX1(X + 1 * D, D);
ConstEigenVectorArrayMap<T> vX2(X + 2 * D, D);
ConstEigenVectorArrayMap<T> vX3(X + 3 * D, D);
if (valid == false) {
if (drop_states) {
vH.setConstant((T)(0.0));
vC.setConstant((T)(0.0));
} else {
vH = vH_prev;
vC = vC_prev;
}
} else {
vC = C2_EIGEN_SIGMOID_INLINE(vX1 + forget_bias) * vC_prev +
C2_EIGEN_SIGMOID_INLINE(vX0) * C2_EIGEN_HOST_TANH_INLINE(vX3);
vH = C2_EIGEN_SIGMOID_INLINE(vX2) * C2_EIGEN_HOST_TANH_INLINE(vC);
}
H_prev += D;
C_prev += D;
X += 4 * D;
C += D;
H += D;
}
}
template <typename T, typename Context>
void LSTMUnitGradient(
int N,
int D,
int t,
const T* C_prev,
const T* X,
const int32_t* seqLengths,
const T* C,
const T* H,
const T* C_diff,
const T* H_diff,
bool drop_states,
T* H_prev_diff,
T* C_prev_diff,
T* X_diff,
const T forget_bias,
Context* context) {
for (int n = 0; n < N; ++n) {
const bool valid = t < seqLengths[n];
// create data aliases into Eigen vectors
ConstEigenVectorArrayMap<T> vC_prev(C_prev, D);
ConstEigenVectorArrayMap<T> vX(X, 4 * D);
ConstEigenVectorArrayMap<T> vX0(X + 0 * D, D);
ConstEigenVectorArrayMap<T> vX1(X + 1 * D, D);
ConstEigenVectorArrayMap<T> vX2(X + 2 * D, D);
ConstEigenVectorArrayMap<T> vX3(X + 3 * D, D);
ConstEigenVectorArrayMap<T> vC(C, D);
ConstEigenVectorArrayMap<T> vH(H, D);
ConstEigenVectorArrayMap<T> vC_diff(C_diff, D);
ConstEigenVectorArrayMap<T> vH_diff(H_diff, D);
// Output
EigenVectorArrayMap<T> vH_prev_diff(H_prev_diff, D);
EigenVectorArrayMap<T> vC_prev_diff(C_prev_diff, D);
EigenVectorArrayMap<T> vX_diff(X_diff, 4 * D);
EigenVectorArrayMap<T> vX0_diff(X_diff + 0 * D, D);
EigenVectorArrayMap<T> vX1_diff(X_diff + 1 * D, D);
EigenVectorArrayMap<T> vX2_diff(X_diff + 2 * D, D);
EigenVectorArrayMap<T> vX3_diff(X_diff + 3 * D, D);
if (!valid) {
if (drop_states) {
vH_prev_diff.setConstant((T)(0.0));
vC_prev_diff.setConstant((T)(0.0));
} else {
vH_prev_diff = vH_diff;
vC_prev_diff = vC_diff;
}
vX_diff.setConstant((T)(0.0));
} else {
const Eigen::Array<T, Eigen::Dynamic, 1> i = C2_EIGEN_SIGMOID_INLINE(vX0);
const Eigen::Array<T, Eigen::Dynamic, 1> f =
C2_EIGEN_SIGMOID_INLINE(vX1 + forget_bias);
const Eigen::Array<T, Eigen::Dynamic, 1> o = C2_EIGEN_SIGMOID_INLINE(vX2);
const Eigen::Array<T, Eigen::Dynamic, 1> g =
C2_EIGEN_HOST_TANH_INLINE(vX3);
const Eigen::Array<T, Eigen::Dynamic, 1> host_tanh_c =
C2_EIGEN_HOST_TANH_INLINE(vC);
const Eigen::Array<T, Eigen::Dynamic, 1> c_term_diff =
vC_diff + vH_diff * o * (1 - host_tanh_c * host_tanh_c);
vC_prev_diff = c_term_diff * f;
vH_prev_diff = 0; // not used in 'valid' case
vX0_diff = c_term_diff * g * i * (1 - i);
vX1_diff = c_term_diff * vC_prev * f * (1 - f);
vX2_diff = vH_diff * host_tanh_c * o * (1 - o);
vX3_diff = c_term_diff * i * (1 - g * g);
}
C_prev += D;
X += 4 * D;
C += D;
H += D;
C_diff += D;
H_diff += D;
X_diff += 4 * D;
H_prev_diff += D;
C_prev_diff += D;
}
}
#undef C2_EIGEN_SIGMOID_INLINE
#undef C2_EIGEN_HOST_TANH_INLINE
} // namespace detail
namespace {
REGISTER_CPU_OPERATOR(LSTMUnit, LSTMUnitOp<float, CPUContext>);
OPERATOR_SCHEMA(LSTMUnit)
.NumInputs(5)
.NumOutputs(2)
.SetDoc(R"DOC(
LSTMUnit computes the activations of a standard LSTM (without peephole
connections), in a sequence-length aware fashion.
Concretely, given the (fused) inputs X (TxNxD), the previous cell
state (NxD), and the sequence lengths (N), computes the LSTM
activations, avoiding computation if the input is invalid (as in, the
value at X{t][n] >= seqLengths[n].
)DOC")
.Arg("forget_bias", "Bias term to add in while calculating forget gate");
REGISTER_CPU_OPERATOR(LSTMUnitGradient, LSTMUnitGradientOp<float, CPUContext>);
OPERATOR_SCHEMA(LSTMUnitGradient).NumInputs(9).NumOutputs(3);
class GetLSTMUnitGradient : public GradientMakerBase {
using GradientMakerBase::GradientMakerBase;
vector<OperatorDef> GetGradientDefs() override {
return SingleGradientDef(
"LSTMUnitGradient",
"",
vector<string>{I(0), I(1), I(2), I(3), I(4), O(0), O(1), GO(0), GO(1)},
vector<string>{GI(0), GI(1), GI(2)});
}
};
REGISTER_GRADIENT(LSTMUnit, GetLSTMUnitGradient);
}
}