blob: d9cd1c0d501e383a3e76f3fef1fa576391a3c210 [file] [log] [blame]
#ifndef TH_GENERIC_FILE
#define TH_GENERIC_FILE "generic/THTensorCopy.c"
#else
int THTensor_(copyTransposeValid)(THTensor *tensor, THTensor *src) {
const int MIN_SZ = 60 * 60;
return THTensor_(isContiguous)(tensor) &&
THTensor_(nDimension)(src) == 2 &&
THTensor_(stride)(src, 0) == 1 &&
THTensor_(stride)(src, 1) == THTensor_(size)(src, 0) &&
THTensor_(nElement)(tensor) >= MIN_SZ;
}
// special case copy where tensor is contiguous and src is a transposed matrix
// This can be generalized to most copies, but it's tricker
void THTensor_(copyTranspose)(THTensor *tensor, THTensor *src) {
#define MIN(x, y) (((x) < (y)) ? (x) : (y))
#define MAX(x, y) (((x) > (y)) ? (x) : (y))
#ifdef TH_REAL_IS_BYTE
const int BLOCK_SZ = 120;
#else
const int BLOCK_SZ = 60;
#endif
THTensor *buf = THTensor_(newWithSize2d)(BLOCK_SZ, BLOCK_SZ);
real *sp = THTensor_(data)(src);
real *rp = THTensor_(data)(tensor);
real *bp = THTensor_(data)(buf);
long NR = THTensor_(size)(src, 0);
long NC = THTensor_(size)(src, 1);
for (long R = 0; R < NR; R += BLOCK_SZ) {
for (long C = 0; C < NC; C += BLOCK_SZ) {
real *spo = sp + R + C * NR;
real *rpo = rp + C + R * NC;
int nr = MIN(NR - R, BLOCK_SZ);
int nc = MIN(NC - C, BLOCK_SZ);
// 1. copy columns from src to buf
for (int c = 0; c < nc; c++) {
memcpy(bp + c * BLOCK_SZ, spo + c * NR, nr * sizeof(real));
}
// 2. transpose buf in place
int rc_max = MAX(nr, nc);
int rc_min = MIN(nr, nc);
for (int r = 0; r < rc_max; r++) {
int end = MIN(r, rc_min);
for (int c = 0; c < end; c++) {
real tmp = bp[r + BLOCK_SZ * c];
bp[r + BLOCK_SZ * c] = bp[r * BLOCK_SZ + c];
bp[r * BLOCK_SZ + c] = tmp;
}
}
// 3. copy rows from buf to dst
for (int r = 0; r < nr; r++) {
memcpy(rpo + r * NC, bp + r * BLOCK_SZ, nc * sizeof(real));
}
}
}
THTensor_(free)(buf);
#undef MIN
#undef MAX
}
void THTensor_(copy)(THTensor *tensor, THTensor *src)
{
if (tensor == src) return;
if (THTensor_(isContiguous)(tensor) && THTensor_(isContiguous)(src) && THTensor_(nElement)(tensor) == THTensor_(nElement)(src)) {
real *sp = THTensor_(data)(src);
real *rp = THTensor_(data)(tensor);
ptrdiff_t sz = THTensor_(nElement)(tensor);
#ifndef TH_REAL_IS_HALF
THVector_(copy)(rp, sp, sz);
#else
memcpy(rp, sp, sz * sizeof(real));
#endif
#ifndef TH_REAL_IS_HALF
} else if (THTensor_(copyTransposeValid)(tensor, src)) {
THTensor_(copyTranspose)(tensor, src);
#endif
} else {
TH_TENSOR_APPLY2(real, tensor, real, src, *tensor_data = *src_data;)
}
}
#define IMPLEMENT_THTensor_COPY(TYPENAMESRC, TYPE_SRC) \
void THTensor_(copy##TYPENAMESRC)(THTensor *tensor, TH##TYPENAMESRC##Tensor *src) \
{ \
TH_TENSOR_APPLY2(real, tensor, TYPE_SRC, src, *tensor_data = (real)(*src_data);) \
}
#define IMPLEMENT_THTensor_COPY_TO_HALF(TYPENAMESRC, TYPE_SRC) \
void THTensor_(copy##TYPENAMESRC)(THTensor *tensor, TH##TYPENAMESRC##Tensor *src) \
{ \
TH_TENSOR_APPLY2(real, tensor, TYPE_SRC, src, *tensor_data = TH_float2half((float)*src_data);) \
}
#define IMPLEMENT_THTensor_COPY_FROM_HALF(TYPENAMESRC, TYPE_SRC) \
void THTensor_(copy##TYPENAMESRC)(THTensor *tensor, TH##TYPENAMESRC##Tensor *src) \
{ \
TH_TENSOR_APPLY2(real, tensor, TYPE_SRC, src, *tensor_data = (real)TH_half2float(*src_data);) \
}
#define IMPLEMENT_THTensor_COPY_TO_FROM_HALF(TYPENAMESRC, TYPE_SRC) \
void THTensor_(copy##TYPENAMESRC)(THTensor *tensor, TH##TYPENAMESRC##Tensor *src) \
{ \
TH_TENSOR_APPLY2(real, tensor, TYPE_SRC, src, *tensor_data = *src_data;) \
}
#ifndef TH_REAL_IS_HALF
IMPLEMENT_THTensor_COPY(Byte, unsigned char)
IMPLEMENT_THTensor_COPY(Char, char)
IMPLEMENT_THTensor_COPY(Short, short)
IMPLEMENT_THTensor_COPY(Int, int)
IMPLEMENT_THTensor_COPY(Long, long)
IMPLEMENT_THTensor_COPY(Float, float)
IMPLEMENT_THTensor_COPY(Double, double)
IMPLEMENT_THTensor_COPY_FROM_HALF(Half, THHalf)
#else
/* only allow pass-through for Half */
IMPLEMENT_THTensor_COPY_TO_FROM_HALF(Half, THHalf)
IMPLEMENT_THTensor_COPY_TO_HALF(Byte, unsigned char)
IMPLEMENT_THTensor_COPY_TO_HALF(Char, char)
IMPLEMENT_THTensor_COPY_TO_HALF(Short, short)
IMPLEMENT_THTensor_COPY_TO_HALF(Int, int)
IMPLEMENT_THTensor_COPY_TO_HALF(Long, long)
IMPLEMENT_THTensor_COPY_TO_HALF(Float, float)
IMPLEMENT_THTensor_COPY_TO_HALF(Double, double)
#endif /* REAL_IS_HALF */
#endif