Correct computation of L0-norm for CudaTensor.norm
Same scheme as in 9b1049fc2c8e0b087d3420a37ee32977f9b281cd
diff --git a/THCTensorMath.cu b/THCTensorMath.cu
index 2eb4c7f..3771f25 100644
--- a/THCTensorMath.cu
+++ b/THCTensorMath.cu
@@ -996,22 +996,33 @@
}
};
+
float THCudaTensor_normall(THCudaTensor *self, float value)
{
self = THCudaTensor_newContiguous(self);
long size = THCudaTensor_nElement(self);
thrust::device_ptr<float> self_data(THCudaTensor_data(self));
- float result = thrust::transform_reduce(self_data, self_data+size, norm_functor(value), (float)0, thrust::plus<float>());
+ float result;
+ if(value == 0.0f) {
+ result = thrust::transform_reduce(self_data, self_data+size, partial_not_equal_functor(0.0f), (float)0, thrust::plus<float>());
+ } else {
+ result = thrust::transform_reduce(self_data, self_data+size, norm_functor(value), (float)0, thrust::plus<float>());
+ result = pow(result, (float)1.0/value);
+ }
THCudaTensor_free(self);
- return pow(result, (float)1.0/value);
+ return result;
}
void THCudaTensor_norm(THCudaTensor* self, THCudaTensor* src, float value, long dimension)
{
- THCudaTensor_transformReduceDim(self, src, dimension, norm_functor(value), (float)0, thrust::plus<float>());
+ if(value == 0.0f) {
+ THCudaTensor_transformReduceDim(self, src, dimension, partial_not_equal_functor(0.0f), (float)0, thrust::plus<float>());
+ } else {
+ THCudaTensor_transformReduceDim(self, src, dimension, norm_functor(value), (float)0, thrust::plus<float>());
+ }
}