Fix THVector cmul AVX bug
diff --git a/generic/THVectorDispatch.c b/generic/THVectorDispatch.c
index 71a81c8..7537f66 100644
--- a/generic/THVectorDispatch.c
+++ b/generic/THVectorDispatch.c
@@ -67,12 +67,25 @@
static void (*THVector_(add_DISPATCHPTR))(real *, const real *, const real, const ptrdiff_t) = &THVector_(add_DEFAULT);
static FunctionDescription THVector_(add_DISPATCHTABLE)[] = {
+ #if defined(__NEON__)
+ #if defined(TH_REAL_IS_FLOAT)
+ FUNCTION_IMPL(THVector_(add_NEON), SIMDExtension_NEON),
+ #endif
+ #endif
+
#if defined(USE_AVX)
#if defined(TH_REAL_IS_DOUBLE) || defined(TH_REAL_IS_FLOAT)
FUNCTION_IMPL(THVector_(add_AVX), SIMDExtension_AVX),
#endif
#endif
+ #if defined(USE_SSE2) || defined(USE_SSE3) || defined(USE_SSSE3) \
+ || defined(USE_SSE4_1) || defined(USE_SSE4_2)
+ #if defined(TH_REAL_IS_DOUBLE) || defined(TH_REAL_IS_FLOAT)
+ FUNCTION_IMPL(THVector_(add_SSE), SIMDExtension_SSE),
+ #endif
+ #endif
+
FUNCTION_IMPL(THVector_(add_DEFAULT), SIMDExtension_DEFAULT)
};
// Dispatch stubs that just call the pointers
@@ -94,6 +107,12 @@
#endif
#endif
+ #if defined(USE_AVX)
+ #if defined(TH_REAL_IS_DOUBLE) || defined(TH_REAL_IS_FLOAT)
+ FUNCTION_IMPL(THVector_(cmul_AVX), SIMDExtension_AVX),
+ #endif
+ #endif
+
#if defined(USE_SSE2) || defined(USE_SSE3) || defined(USE_SSSE3) \
|| defined(USE_SSE4_1) || defined(USE_SSE4_2)
#if defined(TH_REAL_IS_DOUBLE) || defined(TH_REAL_IS_FLOAT)
@@ -109,12 +128,25 @@
static void (*THVector_(mul_DISPATCHPTR))(real *, const real *, const real, const ptrdiff_t) = &THVector_(mul_DEFAULT);
static FunctionDescription THVector_(mul_DISPATCHTABLE)[] = {
+ #if defined(__NEON__)
+ #if defined(TH_REAL_IS_FLOAT)
+ FUNCTION_IMPL(THVector_(mul_NEON), SIMDExtension_NEON),
+ #endif
+ #endif
+
#if defined(USE_AVX)
#if defined(TH_REAL_IS_DOUBLE) || defined(TH_REAL_IS_FLOAT)
FUNCTION_IMPL(THVector_(mul_AVX), SIMDExtension_AVX),
#endif
#endif
+ #if defined(USE_SSE2) || defined(USE_SSE3) || defined(USE_SSSE3) \
+ || defined(USE_SSE4_1) || defined(USE_SSE4_2)
+ #if defined(TH_REAL_IS_DOUBLE) || defined(TH_REAL_IS_FLOAT)
+ FUNCTION_IMPL(THVector_(mul_SSE), SIMDExtension_SSE),
+ #endif
+ #endif
+
FUNCTION_IMPL(THVector_(mul_DEFAULT), SIMDExtension_DEFAULT)
};
void THVector_(mul)(real *y, const real *x, const real c, const ptrdiff_t n) {
@@ -123,12 +155,25 @@
static void (*THVector_(cdiv_DISPATCHPTR))(real *, const real *, const real *, const ptrdiff_t) = &THVector_(cdiv_DEFAULT);
static FunctionDescription THVector_(cdiv_DISPATCHTABLE)[] = {
+ #if defined(__NEON__)
+ #if defined(TH_REAL_IS_FLOAT)
+ FUNCTION_IMPL(THVector_(cdiv_NEON), SIMDExtension_NEON),
+ #endif
+ #endif
+
#if defined(USE_AVX)
#if defined(TH_REAL_IS_DOUBLE) || defined(TH_REAL_IS_FLOAT)
FUNCTION_IMPL(THVector_(cdiv_AVX), SIMDExtension_AVX),
#endif
#endif
+ #if defined(USE_SSE2) || defined(USE_SSE3) || defined(USE_SSSE3) \
+ || defined(USE_SSE4_1) || defined(USE_SSE4_2)
+ #if defined(TH_REAL_IS_DOUBLE) || defined(TH_REAL_IS_FLOAT)
+ FUNCTION_IMPL(THVector_(cdiv_SSE), SIMDExtension_SSE),
+ #endif
+ #endif
+
FUNCTION_IMPL(THVector_(cdiv_DEFAULT), SIMDExtension_DEFAULT)
};
void THVector_(cdiv)(real *z, const real *x, const real *y, const ptrdiff_t n) {
@@ -137,12 +182,25 @@
static void (*THVector_(div_DISPATCHPTR))(real *, const real *, const real, const ptrdiff_t) = &THVector_(div_DEFAULT);
static FunctionDescription THVector_(div_DISPATCHTABLE)[] = {
+ #if defined(__NEON__)
+ #if defined(TH_REAL_IS_FLOAT)
+ FUNCTION_IMPL(THVector_(div_NEON), SIMDExtension_NEON),
+ #endif
+ #endif
+
#if defined(USE_AVX)
#if defined(TH_REAL_IS_DOUBLE) || defined(TH_REAL_IS_FLOAT)
FUNCTION_IMPL(THVector_(div_AVX), SIMDExtension_AVX),
#endif
#endif
+ #if defined(USE_SSE2) || defined(USE_SSE3) || defined(USE_SSSE3) \
+ || defined(USE_SSE4_1) || defined(USE_SSE4_2)
+ #if defined(TH_REAL_IS_DOUBLE) || defined(TH_REAL_IS_FLOAT)
+ FUNCTION_IMPL(THVector_(div_SSE), SIMDExtension_SSE),
+ #endif
+ #endif
+
FUNCTION_IMPL(THVector_(div_DEFAULT), SIMDExtension_DEFAULT)
};
void THVector_(div)(real *y, const real *x, const real c, const ptrdiff_t n) {
diff --git a/vector/AVX.c b/vector/AVX.c
index e87b0d0..a964c88 100644
--- a/vector/AVX.c
+++ b/vector/AVX.c
@@ -177,15 +177,15 @@
static void THFloatVector_cmul_AVX(float *z, const float *x, const float *y, const ptrdiff_t n) {
ptrdiff_t i;
__m256 YMM0, YMM1, YMM2, YMM3;
- for (i=0; i<=((n)-8); i+=8) {
+ for (i=0; i<=((n)-16); i+=16) {
YMM0 = _mm256_loadu_ps(x+i);
- YMM1 = _mm256_loadu_ps(x+i+4);
+ YMM1 = _mm256_loadu_ps(x+i+8);
YMM2 = _mm256_loadu_ps(y+i);
- YMM3 = _mm256_loadu_ps(y+i+4);
+ YMM3 = _mm256_loadu_ps(y+i+8);
YMM2 = _mm256_mul_ps(YMM0, YMM2);
YMM3 = _mm256_mul_ps(YMM1, YMM3);
_mm256_storeu_ps(z+i, YMM2);
- _mm256_storeu_ps(z+i+4, YMM3);
+ _mm256_storeu_ps(z+i+8, YMM3);
}
for (; i<n; i++) {
z[i] = x[i] * y[i];