Optimize pow for different exponents and add tests
diff --git a/generic/THTensorMath.c b/generic/THTensorMath.c
index 43cbf83..1ed4ee2 100644
--- a/generic/THTensorMath.c
+++ b/generic/THTensorMath.c
@@ -2856,13 +2856,6 @@
TH_TENSOR_APPLY2(real, t, real, r_, *r__data = CFUNC(*t_data);); \
} \
-#define LAB_IMPLEMENT_BASIC_FUNCTION_VALUE(NAME, CFUNC) \
- void THTensor_(NAME)(THTensor *r_, THTensor *t, real value) \
- { \
- THTensor_(resizeAs)(r_, t); \
- TH_TENSOR_APPLY2(real, t, real, r_, *r__data = CFUNC(*t_data, value);); \
- } \
-
#if defined(TH_REAL_IS_LONG)
LAB_IMPLEMENT_BASIC_FUNCTION(abs,labs)
LAB_IMPLEMENT_BASIC_FUNCTION(neg,-)
@@ -2912,7 +2905,6 @@
LAB_IMPLEMENT_BASIC_FUNCTION(tan,TH_MATH_NAME(tan))
LAB_IMPLEMENT_BASIC_FUNCTION(atan,TH_MATH_NAME(atan))
LAB_IMPLEMENT_BASIC_FUNCTION(tanh,TH_MATH_NAME(tanh))
-LAB_IMPLEMENT_BASIC_FUNCTION_VALUE(pow,TH_MATH_NAME(pow))
LAB_IMPLEMENT_BASIC_FUNCTION(sqrt,TH_MATH_NAME(sqrt))
LAB_IMPLEMENT_BASIC_FUNCTION(rsqrt,TH_MATH_NAME(TH_rsqrt))
LAB_IMPLEMENT_BASIC_FUNCTION(ceil,TH_MATH_NAME(ceil))
@@ -2925,6 +2917,35 @@
LAB_IMPLEMENT_BASIC_FUNCTION(cinv, TH_MATH_NAME(1.0) / )
+void THTensor_(pow)(THTensor *r_, THTensor *t, real value)
+{
+ THTensor_(resizeAs)(r_, t);
+ if(value == 1){
+ THTensor_(copy)(r_, t);
+ }
+ else if(value == 2){
+ THTensor_(cmul)(r_, t, t);
+ }
+ else if(value == 3){
+ TH_TENSOR_APPLY2(real, t, real, r_, *r__data = *t_data * *t_data * *t_data;);
+ }
+ else if(value == 0.5){
+ THTensor_(sqrt)(r_, t);
+ }
+ else if(value == -0.5){
+ THTensor_(rsqrt)(r_, t);
+ }
+ else if(value == -1){
+ THTensor_(cinv)(r_, t);
+ }
+ else if(value == -2){
+ TH_TENSOR_APPLY2(real, t, real, r_, *r__data = TH_MATH_NAME(1.0) / (*t_data * *t_data););
+ }
+ else{
+ TH_TENSOR_APPLY2(real, t, real, r_, *r__data = TH_MATH_NAME(pow)(*t_data, value););
+ }
+}
+
void THTensor_(atan2)(THTensor *r_, THTensor *tx, THTensor *ty)
{
THTensor_(resizeAs)(r_, tx);