blob: dbf4cf0a418947d53983881c0697cb17afb53f86 [file] [log] [blame]
#ifndef THC_TENSORMATH_COMPARET_CUH
#define THC_TENSORMATH_COMPARET_CUH
#include "THCTensorMath.h"
#include "THCGeneral.h"
#include "THCTensorCopy.h"
#include "THCApply.cuh"
#include "THCNumerics.cuh"
#include "THCReduce.cuh"
template <typename T, typename TOut>
struct TensorLTOp {
__device__ inline void operator()(TOut* out, T* a, T* b) {
*out = ScalarConvert<bool, TOut>::to(THCNumerics<T>::lt(*a, *b));
}
};
template <typename T, typename TOut>
struct TensorGTOp {
__device__ inline void operator()(TOut* out, T* a, T* b) {
*out = ScalarConvert<bool, TOut>::to(THCNumerics<T>::gt(*a, *b));
}
};
template <typename T, typename TOut>
struct TensorLEOp {
__device__ inline void operator()(TOut* out, T* a, T* b) {
*out = ScalarConvert<bool, TOut>::to(THCNumerics<T>::le(*a, *b));
}
};
template <typename T, typename TOut>
struct TensorGEOp {
__device__ inline void operator()(TOut* out, T* a, T* b) {
*out = ScalarConvert<bool, TOut>::to(THCNumerics<T>::ge(*a, *b));
}
};
template <typename T, typename TOut>
struct TensorEQOp {
__device__ inline void operator()(TOut* out, T* a, T* b) {
*out = ScalarConvert<bool, TOut>::to(THCNumerics<T>::eq(*a, *b));
}
};
template <typename T, typename TOut>
struct TensorNEOp {
__device__ inline void operator()(TOut* out, T* a, T* b) {
*out = ScalarConvert<bool, TOut>::to(THCNumerics<T>::ne(*a, *b));
}
};
template<typename TensorType, typename TensorTypeOut, typename Op>
void THC_logicalTensor(THCState *state,
TensorTypeOut *self_,
TensorType *src1,
TensorType *src2,
Op op) {
THLongStorage* st = TensorUtils<TensorType>::newSizeOf(state, src1);
TensorUtils<TensorTypeOut>::resize(state, self_, st, NULL);
THLongStorage_free(st);
THArgCheck(TensorUtils<TensorType>::getNumElements(state, src1) ==
TensorUtils<TensorType>::getNumElements(state, src2), 3,
"sizes do not match");
if (!THC_pointwiseApply3(state, self_, src1, src2, op)) {
THArgCheck(false, 2, CUTORCH_DIM_WARNING);
}
THCudaCheck(cudaGetLastError());
}
#endif // THC_TENSORMATH_COMPARET_CUH