blob: ed438054ae630a7aa2b8d29740da37b3eeb6a941 [file] [log] [blame]
#ifndef CAFFE2_OPERATORS_ARG_OPS_EIGEN_H_
#define CAFFE2_OPERATORS_ARG_OPS_EIGEN_H_
#include "caffe2/core/context.h"
#include "caffe2/core/types.h"
#include "Eigen/Core"
#if EIGEN_VERSION_AT_LEAST(3, 3, 0)
#include "unsupported/Eigen/CXX11/Tensor"
namespace caffe2 {
namespace arg_ops_eigen {
template <typename T>
using EigenTensorMap1D = Eigen::TensorMap<Eigen::Tensor<T, 1, Eigen::RowMajor>>;
template <typename T>
using EigenTensorMap2D = Eigen::TensorMap<Eigen::Tensor<T, 2, Eigen::RowMajor>>;
template <typename T>
using EigenTensorMap3D = Eigen::TensorMap<Eigen::Tensor<T, 3, Eigen::RowMajor>>;
template <class Device, typename T>
void ComputeArgMaxEigen(
const Device& device,
const T* X,
const TIndex prev_size,
const TIndex next_size,
const TIndex n,
TIndex* Y) {
if (next_size == 1) {
EigenTensorMap1D<TIndex>(Y, prev_size).device(device) =
EigenTensorMap2D<T>(const_cast<T*>(X), prev_size, n)
.argmax(1)
.template cast<TIndex>();
} else if (prev_size == 1) {
EigenTensorMap1D<TIndex>(Y, next_size).device(device) =
EigenTensorMap2D<T>(const_cast<T*>(X), n, next_size)
.argmax(0)
.template cast<TIndex>();
} else {
EigenTensorMap2D<TIndex>(Y, prev_size, next_size).device(device) =
EigenTensorMap3D<T>(const_cast<T*>(X), prev_size, n, next_size)
.argmax(1)
.template cast<TIndex>();
}
}
template <class Device, typename T>
void ComputeArgMinEigen(
const Device& device,
const T* X,
const TIndex prev_size,
const TIndex next_size,
const TIndex n,
TIndex* Y) {
if (next_size == 1) {
EigenTensorMap1D<TIndex>(Y, prev_size).device(device) =
EigenTensorMap2D<T>(const_cast<T*>(X), prev_size, n)
.argmin(1)
.template cast<TIndex>();
} else if (prev_size == 1) {
EigenTensorMap1D<TIndex>(Y, next_size).device(device) =
EigenTensorMap2D<T>(const_cast<T*>(X), n, next_size)
.argmin(0)
.template cast<TIndex>();
} else {
EigenTensorMap2D<TIndex>(Y, prev_size, next_size).device(device) =
EigenTensorMap3D<T>(const_cast<T*>(X), prev_size, n, next_size)
.argmin(1)
.template cast<TIndex>();
}
}
} // namespace arg_ops_eigen
} // namespace caffe2
#endif // EIGEN_VERSION_AT_LEAST(3, 3, 0)
#endif // CAFFE2_OPERATORS_ARG_OPS_EIGEN_H_