fixing inconsistent API in addcmul, addcdiv
diff --git a/THCTensorMath.cu b/THCTensorMath.cu
index 9fabe7e..915707d 100644
--- a/THCTensorMath.cu
+++ b/THCTensorMath.cu
@@ -189,8 +189,13 @@
}
-void THCudaTensor_addcmul(THCudaTensor *self_, float value, THCudaTensor *src1, THCudaTensor *src2)
+void THCudaTensor_addcmul(THCudaTensor *self_, THCudaTensor *t, float value, THCudaTensor *src1, THCudaTensor *src2)
{
+ if(self_ != t)
+ {
+ THCudaTensor_resizeAs(self_, t);
+ THCudaTensor_copy(self_, t);
+ }
THCudaTensor_resizeAs(self_, src1);
THArgCheck(THCudaTensor_nElement(src1) == THCudaTensor_nElement(src2), 3, "size do not match");
{
@@ -225,8 +230,14 @@
}
-void THCudaTensor_addcdiv(THCudaTensor *self_, float value, THCudaTensor *src1, THCudaTensor *src2)
+void THCudaTensor_addcdiv(THCudaTensor *self_, THCudaTensor *t, float value, THCudaTensor *src1, THCudaTensor *src2)
{
+ if(self_ != t)
+ {
+ THCudaTensor_resizeAs(self_, t);
+ THCudaTensor_copy(self_, t);
+ }
+
THCudaTensor_resizeAs(self_, src1);
THArgCheck(THCudaTensor_nElement(src1) == THCudaTensor_nElement(src2), 3, "size do not match");
{
diff --git a/THCTensorMath.h b/THCTensorMath.h
index 915e269..7bd76d7 100644
--- a/THCTensorMath.h
+++ b/THCTensorMath.h
@@ -16,8 +16,8 @@
THC_API void THCudaTensor_cmul(THCudaTensor *self, THCudaTensor *src1, THCudaTensor *src2);
THC_API void THCudaTensor_cdiv(THCudaTensor *self, THCudaTensor *src1, THCudaTensor *src2);
-THC_API void THCudaTensor_addcmul(THCudaTensor *self, float value, THCudaTensor *src1, THCudaTensor *src2);
-THC_API void THCudaTensor_addcdiv(THCudaTensor *self, float value, THCudaTensor *src1, THCudaTensor *src2);
+THC_API void THCudaTensor_addcmul(THCudaTensor *self, THCudaTensor* t, float value, THCudaTensor *src1, THCudaTensor *src2);
+THC_API void THCudaTensor_addcdiv(THCudaTensor *self, THCudaTensor* t, float value, THCudaTensor *src1, THCudaTensor *src2);
THC_API float THCudaTensor_dot(THCudaTensor *self, THCudaTensor *src);