blob: 42aab34cc97b78517924e61a8d4286e78672cfc5 [file] [log] [blame]
namespace detail {
// Add a layer of SFINAE to support static_assert
template <typename T, int Dim, typename IndexT,
template <typename U> class PtrTraits,
int NewDim, bool B>
struct UpcastTHCRoot {
static THCDeviceTensor<T, NewDim, IndexT, PtrTraits>
make(THCState* state, THCudaTensor* t);
};
template <typename T, int Dim, typename IndexT,
template <typename U> class PtrTraits,
int NewDim, bool B>
struct UpcastTHC :
UpcastTHCRoot<T, Dim, IndexT, PtrTraits, NewDim, B> {
};
// Never instantiated SFINAE purposes only
template <typename T, int Dim, typename IndexT,
template <typename U> class PtrTraits,
int NewDim>
struct UpcastTHC<T, Dim, IndexT, PtrTraits, NewDim, false> :
UpcastTHCRoot<T, Dim, IndexT, PtrTraits, NewDim, false> {
};
template <typename T, int Dim, typename IndexT,
template <typename U> class PtrTraits,
int NewDim>
struct UpcastTHC<T, Dim, IndexT, PtrTraits, NewDim, true> :
UpcastTHCRoot<T, Dim, IndexT, PtrTraits, NewDim, true> {
static THCDeviceTensor<T, NewDim, IndexT, PtrTraits>
make(THCState* state, THCudaTensor* t) {
thc_static_assert(NewDim > Dim);
return toDeviceTensor<T, Dim, IndexT, PtrTraits>(state, t).
template upcastOuter<NewDim>();
}
};
// Add a layer of SFINAE to support static_assert
template <typename T, int Dim, typename IndexT,
template <typename U> class PtrTraits,
int NewDim, bool B>
struct DowncastTHCRoot {
static THCDeviceTensor<T, NewDim, IndexT, PtrTraits>
make(THCState* state, THCudaTensor* t);
};
template <typename T, int Dim, typename IndexT,
template <typename U> class PtrTraits,
int NewDim, bool B>
struct DowncastTHC :
DowncastTHCRoot<T, Dim, IndexT, PtrTraits, NewDim, B> {
};
// Never instantiated SFINAE purposes only
template <typename T, int Dim, typename IndexT,
template <typename U> class PtrTraits,
int NewDim>
struct DowncastTHC<T, Dim, IndexT, PtrTraits, NewDim, false> :
DowncastTHCRoot<T, Dim, IndexT, PtrTraits, NewDim, false> {
};
template <typename T, int Dim, typename IndexT,
template <typename U> class PtrTraits,
int NewDim>
struct DowncastTHC<T, Dim, IndexT, PtrTraits, NewDim, true> :
DowncastTHCRoot<T, Dim, IndexT, PtrTraits, NewDim, true> {
static THCDeviceTensor<T, NewDim, IndexT, PtrTraits>
make(THCState* state, THCudaTensor* t) {
thc_static_assert(NewDim < Dim);
return toDeviceTensor<T, Dim, IndexT, PtrTraits>(state, t).
template downcastOuter<NewDim>();
}
};
} // namespace detail
#define SWITCH_UNROLL_CUDA_CAST_FACTORY(i) \
case i: \
if (NewDim > i) { \
return detail::UpcastTHC<T, i, IndexT, \
PtrTraits, NewDim, (NewDim > i)>:: \
make(state, t); \
} else if (NewDim == i) { \
return toDeviceTensor<T, NewDim, IndexT, PtrTraits>(state, t); \
} else { \
return detail::DowncastTHC<T, i, IndexT, \
PtrTraits, NewDim, (NewDim < i)>:: \
make(state, t); \
} \
/* break; */
template <typename T, int NewDim,
typename IndexT, template <typename U> class PtrTraits>
THCDeviceTensor<T, NewDim, IndexT, PtrTraits>
toDeviceTensorCast(THCState* state, THCudaTensor* t) {
switch (THCudaTensor_nDimension(state, t)) {
SWITCH_UNROLL_CUDA_CAST_FACTORY(1);
SWITCH_UNROLL_CUDA_CAST_FACTORY(2);
SWITCH_UNROLL_CUDA_CAST_FACTORY(3);
SWITCH_UNROLL_CUDA_CAST_FACTORY(4);
SWITCH_UNROLL_CUDA_CAST_FACTORY(5);
SWITCH_UNROLL_CUDA_CAST_FACTORY(6);
SWITCH_UNROLL_CUDA_CAST_FACTORY(7);
SWITCH_UNROLL_CUDA_CAST_FACTORY(8);
SWITCH_UNROLL_CUDA_CAST_FACTORY(9);
SWITCH_UNROLL_CUDA_CAST_FACTORY(10);
default:
;
}
// Not implemented
THError("THCDeviceTensor dimension size not supported");
return NULL; /* never enters this piece, appeasing compiler warnings */
}
#undef SWITCH_UNROLL_CUDA_CAST_FACTORY