check for beta=0 and avoid multiply in sparse mm (#1211)
* check for beta=0 and avoid multiply in sparse mm
diff --git a/torch/csrc/generic/methods/SparseTensor.cwrap b/torch/csrc/generic/methods/SparseTensor.cwrap
index d89cd6d..3d9cd99 100644
--- a/torch/csrc/generic/methods/SparseTensor.cwrap
+++ b/torch/csrc/generic/methods/SparseTensor.cwrap
@@ -162,10 +162,6 @@
only_stateless: True
cname: spaddmm
return: argument 0
- before_call: |
- long s1 = THSTensor_(size)(LIBRARY_STATE ((THSPTensor*)$arg4)->cdata, 0);
- long s2 = THTensor_(size)(LIBRARY_STATE ((THPTensor*)$arg5)->cdata, 1);
- THTensor_(resize2d)(LIBRARY_STATE ((THPTensor*)$arg0)->cdata, s1, s2);
arguments:
- arg: THTensor* result
output: True
@@ -182,11 +178,6 @@
sparse: yes
cname: spaddmm
return: argument 0
- before_call: |
- long s1 = THSTensor_(size)(LIBRARY_STATE ((THSPTensor*)$arg4)->cdata, 0);
- long s2 = THTensor_(size)(LIBRARY_STATE ((THPTensor*)$arg5)->cdata, 1);
- THTensor_(resize2d)(LIBRARY_STATE ((THPTensor*)$arg0)->cdata, s1, s2);
- THTensor_(zero)(LIBRARY_STATE ((THPTensor*)$arg0)->cdata);
arguments:
- arg: THTensor* result
output: True
@@ -203,10 +194,6 @@
sparse: yes
cname: sspaddmm
return: argument 0
- before_call: |
- long s1 = THSTensor_(size)(LIBRARY_STATE ((THSPTensor*)$arg4)->cdata, 0);
- long s2 = THTensor_(size)(LIBRARY_STATE ((THPTensor*)$arg5)->cdata, 1);
- THSTensor_(resize2d)(LIBRARY_STATE ((THSPTensor*)$arg0)->cdata, s1, s2);
arguments:
- arg: THSTensor* result
output: True
diff --git a/torch/csrc/generic/methods/TensorMath.cwrap b/torch/csrc/generic/methods/TensorMath.cwrap
index ec498f7..50bdcad 100644
--- a/torch/csrc/generic/methods/TensorMath.cwrap
+++ b/torch/csrc/generic/methods/TensorMath.cwrap
@@ -1469,11 +1469,6 @@
- THTensor* mat2
- cname: spaddmm
sparse: True
- before_call: |
- long s1 = THSTensor_(size)(LIBRARY_STATE ((THSPTensor*)$arg4)->cdata, 0);
- long s2 = THTensor_(size)(LIBRARY_STATE ((THPTensor*)$arg5)->cdata, 1);
- THTensor_(resize2d)(LIBRARY_STATE ((THPTensor*)$arg0)->cdata, s1, s2);
- THTensor_(zero)(LIBRARY_STATE ((THPTensor*)$arg0)->cdata);
arguments:
- arg: THTensor* result
output: True
diff --git a/torch/lib/THS/generic/THSTensorMath.c b/torch/lib/THS/generic/THSTensorMath.c
index 5e2e634..fd0a78d 100644
--- a/torch/lib/THS/generic/THSTensorMath.c
+++ b/torch/lib/THS/generic/THSTensorMath.c
@@ -266,6 +266,8 @@
dim_j = THSTensor_(size)(sparse, 1);
dim_k = THTensor_(size)(dense, 1);
+ THTensor_(resize2d)(r_, dim_i, dim_k);
+
THArgCheck(THTensor_(size)(dense, 0) == dim_j, 3,
"Expected dim 0 size %d, got %d", dim_j, THTensor_(size)(dense, 0));
THArgCheck(THTensor_(size)(t, 0) == dim_i, 1,
@@ -280,8 +282,11 @@
csr = THSTensor_(toCSR)(THLongTensor_data(indices), dim_i, nnz);
// r_ = alpha * sparse * dense
- THTensor_(resize2d)(r_, dim_i, dim_k);
- THTensor_(mul)(r_, t, beta);
+ if (beta == 0) {
+ THTensor_(zero)(r_);
+ } else {
+ THTensor_(mul)(r_, t, beta);
+ }
#pragma omp parallel for private(h, i) schedule(static) if (nnz > 10000)
for (h = 0; h < dim_i; h++) {
long i_start = THTensor_fastGet1d(csr, h);
@@ -322,13 +327,15 @@
"scalar values expected, got %dD values", sparse->nDimensionV);
THArgCheck(dense->nDimension == 2, 2,
"matrices expected, got %dD tensor", dense->nDimension);
-
THSTensor_(contiguous)(sparse);
dim_i = THSTensor_(size)(sparse, 0);
dim_j = THSTensor_(size)(sparse, 1);
dim_k = THTensor_(size)(dense, 1);
+ THSTensor_(resize2d)(r_, dim_i, dim_k);
+
+
THArgCheck(THTensor_(size)(dense, 0) == dim_j, 3,
"Expected dim 0 size %d, got %d", dim_j, THTensor_(size)(dense, 0));
THArgCheck(THSTensor_(size)(t, 0) == dim_i, 1,
@@ -339,7 +346,6 @@
nnz = THSTensor_(nnz)(sparse);
indices = THSTensor_(indices)(sparse);
values = THSTensor_(values)(sparse);
-
csr = THSTensor_(toCSR)(THLongTensor_data(indices), dim_i, nnz);
t_nnz = THSTensor_(nnz)(t);
@@ -390,7 +396,6 @@
}
- THSTensor_(resize2d)(r_, dim_i, dim_k);
// to avoid a clone
r_->indices = newi;
r_-> values = newv;