Renamed baddmm -> addbmm; added baddbmm.
- addbmm is a batch MM + reduce add
- baddbmm is a batch MM + batch add
diff --git a/generic/THTensorMath.c b/generic/THTensorMath.c
index b015732..d5e1818 100644
--- a/generic/THTensorMath.c
+++ b/generic/THTensorMath.c
@@ -658,7 +658,7 @@
THTensor *matrix1 = THTensor_(new)();
THTensor *matrix2 = THTensor_(new)();
- THTensor *result_matrix = THTensor_(new());
+ THTensor *result_matrix = THTensor_(new)();
for (batch = 0; batch < THTensor_(size)(batch1, 0); ++batch) {
THTensor_(select)(matrix1, batch1, 0, batch);
@@ -673,7 +673,7 @@
THTensor_(free)(result_matrix);
}
-void THTensor_(baddmm)(THTensor *result, real beta, THTensor *t, real alpha, THTensor *batch1, THTensor *batch2)
+void THTensor_(addbmm)(THTensor *result, real beta, THTensor *t, real alpha, THTensor *batch1, THTensor *batch2)
{
long batch;
@@ -709,6 +709,46 @@
THTensor_(free)(matrix2);
}
+void THTensor_(baddbmm)(THTensor *result, real beta, THTensor *t, real alpha, THTensor *batch1, THTensor *batch2)
+{
+ long batch;
+
+ THArgCheck(THTensor_(nDimension)(batch1) == 3, 1, "expected 3D tensor");
+ THArgCheck(THTensor_(nDimension)(batch2) == 3, 2, "expected 3D tensor");
+ THArgCheck(THTensor_(size)(batch1, 0) == THTensor_(size)(batch2, 0), 2,
+ "equal number of batches expected");
+ THArgCheck(THTensor_(size)(batch1, 2) == THTensor_(size)(batch2, 1), 2,
+ "wrong matrix size");
+
+ long bs = THTensor_(size)(batch1, 0);
+ long dim1 = THTensor_(size)(batch1, 1);
+ long dim2 = THTensor_(size)(batch2, 2);
+ THArgCheck(THTensor_(size)(t, 0) == bs, 1, "output tensor of incorrect size");
+ THArgCheck(THTensor_(size)(t, 1) == dim1, 1, "output tensor of incorrect size");
+ THArgCheck(THTensor_(size)(t, 2) == dim2, 1, "output tensor of incorrect size");
+
+ if (t != result) {
+ THTensor_(resizeAs)(result, t);
+ THTensor_(copy)(result, t);
+ }
+
+ THTensor *matrix1 = THTensor_(new)();
+ THTensor *matrix2 = THTensor_(new)();
+ THTensor *result_matrix = THTensor_(new)();
+
+ for (batch = 0; batch < THTensor_(size)(batch1, 0); ++batch) {
+ THTensor_(select)(matrix1, batch1, 0, batch);
+ THTensor_(select)(matrix2, batch2, 0, batch);
+ THTensor_(select)(result_matrix, result, 0, batch);
+
+ THTensor_(addmm)(result_matrix, beta, result_matrix, alpha, matrix1, matrix2);
+ }
+
+ THTensor_(free)(matrix1);
+ THTensor_(free)(matrix2);
+ THTensor_(free)(result_matrix);
+}
+
long THTensor_(numel)(THTensor *t)
{
return THTensor_(nElement)(t);
diff --git a/generic/THTensorMath.h b/generic/THTensorMath.h
index aa289b6..0094bd3 100644
--- a/generic/THTensorMath.h
+++ b/generic/THTensorMath.h
@@ -38,7 +38,8 @@
TH_API void THTensor_(addr)(THTensor *r_, real beta, THTensor *t, real alpha, THTensor *vec1, THTensor *vec2);
TH_API void THTensor_(bmm)(THTensor *r_, THTensor *batch1, THTensor *batch2);
-TH_API void THTensor_(baddmm)(THTensor *r_, real beta, THTensor *t, real alpha, THTensor *batch1, THTensor *batch2);
+TH_API void THTensor_(addbmm)(THTensor *r_, real beta, THTensor *t, real alpha, THTensor *batch1, THTensor *batch2);
+TH_API void THTensor_(baddbmm)(THTensor *r_, real beta, THTensor *t, real alpha, THTensor *batch1, THTensor *batch2);
TH_API void THTensor_(match)(THTensor *r_, THTensor *m1, THTensor *m2, real gain);