blob: 2984672f814157a24c02deb0a5f0c77dbbebd8b5 [file] [log] [blame]
#include "THCUNN.h"
#include "common.h"
#include <thrust/fill.h>
#include <thrust/functional.h>
#include <thrust/device_ptr.h>
#include <thrust/reduce.h>
#include <thrust/inner_product.h>
struct kl_functor
{
__host__ __device__ float operator()(const float& x, const float& y) const
{
return y > 0 ? y * (log(y) - x) : 0;
}
};
void THNN_CudaDistKLDivCriterion_updateOutput(THCState *state, THCudaTensor *input, THCudaTensor *target, THCudaTensor *output, bool sizeAverage)
{
THCUNN_assertSameGPU(state, 2, input, target);
THArgCheck(THCudaTensor_nElement(state, input) == THCudaTensor_nElement(state, target), 2,
"input and target need to have the same number of elements");
float sum;
long size = THCudaTensor_nElement(state, input);
input = THCudaTensor_newContiguous(state, input);
target = THCudaTensor_newContiguous(state, target);
thrust::device_ptr<float> input_data(THCudaTensor_data(state, input));
thrust::device_ptr<float> target_data(THCudaTensor_data(state, target));
sum = thrust::inner_product(input_data, input_data+size, target_data, (float) 0, thrust::plus<float>(), kl_functor());
if (sizeAverage)
sum /= size;
THCudaTensor_free(state, input);
THCudaTensor_free(state, target);
THCudaTensor_set1d(state, output, 0, sum);
}
struct kl_updateGradInput_functor
{
const float norm;
kl_updateGradInput_functor(float norm_)
: norm(norm_)
{}
__host__ __device__ float operator()(const float& x, const float& y) const
{
return y > 0 ? norm * (-y) : 0;
}
};
void THNN_CudaDistKLDivCriterion_updateGradInput(THCState *state, THCudaTensor *input, THCudaTensor *target, THCudaTensor *gradInput, bool sizeAverage)
{
THCUNN_assertSameGPU(state, 3, input, target, gradInput);
THArgCheck(THCudaTensor_nElement(state, input) == THCudaTensor_nElement(state, target), 2,
"input and target need to have the same number of elements");
long size = THCudaTensor_nElement(state, input);
float norm = (sizeAverage ? 1./size : 1.);
input = THCudaTensor_newContiguous(state, input);
target = THCudaTensor_newContiguous(state, target);
THCudaTensor_resizeAs(state, gradInput, input);
thrust::device_ptr<float> input_data(THCudaTensor_data(state, input));
thrust::device_ptr<float> target_data(THCudaTensor_data(state, target));
thrust::device_ptr<float> gradInput_data(THCudaTensor_data(state, gradInput));
thrust::transform(input_data, input_data+size, target_data, gradInput_data, kl_updateGradInput_functor(norm));
THCudaTensor_free(state, input);
THCudaTensor_free(state, target);
}