blob: dc8544ca21aba7cbfe201c8f69f064384f5a9c80 [file] [log] [blame]
#include "THCTensorMath.h"
#include "THCGeneral.h"
#include "THCBlas.h"
#include "THCTensorCopy.h"
#include "THCTensorRandom.h"
#include "THCApply.cuh"
#include "THCReduce.cuh"
#include <thrust/device_ptr.h>
#include <thrust/transform_reduce.h>
#include <thrust/functional.h>
#include <thrust/inner_product.h>
#if CUDA_VERSION >= 7000
#include <thrust/system/cuda/execution_policy.h>
#endif
struct TensorPowOp {
TensorPowOp(float v) : val(v) {}
__device__ __forceinline__ void operator()(float* out, float* in) {
*out = powf(*in, val);
}
__device__ __forceinline__ void operator()(float* v) {
*v = powf(*v, val);
}
const float val;
};
void THCudaTensor_pow(THCState *state, THCudaTensor *self_, THCudaTensor *src, float value)
{
THAssert(THCudaTensor_checkGPU(state, 2, self_, src));
if (self_ == src) {
if (!THC_pointwiseApply1(state, self_, TensorPowOp(value))) {
THArgCheck(false, 2, CUTORCH_DIM_WARNING);
}
} else {
THCudaTensor_resizeAs(state, self_, src);
if (!THC_pointwiseApply2(state, self_, src, TensorPowOp(value))) {
THArgCheck(false, 2, CUTORCH_DIM_WARNING);
}
}
THCudaCheck(cudaGetLastError());
}
struct TensorTPowOp {
TensorTPowOp(float v) : val(v) {}
__device__ __forceinline__ void operator()(float* out, float* in) {
*out = powf(val, *in);
}
__device__ __forceinline__ void operator()(float* v) {
*v = powf(val, *v);
}
const float val;
};
void THCudaTensor_tpow(THCState *state, THCudaTensor *self_, float value, THCudaTensor *src)
{
THAssert(THCudaTensor_checkGPU(state, 2, self_, src));
if (self_ == src) {
if (!THC_pointwiseApply1(state, self_, TensorTPowOp(value))) {
THArgCheck(false, 2, CUTORCH_DIM_WARNING);
}
} else {
THCudaTensor_resizeAs(state, self_, src);
if (!THC_pointwiseApply2(state, self_, src, TensorTPowOp(value))) {
THArgCheck(false, 2, CUTORCH_DIM_WARNING);
}
}
THCudaCheck(cudaGetLastError());
}
struct TensorATan2Op {
__device__ __forceinline__ void operator()(float* out, float* a, float* b) {
*out = atan2f(*a, *b);
}
};
void THCudaTensor_atan2(THCState *state, THCudaTensor *self_, THCudaTensor *tx, THCudaTensor *ty)
{
THAssert(THCudaTensor_checkGPU(state, 3, self_, tx, ty));
THArgCheck(THCudaTensor_nElement(state, tx) ==
THCudaTensor_nElement(state, ty), 3, "sizes do not match");
THCudaTensor_resizeAs(state, self_, tx);
if (!THC_pointwiseApply3(state, self_, tx, ty, TensorATan2Op())) {
THArgCheck(false, 2, CUTORCH_DIM_WARNING);
}
THCudaCheck(cudaGetLastError());
}
struct TensorClampOp {
TensorClampOp(float min, float max) : minValue(min), maxValue(max) {}
__device__ __forceinline__ void operator()(float* out, float* in) {
*out = max(min(*in, maxValue), minValue);
}
__device__ __forceinline__ void operator()(float* v) {
*v = max(min(*v, maxValue), minValue);
}
const float minValue;
const float maxValue;
};
void THCudaTensor_clamp(THCState *state, THCudaTensor *self_, THCudaTensor *src, float min_value,
float max_value)
{
THAssert(THCudaTensor_checkGPU(state, 2, self_, src));
if (self_ == src) {
if (!THC_pointwiseApply1(state, self_, TensorClampOp(min_value, max_value))) {
THArgCheck(false, 2, CUTORCH_DIM_WARNING);
}
} else {
THCudaTensor_resizeAs(state, self_, src);
if (!THC_pointwiseApply2(state, self_, src, TensorClampOp(min_value, max_value))) {
THArgCheck(false, 2, CUTORCH_DIM_WARNING);
}
}
THCudaCheck(cudaGetLastError());
}
struct TensorSignOp {
__device__ __forceinline__ void operator()(float* out, float* in) {
float orig = *in;
*out = (orig > 0) - (orig < 0);
}
__device__ __forceinline__ void operator()(float* v) {
float orig = *v;
*v = (orig > 0) - (orig < 0);
}
};
void THCudaTensor_sign(THCState *state, THCudaTensor *self_, THCudaTensor *src)
{
THAssert(THCudaTensor_checkGPU(state, 2, self_, src));
if (self_ == src) {
if (!THC_pointwiseApply1(state, self_, TensorSignOp())) {
THArgCheck(false, 2, CUTORCH_DIM_WARNING);
}
} else {
THCudaTensor_resizeAs(state, self_, src);
if (!THC_pointwiseApply2(state, self_, src, TensorSignOp())) {
THArgCheck(false, 2, CUTORCH_DIM_WARNING);
}
}
THCudaCheck(cudaGetLastError());
}
float THCudaTensor_meanall(THCState *state, THCudaTensor *self)
{
THAssert(THCudaTensor_checkGPU(state, 1, self));
THArgCheck(self->nDimension > 0, 1, "empty Tensor");
return THCudaTensor_sumall(state, self)/THCudaTensor_nElement(state, self);
}
void
THCudaTensor_mean(THCState *state, THCudaTensor *self, THCudaTensor *src, long dim)
{
THAssert(THCudaTensor_checkGPU(state, 2, self, src));
THCudaTensor_sum(state, self, src, dim);
THCudaTensor_div(state, self, self, THCudaTensor_size(state, src, dim));
}
struct TensorLerpOp {
TensorLerpOp(float w) : w(w) {}
__device__ __forceinline__ void operator()(float *out, float *a, float *b) {
*out = *a + w * (*b - *a);
}
const float w;
};
void THCudaTensor_lerp(THCState *state, THCudaTensor *result, THCudaTensor *a, THCudaTensor *b, float w)
{
THAssert(THCudaTensor_checkGPU(state, 3, result, a, b));
THArgCheck(THCudaTensor_nElement(state, a) ==
THCudaTensor_nElement(state, b), 3, "sizes do not match");
THCudaTensor_resizeAs(state, result, a);
if (!THC_pointwiseApply3(state, result, a, b, TensorLerpOp(w))) {
THArgCheck(false, 2, CUTORCH_DIM_WARNING);
}
THCudaCheck(cudaGetLastError());
}
struct square_functor
{
const float mean;
square_functor(float mean_) : mean(mean_) {}
__host__ __device__ float operator()(const float& x) const
{
return (x-mean)*(x-mean);
}
};
float THCudaTensor_varall(THCState *state, THCudaTensor *self)
{
THAssert(THCudaTensor_checkGPU(state, 1, self));
self = THCudaTensor_newContiguous(state, self);
long size = THCudaTensor_nElement(state, self);
thrust::device_ptr<float> self_data(THCudaTensor_data(state, self));
float mean = THCudaTensor_meanall(state, self);
float result =
thrust::transform_reduce(
#if CUDA_VERSION >= 7000
thrust::cuda::par.on(THCState_getCurrentStream(state)),
#endif
self_data, self_data+size, square_functor(mean),
(float)0, thrust::plus<float>());
result = result/(THCudaTensor_nElement(state, self)-1);
THCudaTensor_free(state, self);
return result;
}
float THCudaTensor_stdall(THCState *state, THCudaTensor *self)
{
THAssert(THCudaTensor_checkGPU(state, 1, self));
return sqrt(THCudaTensor_varall(state, self));
}
// Given the sum of values and the sum of squares, compute the variance or standard deviation.
template<bool flag, bool apply_sqrt>
__forceinline__ __device__ float THCudaTensor_computeVar(float sum, float sum2, unsigned row_size) {
if (flag) {
sum /= row_size;
sum2 /= row_size;
sum2 -= sum * sum;
sum2 = (sum2 < 0 ? 0 : sum2);
}
else {
sum /= row_size;
sum2 /= row_size - 1;
sum2 -= ((float)row_size) / ((float)(row_size - 1)) * sum * sum;
sum2 = (sum2 < 0 ? 0 : sum2);
}
if (apply_sqrt)
return sqrt(sum2);
else
return sum2;
}
/* Compute the variance (or standard deviation) along an outer dimension of a tensor.
*
* - num_orows is the size of the flattened outer dimensions;
* - num_irows is the size of the flattened inner dimensions;
* - row_size is the size of the dimension along which to compute the variance;
* - if flag is set, normalize by `row_size` instead of `row_size - 1`
* - if apply_sqrt is set, compute the standard deviation instead of variance
*
* The dimensions to the outside and inside of the specified dimension are considered as flattened.
* Thread blocks with the same blockIdx.y process an "outer row" (i.e. an element of the flattened
* outer dimensions, which contains several "inner rows").
* Each thread processes a single inner row at a time.
*/
template<bool flag, bool apply_sqrt>
__global__ void THCudaTensor_kernel_varOuterDim(float *tgt, float *src_, unsigned num_orows, unsigned num_irows, unsigned row_size)
{
for (unsigned orow = blockIdx.x; orow < num_orows; orow += gridDim.x) {
for (unsigned irow = blockIdx.y * blockDim.x + threadIdx.x; irow < num_irows; irow += gridDim.y * blockDim.x) {
float *src = src_ + orow * row_size * num_irows + irow;
float sum = 0, sum2 = 0;
for (unsigned col = 0; col < row_size; ++col) {
float val = *src;
sum += val;
sum2 += val * val;
src += num_irows;
}
tgt[orow * num_irows + irow] = THCudaTensor_computeVar<flag, apply_sqrt>(sum, sum2, row_size);
}
}
}
template<bool apply_sqrt>
__host__ void THCudaTensor_varOuterDim(THCState *state, THCudaTensor *tgt, THCudaTensor *src, long dimension, int flag)
{
unsigned ndim = THCudaTensor_nDimension(state, src);
// Treat all outer dimensions (i.e. dim < dimension) as one.
unsigned num_orows = 1;
for (unsigned dim = 0; dim < dimension; dim++) {
num_orows *= THCudaTensor_size(state, src, dim);
}
unsigned row_size = THCudaTensor_size(state, src, dimension);
// Treat all inner dimensions (i.e. dim > dimension) as one.
unsigned num_irows = 1;
for (unsigned dim = dimension + 1; dim < ndim; dim++) {
num_irows *= THCudaTensor_size(state, src, dim);
}
dim3 threads(min(512, num_irows));
unsigned maxGridDim = 1024;
dim3 grid(min(maxGridDim, num_orows), min(maxGridDim, THCCeilDiv(num_irows, threads.x)));
if (flag) {
THCudaTensor_kernel_varOuterDim<true, apply_sqrt><<<grid, threads, 0, THCState_getCurrentStream(state)>>>(
THCudaTensor_data(state, tgt), THCudaTensor_data(state, src), num_orows, num_irows, row_size);
} else {
THCudaTensor_kernel_varOuterDim<false, apply_sqrt><<<grid, threads, 0, THCState_getCurrentStream(state)>>>(
THCudaTensor_data(state, tgt), THCudaTensor_data(state, src), num_orows, num_irows, row_size);
}
cudaError errcode = cudaGetLastError();
if (errcode != cudaSuccess) {
THError(cudaGetErrorString(errcode));
}
}
/* Compute the variance (or standard deviation) of the innermost dimension of a tensor.
*
* - num_rows is the size of the flattened outer dimensions;
* - row_size is the size of the innermost dimension;
* - if flag is set, normalize by `row_size` instead of `row_size - 1`
* - if apply_sqrt is set, compute the standard deviation instead of variance
*
* The outer dimensions of the tensor are considered as a single dimension, i.e. the tensor is
* considered as having 'num_rows' rows of size 'row_size'.
* Each thread block processes one or more sets of contiguous rows (processing multiple rows
* per thread block is quicker than processing a single row, especially for short rows).
*/
template<bool flag, bool apply_sqrt>
__global__ void THCudaTensor_kernel_varInnermostDim(float *tgt, float *src_, unsigned num_rows, unsigned row_size)
{
__shared__ float ssum[32][16];
__shared__ float ssum2[32][16];
for (unsigned block_row = blockIdx.x * blockDim.y; block_row < num_rows; block_row += blockDim.y * gridDim.x) {
unsigned row = block_row + threadIdx.y;
float sum = 0, sum2 = 0;
if (row < num_rows) {
float *src = src_ + row * row_size;
// Sequential reduction within a thread.
for (unsigned col = threadIdx.x; col < row_size; col += blockDim.x) {
float val = src[col];
sum += val;
sum2 += val * val;
}
}
ssum[threadIdx.y][threadIdx.x] = sum;
ssum2[threadIdx.y][threadIdx.x] = sum2;
__syncthreads();
// Reduce intermediate values to single value.
for (unsigned s = 8; s > 1; s >>= 1) {
if (row < num_rows && threadIdx.x < s) {
ssum[threadIdx.y][threadIdx.x] += ssum[threadIdx.y][threadIdx.x + s];
ssum2[threadIdx.y][threadIdx.x] += ssum2[threadIdx.y][threadIdx.x + s];
}
__syncthreads();
}
if (row < num_rows && threadIdx.x == 0) {
sum = ssum[threadIdx.y][0] + ssum[threadIdx.y][1];
sum2 = ssum2[threadIdx.y][0] + ssum2[threadIdx.y][1];
tgt[row] = THCudaTensor_computeVar<flag, apply_sqrt>(sum, sum2, row_size);
}
__syncthreads();
}
}
template<bool apply_sqrt>
__host__ void THCudaTensor_varInnermostDim(THCState *state, THCudaTensor *tgt, THCudaTensor *src, int flag)
{
unsigned ndim = THCudaTensor_nDimension(state, src);
// Treat all outer dimensions as a single dimension.
unsigned num_rows = 1;
for (unsigned dim = 0; dim < ndim - 1; dim++) {
num_rows *= THCudaTensor_size(state, src, dim);
}
unsigned row_size = THCudaTensor_size(state, src, ndim - 1);
// From limited testing, 16x32 seemed a good compromise for handling both long and short dimensions.
dim3 threads(16, 32);
dim3 grid(min(1024, THCCeilDiv(num_rows, threads.y)));
if (flag) {
THCudaTensor_kernel_varInnermostDim<true, apply_sqrt><<<grid, threads, 0, THCState_getCurrentStream(state)>>>(
THCudaTensor_data(state, tgt), THCudaTensor_data(state, src), num_rows, row_size);
} else {
THCudaTensor_kernel_varInnermostDim<false, apply_sqrt><<<grid, threads, 0, THCState_getCurrentStream(state)>>>(
THCudaTensor_data(state, tgt), THCudaTensor_data(state, src), num_rows, row_size);
}
cudaError errcode = cudaGetLastError();
if (errcode != cudaSuccess) {
THError(cudaGetErrorString(errcode));
}
}
void THCudaTensor_var(THCState *state, THCudaTensor *self_, THCudaTensor *src, long dimension, int flag)
{
THAssert(THCudaTensor_checkGPU(state, 2, self_, src));
THLongStorage *dim = THCudaTensor_newSizeOf(state, src);
THLongStorage_set(dim, dimension, 1);
THCudaTensor_resize(state, self_, dim, NULL);
THLongStorage_free(dim);
THCudaTensor *self = THCudaTensor_newContiguous(state, self_);
src = THCudaTensor_newContiguous(state, src);
if (dimension == THCudaTensor_nDimension(state, src) - 1) {
THCudaTensor_varInnermostDim<false>(state, self, src, flag);
} else {
THCudaTensor_varOuterDim<false>(state, self, src, dimension, flag);
}
THCudaTensor_free(state, src);
THCudaTensor_freeCopyTo(state, self, self_);
}
void THCudaTensor_std(THCState *state, THCudaTensor *self_, THCudaTensor *src, long dimension, int flag)
{
THAssert(THCudaTensor_checkGPU(state, 2, self_, src));
THLongStorage *dim = THCudaTensor_newSizeOf(state, src);
THLongStorage_set(dim, dimension, 1);
THCudaTensor_resize(state, self_, dim, NULL);
THLongStorage_free(dim);
THCudaTensor *self = THCudaTensor_newContiguous(state, self_);
src = THCudaTensor_newContiguous(state, src);
if (dimension == THCudaTensor_nDimension(state, src) - 1) {
THCudaTensor_varInnermostDim<true>(state, self, src, flag);
} else {
THCudaTensor_varOuterDim<true>(state, self, src, dimension, flag);
}
THCudaTensor_free(state, src);
THCudaTensor_freeCopyTo(state, self, self_);
}
template <int StaticExp>
struct TensorNormOp
{
TensorNormOp(float exp) : exponent(exp) {}
__host__ __device__ float operator()(float x) const {
if (StaticExp == 1) {
return fabsf(x);
} else if (StaticExp == 2) {
return x * x;
} else {
return powf(fabsf(x), exponent);
}
}
const float exponent;
};
struct TensorNonZeroOp
{
TensorNonZeroOp() {}
__host__ __device__ bool operator()(float lhs) const { return lhs != 0.0f; }
};
float THCudaTensor_normall(THCState *state, THCudaTensor *self, float value)
{
THAssert(THCudaTensor_checkGPU(state, 1, self));
self = THCudaTensor_newContiguous(state, self);
long size = THCudaTensor_nElement(state, self);
thrust::device_ptr<float> self_data(THCudaTensor_data(state, self));
float result;
if (value == 0.0f) {
result = thrust::transform_reduce(
#if CUDA_VERSION >= 7000
thrust::cuda::par.on(THCState_getCurrentStream(state)),
#endif
self_data, self_data+size, TensorNonZeroOp(),
0.0f, thrust::plus<float>());
} else if (value == 1.0f) {
result = thrust::transform_reduce(
#if CUDA_VERSION >= 7000
thrust::cuda::par.on(THCState_getCurrentStream(state)),
#endif
self_data, self_data+size, TensorNormOp<1>(value),
0.0f, thrust::plus<float>());
} else if (value == 2.0f) {
result = thrust::transform_reduce(
#if CUDA_VERSION >= 7000
thrust::cuda::par.on(THCState_getCurrentStream(state)),
#endif
self_data, self_data+size, TensorNormOp<2>(value),
0.0f, thrust::plus<float>());
result = powf(result, 0.5f);
} else {
result = thrust::transform_reduce(
#if CUDA_VERSION >= 7000
thrust::cuda::par.on(THCState_getCurrentStream(state)),
#endif
self_data, self_data+size, TensorNormOp<-1>(value),
0.0f, thrust::plus<float>());
result = powf(result, 1.0f / value);
}
THCudaTensor_free(state, self);
return result;
}
void THCudaTensor_norm(THCState *state, THCudaTensor* self, THCudaTensor* src, float value, long dimension)
{
THAssert(THCudaTensor_checkGPU(state, 2, self, src));
if (value == 0.0f) {
THC_reduceDim(state, self, src,
TensorNonZeroOp(), thrust::plus<float>(),
0.0f, dimension);
} else if (value == 1.0f) {
THC_reduceDim(state, self, src,
TensorNormOp<1>(value), thrust::plus<float>(),
0.0f, dimension);
} else if (value == 2.0f) {
THC_reduceDim(state, self, src,
TensorNormOp<2>(value), thrust::plus<float>(),
0.0f, dimension);
THCudaTensor_pow(state, self, self, 0.5f);
} else {
THC_reduceDim(state, self, src,
TensorNormOp<-1>(value), thrust::plus<float>(),
0.0f, dimension);
THCudaTensor_pow(state, self, self, 1.0f / value);
}
THCudaCheck(cudaGetLastError());
}
__global__ void THCudaTensor_kernel_renorm(float *data, const float value, const long size, const float maxnorm)
{
__shared__ float buffer[32];
long tx = threadIdx.x;
long bx = blockIdx.x;
long step = blockDim.x;
float *row = data + size*bx;
buffer[tx] = 0;
// get norm of axis
for (long i=tx; i<size; i+=step)
{
buffer[tx] += pow(fabs(row[i]), value);
}
// add (reduce)
for (unsigned int stride = blockDim.x >> 1; stride > 0; stride >>= 1)
{
__syncthreads();
if (tx < stride)
buffer[tx] += buffer[tx+stride];
}
// clip norms
__syncthreads();
float norm = pow(buffer[0], 1/value);
if (norm > maxnorm)
{
norm = maxnorm / (norm + 1e-7);
// renormalize
for (long i=tx; i<size; i+=step)
{
row[i] *= norm;
}
}
}
void THCudaTensor_renorm(THCState *state, THCudaTensor* self, THCudaTensor* src, float value, long dimension, float maxnorm)
{
THAssert(THCudaTensor_checkGPU(state, 2, self, src));
THCudaTensor *self_;
THCudaTensor *src_ = THCudaTensor_newTranspose(state, src, dimension, 0);
THCudaTensor *data = THCudaTensor_newClone(state, src_);
long size = THCudaTensor_nElement(state, data)/data->size[0];
THArgCheck(dimension >= 0 && dimension < THCudaTensor_nDimension(state, src), 3, "invalid dimension");
THArgCheck(value > 0, 2, "non-positive-norm not supported");
THArgCheck(THCudaTensor_nDimension(state, src) > 1, 1, "need at least 2 dimensions");
dim3 grid(data->size[0]);
dim3 threads(32);
THCudaTensor_kernel_renorm<<<grid, threads, 0, THCState_getCurrentStream(state)>>>(THCudaTensor_data(state, data), value, size, maxnorm);
cudaError errcode = cudaGetLastError();
if(errcode != cudaSuccess)
THError(cudaGetErrorString(errcode));
THCudaTensor_free(state, src_);
self_ = THCudaTensor_newTranspose(state, data, dimension, 0);
THCudaTensor_resizeAs(state, self, self_);
THCudaTensor_freeCopyTo(state, self_, self);
THCudaTensor_free(state, data);
}
struct dist_functor
{
const float exponent;
dist_functor(float exponent_) : exponent(exponent_) {}
__host__ __device__ float operator()(const float& x, const float& y) const
{
return pow(fabs(x-y), exponent);
}
};
float THCudaTensor_dist(THCState *state, THCudaTensor *self, THCudaTensor *src, float value)
{
THAssert(THCudaTensor_checkGPU(state, 2, self, src));
self = THCudaTensor_newContiguous(state, self);
long size = THCudaTensor_nElement(state, self);
src = THCudaTensor_newContiguous(state, src);
thrust::device_ptr<float> self_data(THCudaTensor_data(state, self));
thrust::device_ptr<float> src_data(THCudaTensor_data(state, src));
float result = thrust::inner_product(
#if CUDA_VERSION >= 7000
thrust::cuda::par.on(THCState_getCurrentStream(state)),
#endif
self_data, self_data+size, src_data, (float) 0,
thrust::plus<float>(), dist_functor(value));
THCudaTensor_free(state, src);
THCudaTensor_free(state, self);
return pow(result, (float)1.0/value);
}
void THCudaTensor_rand(THCState *state, THCudaTensor *r_, THLongStorage *size)
{
THAssert(THCudaTensor_checkGPU(state, 1, r_));
THCudaTensor_resize(state, r_, size, NULL);
THCudaTensor_uniform(state, r_, 0, 1);
}
void THCudaTensor_randn(THCState *state, THCudaTensor *r_, THLongStorage *size)
{
THAssert(THCudaTensor_checkGPU(state, 1, r_));
THCudaTensor_resize(state, r_, size, NULL);
THCudaTensor_normal(state, r_, 0, 1);
}
struct TensorCrossOp {
TensorCrossOp(long sx, long sy, long so) : sx(sx), sy(sy), so(so) {}
__device__ __forceinline__ void operator()(float* out, float* x, float*y) {
out[0 * so] = x[1 * sx] * y[2 * sy] - x[2 * sx] * y[1 * sy];
out[1 * so] = x[2 * sx] * y[0 * sy] - x[0 * sx] * y[2 * sy];
out[2 * so] = x[0 * sx] * y[1 * sy] - x[1 * sx] * y[0 * sy];
}
const long sx, sy, so;
};
THC_API void THCudaTensor_cross(THCState *state, THCudaTensor *self, THCudaTensor *x, THCudaTensor *y, int dimension)
{
THAssert(THCudaTensor_checkGPU(state, 3, self, x, y));
int i;
long nd = THCudaTensor_nDimension(state, x);
long nelem = THCudaTensor_nElement(state, x);
THArgCheck(nd == THCudaTensor_nDimension(state, y), 1, "tensors must have same number of dimensions");
for (i = 0; i < nd; i++) {
THArgCheck(THCudaTensor_size(state, x, i) == THCudaTensor_size(state, y, i), 1, "dimension %i of x and y does not match", i);
if (dimension < 0 && THCudaTensor_size(state, x, i) == 3) {
dimension = i;
}
}
THArgCheck(dimension >= 0 && dimension < nd, 3, "dimension %d out of range", dimension+1);
THArgCheck(THCudaTensor_size(state, x, dimension) == 3, 3,
"dimension %d does not have size 3", dimension+1);
THCudaTensor_resizeAs(state, self, x);
long sx = THCudaTensor_stride(state, x, dimension);
long sy = THCudaTensor_stride(state, y, dimension);
long so = THCudaTensor_stride(state, self, dimension);
THCudaTensor *nx = THCudaTensor_newNarrow(state, x, dimension, 0, 1);
THCudaTensor *ny = THCudaTensor_newNarrow(state, y, dimension, 0, 1);
THCudaTensor *nself = THCudaTensor_newNarrow(state, self, dimension, 0, 1);
if (!THC_pointwiseApply3(state, nself, nx, ny, TensorCrossOp(sx, sy, so))) {
THArgCheck(false, 2, CUTORCH_DIM_WARNING);
}
THCudaTensor_free(state, nx);
THCudaTensor_free(state, ny);
THCudaTensor_free(state, nself);
}