Lapack function implementation #1
diff --git a/torch/lib/THD/master_worker/master/generic/THDTensorLapack.cpp b/torch/lib/THD/master_worker/master/generic/THDTensorLapack.cpp
index 2e42122..b75183e 100644
--- a/torch/lib/THD/master_worker/master/generic/THDTensorLapack.cpp
+++ b/torch/lib/THD/master_worker/master/generic/THDTensorLapack.cpp
@@ -92,20 +92,220 @@
return THDTensor_(cloneColumnMajorNrows)(self, src, src->size[0]);
}
+
/* TODO implement all those */
-void THDTensor_(gesv)(THDTensor *rb, THDTensor *ra, THDTensor *b, THDTensor *a) {}
+
+/* TODO this might leak on incorrect data */
+void THDTensor_(gesv)(THDTensor *rb, THDTensor *ra, THDTensor *b, THDTensor *a) {
+ bool free_b = false;
+ if (a == NULL) a = ra;
+ if (b == NULL) b = rb;
+ THArgCheck(a->nDimension == 2, 2, "A should have 2 dimensions, but has %d",
+ a->nDimension);
+ THArgCheck(b->nDimension == 1 || b->nDimension == 2, 1, "B should have 1 or 2 "
+ "dimensions, but has %d", b->nDimension);
+ THArgCheck(a->size[0] == a->size[1], 2, "A should be square, but is %ldx%ld",
+ a->size[0], a->size[1]);
+ THArgCheck(a->size[0] == b->size[0], 2, "A,B size incompatible - A has %ld "
+ "rows, B has %ld", a->size[0], b->size[0]);
+
+ if (b->nDimension == 1) {
+ b = THDTensor_(newWithStorage2d)(b->storage, b->storageOffset, b->size[0],
+ b->stride[0], 1, 0);
+ free_b = true;
+ }
+
+ masterCommandChannel->sendMessage(
+ packMessage(Functions::tensorGesv, rb, ra, b, a),
+ THDState::s_current_worker
+ );
+
+ THDTensor_(free)(THDTensor_(cloneColumnMajor)(ra, a));
+ THDTensor_(free)(THDTensor_(cloneColumnMajor)(rb, b));
+
+ if (free_b) THDTensor_(free)(b);
+}
void THDTensor_(trtrs)(THDTensor *rb, THDTensor *ra, THDTensor *b, THDTensor *a,
- const char *uplo, const char *trans, const char *diag) {}
-void THDTensor_(gels)(THDTensor *rb, THDTensor *ra, THDTensor *b, THDTensor *a) {}
+ const char *uplo, const char *trans, const char *diag) {
+ bool free_b = false;
+ if (a == NULL) a = ra;
+ if (b == NULL) b = rb;
+ THArgCheck(a->nDimension == 2, 2, "A should have 2 dimensions, but has %d",
+ a->nDimension);
+ THArgCheck(b->nDimension == 1 || b->nDimension == 2, 1, "B should have 1 or 2 "
+ "dimensions, but has %d", b->nDimension);
+ THArgCheck(a->size[0] == a->size[1], 2, "A should be square, but is %ldx%ld",
+ a->size[0], a->size[1]);
+ THArgCheck(a->size[0] == b->size[0], 2, "A,B size incompatible - A has %ld "
+ "rows, B has %ld", a->size[0], b->size[0]);
+
+ if (b->nDimension == 1) {
+ b = THDTensor_(newWithStorage2d)(b->storage, b->storageOffset, b->size[0],
+ b->stride[0], 1, 0);
+ free_b = true;
+ }
+
+ masterCommandChannel->sendMessage(
+ packMessage(Functions::tensorTrtrs, rb, ra, b, a, uplo[0], trans[0], diag[0]),
+ THDState::s_current_worker
+ );
+
+ THDTensor_(free)(THDTensor_(cloneColumnMajor)(ra, a));
+ THDTensor_(free)(THDTensor_(cloneColumnMajor)(rb, b));
+
+ if (free_b) THDTensor_(free)(b);
+}
+
+void THDTensor_(gels)(THDTensor *rb, THDTensor *ra, THDTensor *b, THDTensor *a) {
+ bool free_b = 0;
+ if (a == NULL) a = ra;
+ if (b == NULL) b = rb;
+ THArgCheck(a->nDimension == 2, 2, "A should have 2 dimensions, but has %d",
+ a->nDimension);
+ THArgCheck(b->nDimension == 1 || b->nDimension == 2, 1, "B should have 1 or 2 "
+ "dimensions, but has %d", b->nDimension);
+ THArgCheck(a->size[0] == b->size[0], 2, "A,B size incompatible - A has %ld "
+ "rows, B has %ld", a->size[0], b->size[0]);
+
+ if (b->nDimension == 1) {
+ b = THDTensor_(newWithStorage2d)(b->storage, b->storageOffset, b->size[0],
+ b->stride[0], 1, 0);
+ free_b = true;
+ }
+
+ masterCommandChannel->sendMessage(
+ packMessage(Functions::tensorGels, rb, ra, b, a),
+ THDState::s_current_worker
+ );
+
+ int m, n, nrhs, ldb;
+
+ THDTensor *ra_ = NULL;
+ THDTensor *rb_ = NULL;
+
+ ra_ = THDTensor_(cloneColumnMajor)(ra, a);
+
+ m = ra_->size[0];
+ n = ra_->size[1];
+ ldb = (m > n) ? m : n;
+
+ rb_ = THDTensor_(cloneColumnMajorNrows)(rb, b, ldb);
+
+ nrhs = rb_->size[1];
+
+ /* rb_ is currently ldb by nrhs; resize it to n by nrhs */
+ rb_->size[0] = n;
+ if (rb_ != rb)
+ THDTensor_(resize2d)(rb, n, nrhs);
+
+ THDTensor_(free)(ra_);
+ THDTensor_(free)(rb_);
+ if (free_b) THDTensor_(free)(b);
+}
void THDTensor_(syev)(THDTensor *re, THDTensor *rv, THDTensor *a,
- const char *jobz, const char *uplo) {}
-void THDTensor_(geev)(THDTensor *re, THDTensor *rv, THDTensor *a, const char *jobvr) {}
+ const char *jobz, const char *uplo) {
+ if (a == NULL) a = rv;
+ THArgCheck(a->nDimension == 2, 1, "A should be 2 dimensional");
+ THArgCheck(a->size[0] == a->size[1], 1,"A should be square");
+
+ masterCommandChannel->sendMessage(
+ packMessage(Functions::tensorSyev, re, rv, a, jobz[0], uplo[0]),
+ THDState::s_current_worker
+ );
+
+ THDTensor *rv_ = THDTensor_(cloneColumnMajor)(rv, a);
+ THDTensor_(resize1d)(re, rv_->size[0]);
+ THDTensor_(free)(rv_);
+}
+
+void THDTensor_(geev)(THDTensor *re, THDTensor *rv, THDTensor *a, const char *jobvr) {
+ int n;
+ THDTensor *a_;
+
+ THDTensor *re_ = NULL;
+ THDTensor *rv_ = NULL;
+
+ THArgCheck(a->nDimension == 2, 1, "A should be 2 dimensional");
+ THArgCheck(a->size[0] == a->size[1], 1,"A should be square");
+
+ masterCommandChannel->sendMessage(
+ packMessage(Functions::tensorGeev, re, rv, a, jobvr[0]),
+ THDState::s_current_worker
+ );
+
+ /* we want to definitely clone a for geev*/
+ a_ = THDTensor_(cloneColumnMajor)(NULL, a);
+
+ n = a_->size[0];
+
+ if (*jobvr == 'V') {
+ THDTensor_(resize2d)(rv, n, n);
+ /* guard against someone passing a correct size, but wrong stride */
+ rv_ = THDTensor_(newTransposedContiguous)(rv);
+ }
+ THDTensor_(resize2d)(re, n, 2);
+
+ if (*jobvr == 'V') {
+ THDTensor_(checkTransposed)(rv);
+ }
+
+ THDTensor_(free)(a_);
+}
+
void THDTensor_(gesvd)(THDTensor *ru, THDTensor *rs, THDTensor *rv, THDTensor *a,
- const char *jobu) {}
+ const char *jobu) {
+ THDTensor *ra = THDTensor_(new)();
+ THDTensor_(gesvd2)(ru, rs, rv, ra, a, jobu);
+ THDTensor_(free)(ra);
+}
+
void THDTensor_(gesvd2)(THDTensor *ru, THDTensor *rs, THDTensor *rv, THDTensor *ra,
- THDTensor *a, const char *jobu) {}
+ THDTensor *a, const char *jobu) {
+ if (a == NULL) a = ra;
+ THArgCheck(a->nDimension == 2, 1, "A should be 2 dimensional");
+
+ masterCommandChannel->sendMessage(
+ packMessage(Functions::tensorGesvd2, ru, rs, rv, ra, a, jobu[0]),
+ THDState::s_current_worker
+ );
+
+ int k, m, n, ldu, ldvt;
+ THDTensor *rvf = THDTensor_(new)();
+
+ THDTensor *ra_ = NULL;
+ THDTensor *ru_ = NULL;
+
+ ra_ = THDTensor_(cloneColumnMajor)(ra, a);
+
+ m = ra_->size[0];
+ n = ra_->size[1];
+ k = (m < n ? m : n);
+
+ ldu = m;
+ ldvt = n;
+
+ THDTensor_(resize1d)(rs, k);
+ THDTensor_(resize2d)(rvf, ldvt, n);
+ if (*jobu == 'A')
+ THDTensor_(resize2d)(ru, m, ldu);
+ else
+ THDTensor_(resize2d)(ru, k, ldu);
+
+ THDTensor_(checkTransposed)(ru);
+
+ /* guard against someone passing a correct size, but wrong stride */
+ ru_ = THDTensor_(newTransposedContiguous)(ru);
+
+ if (*jobu == 'S') {
+ THDTensor_(narrow)(rvf, NULL, 1, 0, k);
+ }
+ THDTensor_(resizeAs)(rv, rvf);
+ THDTensor_(free)(rvf);
+ THDTensor_(free)(ra_);
+}
+
void THDTensor_(getri)(THDTensor *ra, THDTensor *a) {}
void THDTensor_(potrf)(THDTensor *ra, THDTensor *a, const char *uplo) {}
void THDTensor_(potrs)(THDTensor *rb, THDTensor *b, THDTensor *a, const char *uplo) {}
diff --git a/torch/lib/THD/master_worker/worker/Dispatch.cpp b/torch/lib/THD/master_worker/worker/Dispatch.cpp
index 750994b..c0c59ef 100644
--- a/torch/lib/THD/master_worker/worker/Dispatch.cpp
+++ b/torch/lib/THD/master_worker/worker/Dispatch.cpp
@@ -51,6 +51,8 @@
#include "dispatch/Tensor.cpp"
#include "dispatch/TensorMath.cpp"
#include "dispatch/TensorRandom.cpp"
+#include "dispatch/TensorLapack.cpp"
+#include "dispatch/Communication.cpp"
using dispatch_fn = void (*)(rpc::RPCMessage&);
using Functions = thd::Functions;
@@ -82,6 +84,8 @@
{Functions::tensorSelect, tensorSelect},
{Functions::tensorTranspose, tensorTranspose},
{Functions::tensorUnfold, tensorUnfold},
+ {Functions::tensorSqueeze, tensorSqueeze},
+ {Functions::tensorSqueeze, tensorSqueeze1d},
{Functions::tensorFree, tensorFree},
{Functions::tensorAdd, tensorAdd},
@@ -231,6 +235,8 @@
{Functions::tensorLogNormal, tensorLogNormal},
{Functions::tensorMultinomial, tensorMultinomial},
+ {Functions::tensorGesv, tensorGesv},
+
{Functions::storageNew, storageNew},
{Functions::storageNewWithSize, storageNewWithSize},
{Functions::storageNewWithSize1, storageNewWithSize1},
diff --git a/torch/lib/THD/master_worker/worker/dispatch/TensorLapack.cpp b/torch/lib/THD/master_worker/worker/dispatch/TensorLapack.cpp
new file mode 100644
index 0000000..7f16a55
--- /dev/null
+++ b/torch/lib/THD/master_worker/worker/dispatch/TensorLapack.cpp
@@ -0,0 +1,145 @@
+
+static void tensorGesv(rpc::RPCMessage& raw_message) {
+ thpp::Tensor *rb = unpackRetrieveTensor(raw_message);
+ thpp::Tensor *ra = unpackRetrieveTensor(raw_message);
+ thpp::Tensor *b = unpackRetrieveTensor(raw_message);
+ thpp::Tensor *a = unpackRetrieveTensor(raw_message);
+ finalize(raw_message);
+ rb->gesv(*ra, *b, *a);
+}
+
+static void tensorTrtrs(rpc::RPCMessage& raw_message) {
+ thpp::Tensor *rb = unpackRetrieveTensor(raw_message);
+ thpp::Tensor *ra = unpackRetrieveTensor(raw_message);
+ thpp::Tensor *b = unpackRetrieveTensor(raw_message);
+ thpp::Tensor *a = unpackRetrieveTensor(raw_message);
+ char uplo = unpackInteger(raw_message);
+ char trans = unpackInteger(raw_message);
+ char diag = unpackInteger(raw_message);
+ rb->trtrs(*ra, *b, *a, &uplo, &trans, &diag);
+}
+
+static void tensorGels(rpc::RPCMessage& raw_message) {
+ thpp::Tensor *rb = unpackRetrieveTensor(raw_message);
+ thpp::Tensor *ra = unpackRetrieveTensor(raw_message);
+ thpp::Tensor *b = unpackRetrieveTensor(raw_message);
+ thpp::Tensor *a = unpackRetrieveTensor(raw_message);
+ finalize(raw_message);
+ rb->gels(*ra, *b, *a);
+}
+
+static void tensorSyev(rpc::RPCMessage& raw_message) {
+ thpp::Tensor *re = unpackRetrieveTensor(raw_message);
+ thpp::Tensor *rv = unpackRetrieveTensor(raw_message);
+ thpp::Tensor *a = unpackRetrieveTensor(raw_message);
+ char jobz = unpackInteger(raw_message);
+ char uplo = unpackInteger(raw_message);
+ finalize(raw_message);
+ re->syev(*rv, *a, &jobz, &uplo);
+}
+
+static void tensorGeev(rpc::RPCMessage& raw_message) {
+ thpp::Tensor *re = unpackRetrieveTensor(raw_message);
+ thpp::Tensor *rv = unpackRetrieveTensor(raw_message);
+ thpp::Tensor *a = unpackRetrieveTensor(raw_message);
+ char jobvr = unpackInteger(raw_message);
+ finalize(raw_message);
+ re->geev(*rv, *a, &jobvr);
+}
+
+static void tensorGesvd2(rpc::RPCMessage& raw_message) {
+ thpp::Tensor *ru = unpackRetrieveTensor(raw_message);
+ thpp::Tensor *rs = unpackRetrieveTensor(raw_message);
+ thpp::Tensor *rv = unpackRetrieveTensor(raw_message);
+ thpp::Tensor *ra = unpackRetrieveTensor(raw_message);
+ thpp::Tensor *a = unpackRetrieveTensor(raw_message);
+ char jobu = unpackInteger(raw_message);
+ finalize(raw_message);
+ ru->gesvd2(*rs, *rv, *ra, *a, &jobu);
+}
+
+static void tensorGetri(rpc::RPCMessage& raw_message) {
+ thpp::Tensor *ra = unpackRetrieveTensor(raw_message);
+ thpp::Tensor *a = unpackRetrieveTensor(raw_message);
+ finalize(raw_message);
+ ra->getri(*a);
+}
+
+static void tensorPotrf(rpc::RPCMessage& raw_message) {
+ thpp::Tensor *ra = unpackRetrieveTensor(raw_message);
+ thpp::Tensor *a = unpackRetrieveTensor(raw_message);
+ char uplo = unpackInteger(raw_message);
+ finalize(raw_message);
+ ra->potrf(*a, &uplo);
+}
+
+static void tensorPotrs(rpc::RPCMessage& raw_message) {
+ thpp::Tensor *rb = unpackRetrieveTensor(raw_message);
+ thpp::Tensor *b = unpackRetrieveTensor(raw_message);
+ thpp::Tensor *a = unpackRetrieveTensor(raw_message);
+ char uplo = unpackInteger(raw_message);
+ finalize(raw_message);
+ rb->potrs(*b, *a, &uplo);
+}
+
+static void tensorPotri(rpc::RPCMessage& raw_message) {
+ thpp::Tensor *ra = unpackRetrieveTensor(raw_message);
+ thpp::Tensor *a = unpackRetrieveTensor(raw_message);
+ char uplo = unpackInteger(raw_message);
+ finalize(raw_message);
+ ra->potri(*a, &uplo);
+}
+
+static void tensorQr(rpc::RPCMessage& raw_message) {
+ thpp::Tensor *rq = unpackRetrieveTensor(raw_message);
+ thpp::Tensor *rr = unpackRetrieveTensor(raw_message);
+ thpp::Tensor *a = unpackRetrieveTensor(raw_message);
+ finalize(raw_message);
+ rq->qr(*rr, *a);
+}
+
+static void tensorGeqrf(rpc::RPCMessage& raw_message) {
+ thpp::Tensor *ra = unpackRetrieveTensor(raw_message);
+ thpp::Tensor *rtau = unpackRetrieveTensor(raw_message);
+ thpp::Tensor *a = unpackRetrieveTensor(raw_message);
+ finalize(raw_message);
+ ra->geqrf(*rtau, *a);
+}
+
+static void tensorOrgqr(rpc::RPCMessage& raw_message) {
+ thpp::Tensor *ra = unpackRetrieveTensor(raw_message);
+ thpp::Tensor *a = unpackRetrieveTensor(raw_message);
+ thpp::Tensor *tau = unpackRetrieveTensor(raw_message);
+ finalize(raw_message);
+ ra->geqrf(*a, *tau);
+}
+
+static void tensorOrmqr(rpc::RPCMessage& raw_message) {
+ thpp::Tensor *ra = unpackRetrieveTensor(raw_message);
+ thpp::Tensor *a = unpackRetrieveTensor(raw_message);
+ thpp::Tensor *tau = unpackRetrieveTensor(raw_message);
+ thpp::Tensor *c = unpackRetrieveTensor(raw_message);
+ char side = unpackInteger(raw_message);
+ char trans = unpackInteger(raw_message);
+ finalize(raw_message);
+ ra->ormqr(*a, *tau, *c, &side, &trans);
+}
+
+static void tensorPstrf(rpc::RPCMessage& raw_message) {
+ thpp::Tensor *ra = unpackRetrieveTensor(raw_message);
+ thpp::Tensor *rpiv = unpackRetrieveTensor(raw_message);
+ thpp::Tensor *a = unpackRetrieveTensor(raw_message);
+ char uplo = unpackInteger(raw_message);
+ thpp::Type type = peekType(raw_message);
+ if (thpp::isInteger(type)) {
+ auto tol = unpackInteger(raw_message);
+ finalize(raw_message);
+ dynamic_cast<thpp::IntTensor*>(ra)->pstrf(*rpiv, *a, &uplo, tol);
+ } else if (thpp::isFloat(type)) {
+ auto tol = unpackFloat(raw_message);
+ finalize(raw_message);
+ dynamic_cast<thpp::FloatTensor*>(ra)->pstrf(*rpiv, *a, &uplo, tol);
+ } else {
+ throw std::runtime_error("expected scalar type");
+ }
+}