blob: 623add80b26d46193d2ffe4bc55670b6d3985eb6 [file] [log] [blame]
#include "THCUNN.h"
#include "common.h"
struct ThresholdUpdateOutput
{
const float threshold_;
const float val_;
ThresholdUpdateOutput(float threshold, float val)
: threshold_(threshold)
, val_(val)
{}
__device__ __forceinline__ void operator()(float *out, float *in)
{
float x = *in;
*out = (x > threshold_) ? x : val_;
}
};
// in-place variant
struct ThresholdUpdateOutputIP
{
const float threshold_;
const float val_;
ThresholdUpdateOutputIP(float threshold, float val)
: threshold_(threshold)
, val_(val)
{}
__device__ __forceinline__ void operator()(float *x)
{
*x = (*x > threshold_) ? *x : val_;
}
};
void THNN_CudaThreshold_updateOutput(THCState *state, THCudaTensor *input, THCudaTensor *output,
double threshold, double val, bool inplace)
{
THCUNN_assertSameGPU(state, 2, input, output);
if (inplace)
{
THC_pointwiseApply1(state, input,
ThresholdUpdateOutputIP(threshold, val)
);
THCudaTensor_set(state, output, input);
}
else
{
THCudaTensor_resizeAs(state, output, input);
THC_pointwiseApply2(state, output, input,
ThresholdUpdateOutput(threshold, val)
);
}
THCudaCheck(cudaGetLastError());
}
struct ThresholdUpdateGradInput
{
const float threshold_;
ThresholdUpdateGradInput(float threshold)
: threshold_(threshold)
{}
__device__ __forceinline__ void operator()(
float *gradInput, float *input, float *gradOutput) const
{
*gradInput = (*input > threshold_) ? *gradOutput : 0;
}
};
struct ThresholdUpdateGradInputIP
{
const float threshold_;
ThresholdUpdateGradInputIP(float threshold)
: threshold_(threshold)
{}
__device__ __forceinline__ void operator()(
float *gradOutput, float *input) const
{
*gradOutput = (*input > threshold_) ? *gradOutput : 0;
}
};
void THNN_CudaThreshold_updateGradInput(THCState *state, THCudaTensor *input, THCudaTensor *gradOutput,
THCudaTensor *gradInput, double threshold, double val, bool inplace)
{
THCUNN_assertSameGPU(state, 3, input, gradInput, gradOutput);
if (inplace)
{
THC_pointwiseApply2(state, gradOutput, input,
ThresholdUpdateGradInputIP(threshold)
);
THCudaTensor_set(state, gradInput, gradOutput);
}
else
{
THCudaTensor_resizeAs(state, gradInput, input);
THC_pointwiseApply3(state, gradInput, input, gradOutput,
ThresholdUpdateGradInput(threshold)
);
}
THCudaCheck(cudaGetLastError());
}