blob: 8885d7c4bcb2f95c39ba8b6652a3b67eeefe28f0 [file] [log] [blame]
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_KERNELS_RNN_LSTM_OPS_H_
#define TENSORFLOW_CORE_KERNELS_RNN_LSTM_OPS_H_
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/kernels/eigen_activations.h"
#include "tensorflow/core/kernels/rnn/blas_gemm.h"
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
class OpKernelContext;
namespace functor {
template <typename Device, typename T>
struct TensorZero {
void operator()(const Device& d, typename TTypes<T>::Flat t) {
t.device(d) = t.constant(T(0));
}
};
template <typename Device, typename T>
struct TensorUnalignedZero {
void operator()(const Device& d, typename TTypes<T>::UnalignedFlat t) {
t.device(d) = t.constant(T(0));
}
};
template <typename Device, typename T>
struct TensorCopy {
void operator()(const Device& d, typename TTypes<T>::ConstFlat src,
typename TTypes<T>::Flat dst) {
dst.device(d) = src;
}
};
template <typename Device, typename T>
struct TensorCopyUnaligned {
void operator()(const Device& d, typename TTypes<T>::UnalignedConstFlat src,
typename TTypes<T>::Flat dst) {
dst.device(d) = src;
}
};
template <typename Device, typename T>
struct TensorCopyToUnaligned {
void operator()(const Device& d, typename TTypes<T>::ConstFlat src,
typename TTypes<T>::UnalignedFlat dst) {
dst.device(d) = src;
}
};
template <typename Device, typename T>
struct TensorAdd {
void operator()(const Device& d, typename TTypes<T>::ConstFlat a,
typename TTypes<T>::ConstFlat b, typename TTypes<T>::Flat c) {
c.device(d) = a + b;
}
};
template <typename Device, typename T>
struct TensorZeroPadding {
void operator()(const Device& d, const int64 time_idx,
typename TTypes<int64>::ConstVec seq_len,
typename TTypes<T>::Vec mask, typename TTypes<T>::Matrix m) {
// mask is shape [batch_size].
mask.device(d) = seq_len.constant(time_idx) < seq_len;
// m_shape is [batch_size, 1].
Eigen::array<Eigen::DenseIndex, 2> m_shape({m.dimensions()[0], 1});
// broadcast_shape is [1, units].
Eigen::array<Eigen::DenseIndex, 2> broadcast_shape({1, m.dimensions()[1]});
// m is shape [batch_size, units].
m.device(d) = m * mask.reshape(m_shape).broadcast(broadcast_shape);
}
};
struct LSTMBlockCell {
LSTMBlockCell(const int batch_size, const int input_size, const int cell_size)
: batch_size_(batch_size),
input_size_(input_size),
cell_size_(cell_size) {}
int batch_size() const { return batch_size_; }
int input_size() const { return input_size_; }
int cell_size() const { return cell_size_; }
inline Eigen::array<Eigen::DenseIndex, 2> icfo_i_offsets() const {
return {0, 0};
}
inline Eigen::array<Eigen::DenseIndex, 2> icfo_c_offsets() const {
return {0, cell_size_};
}
inline Eigen::array<Eigen::DenseIndex, 2> icfo_f_offsets() const {
return {0, cell_size_ * 2};
}
inline Eigen::array<Eigen::DenseIndex, 2> icfo_o_offsets() const {
return {0, cell_size_ * 3};
}
inline Eigen::array<Eigen::DenseIndex, 2> cell_extents() const {
return {batch_size_, cell_size_};
}
inline Eigen::array<Eigen::DenseIndex, 2> xh_x_offsets() const {
return {0, 0};
}
inline Eigen::array<Eigen::DenseIndex, 2> xh_x_extents() const {
return {batch_size_, input_size_};
}
inline Eigen::array<Eigen::DenseIndex, 2> xh_h_offsets() const {
return {0, input_size_};
}
inline Eigen::array<Eigen::DenseIndex, 2> xh_h_extents() const {
return {batch_size_, cell_size_};
}
protected:
const int batch_size_;
const int input_size_;
const int cell_size_;
};
// See lstm_ops.cc for CPUDevice implementation and lstm_ops_gpu.cu.cc for
// GPUDevice implementation.
template <typename Device, typename T, bool USE_CUBLAS>
struct LSTMBlockCellFprop : public LSTMBlockCell {
LSTMBlockCellFprop(const int batch_size, const int input_size,
const int cell_size)
: LSTMBlockCell(batch_size, input_size, cell_size) {}
void operator()(OpKernelContext* ctx, const Device& d,
const float forget_bias, const float cell_clip,
bool use_peephole, typename TTypes<T>::ConstMatrix x,
typename TTypes<T>::ConstMatrix cs_prev,
typename TTypes<T>::ConstMatrix h_prev,
typename TTypes<T>::ConstMatrix w,
typename TTypes<T>::ConstVec wci,
typename TTypes<T>::ConstVec wcf,
typename TTypes<T>::ConstVec wco,
typename TTypes<T>::ConstVec b, typename TTypes<T>::Matrix xh,
typename TTypes<T>::Matrix i, typename TTypes<T>::Matrix cs,
typename TTypes<T>::Matrix f, typename TTypes<T>::Matrix o,
typename TTypes<T>::Matrix ci, typename TTypes<T>::Matrix co,
typename TTypes<T>::Matrix icfo,
typename TTypes<T>::Matrix h);
};
// See lstm_ops.cc for CPUDevice implementation and lstm_ops_gpu.cu.cc for
// GPUDevice implementation.
template <typename Device, typename T, bool USE_CUBLAS>
struct LSTMBlockCellBprop : public LSTMBlockCell {
LSTMBlockCellBprop(const int batch_size, const int input_size,
const int cell_size)
: LSTMBlockCell(batch_size, input_size, cell_size) {}
void operator()(
OpKernelContext* ctx, const Device& d, bool use_peephole,
typename TTypes<T>::ConstMatrix x,
typename TTypes<T>::ConstMatrix cs_prev,
typename TTypes<T>::ConstMatrix h_prev, typename TTypes<T>::ConstMatrix w,
typename TTypes<T>::ConstVec wci, typename TTypes<T>::ConstVec wcf,
typename TTypes<T>::ConstVec wco, typename TTypes<T>::ConstVec b,
typename TTypes<T>::ConstMatrix i, typename TTypes<T>::ConstMatrix cs,
typename TTypes<T>::ConstMatrix f, typename TTypes<T>::ConstMatrix o,
typename TTypes<T>::ConstMatrix ci, typename TTypes<T>::ConstMatrix co,
typename TTypes<T>::ConstMatrix cs_grad,
typename TTypes<T>::ConstMatrix h_grad, typename TTypes<T>::Matrix do_,
typename TTypes<T>::Matrix dcs, typename TTypes<T>::Matrix dci,
typename TTypes<T>::Matrix df, typename TTypes<T>::Matrix di,
typename TTypes<T>::Matrix dicfo, typename TTypes<T>::Matrix cs_prev_grad,
typename TTypes<T>::Vec wci_grad, typename TTypes<T>::Vec wcf_grad,
typename TTypes<T>::Vec wco_grad);
};
template <typename Device, typename T, bool USE_CUBLAS>
struct BlockLSTMBprop : public LSTMBlockCell {
BlockLSTMBprop(const int batch_size, const int input_size,
const int cell_size)
: LSTMBlockCell(batch_size, input_size, cell_size) {}
void operator()(
OpKernelContext* ctx, const Device& d, bool use_peephole,
typename TTypes<T>::ConstMatrix x,
typename TTypes<T>::ConstMatrix cs_prev,
typename TTypes<T>::ConstMatrix h_prev, typename TTypes<T>::ConstMatrix w,
typename TTypes<T>::ConstVec wci, typename TTypes<T>::ConstVec wcf,
typename TTypes<T>::ConstVec wco, typename TTypes<T>::ConstVec b,
typename TTypes<T>::Matrix xh, typename TTypes<T>::ConstMatrix i,
typename TTypes<T>::ConstMatrix cs, typename TTypes<T>::ConstMatrix f,
typename TTypes<T>::ConstMatrix o, typename TTypes<T>::ConstMatrix ci,
typename TTypes<T>::ConstMatrix co,
typename TTypes<T>::ConstMatrix cs_grad,
typename TTypes<T>::ConstMatrix h_grad, typename TTypes<T>::Matrix do_,
typename TTypes<T>::Matrix dcs, typename TTypes<T>::Matrix dci,
typename TTypes<T>::Matrix df, typename TTypes<T>::Matrix di,
typename TTypes<T>::Matrix dicfo, typename TTypes<T>::Matrix cs_prev_grad,
typename TTypes<T>::Matrix h_prev_grad,
typename TTypes<T>::Matrix xh_grad, typename TTypes<T>::Matrix x_grad,
typename TTypes<T>::Matrix w_grad, typename TTypes<T>::Vec wci_grad,
typename TTypes<T>::Vec wcf_grad, typename TTypes<T>::Vec wco_grad,
typename TTypes<T>::Vec b_grad) {
// do[t] = sigm'(o[t]) .* dh[t] .* co[t]
do_.device(d) = o * (o.constant(T(1)) - o) * h_grad * co;
// dcs[t] += tanh'(cs[t]) .* dh[t] .* o[t] + dcs[t + 1] .* f[t + 1]
dcs.device(d) = (co.constant(T(1)) - co * co) * h_grad * o + cs_grad;
Eigen::array<Eigen::DenseIndex, 2> p_shape({1, cell_size_});
Eigen::array<Eigen::DenseIndex, 2> p_broadcast_shape({batch_size_, 1});
if (use_peephole) {
dcs.device(d) =
dcs + do_ * wco.reshape(p_shape).broadcast(p_broadcast_shape);
}
// dci[t] = tanh'(ci[t]) dcs[t] i[t]
dci.device(d) = (ci.constant(T(1)) - ci * ci) * dcs * i;
// df[t] = sigm'(f[t]) dcs[t] cs[t - 1]
df.device(d) = f * (f.constant(T(1)) - f) * dcs * cs_prev;
// di[t] = sigm'(i[t]) dcs[t] ci[t]
di.device(d) = i * (i.constant(T(1)) - i) * dcs * ci;
dicfo.slice(icfo_i_offsets(), cell_extents()).device(d) = di;
dicfo.slice(icfo_c_offsets(), cell_extents()).device(d) = dci;
dicfo.slice(icfo_f_offsets(), cell_extents()).device(d) = df;
dicfo.slice(icfo_o_offsets(), cell_extents()).device(d) = do_;
cs_prev_grad.device(d) = dcs * f;
if (use_peephole) {
cs_prev_grad.device(d) =
cs_prev_grad +
di * wci.reshape(p_shape).broadcast(p_broadcast_shape) +
df * wcf.reshape(p_shape).broadcast(p_broadcast_shape);
}
// xh_grad.
typename TTypes<T>::ConstMatrix const_dicfo(dicfo.data(),
dicfo.dimensions());
TensorBlasGemm<Device, T, USE_CUBLAS>::compute(
ctx, d, false, true, 1.f, const_dicfo, w, 0.f, xh_grad);
// xh.
xh.slice(xh_x_offsets(), xh_x_extents()).device(d) = x;
xh.slice(xh_h_offsets(), xh_h_extents()).device(d) = h_prev;
typename TTypes<T>::ConstMatrix const_xh(xh.data(), xh.dimensions());
// x_grad.
x_grad.device(d) = xh_grad.slice(xh_x_offsets(), xh_x_extents());
h_prev_grad.device(d) = xh_grad.slice(xh_h_offsets(), xh_h_extents());
// w_grad.
TensorBlasGemm<Device, T, USE_CUBLAS>::compute(
ctx, d, true, false, 1.f, const_xh, const_dicfo, 1.f, w_grad);
// b_grad.
b_grad.device(d) += dicfo.sum(Eigen::array<int, 1>({0}));
if (use_peephole) {
wci_grad.device(d) += (di * cs_prev).sum(Eigen::array<int, 1>({0}));
wcf_grad.device(d) += (df * cs_prev).sum(Eigen::array<int, 1>({0}));
wco_grad.device(d) += (do_ * cs).sum(Eigen::array<int, 1>({0}));
}
}
};
} // namespace functor
} // namespace tensorflow
#endif // TENSORFLOW_CORE_KERNELS_RNN_LSTM_OPS_H_