[Relanding] Implemented torch.linalg.multi_dot (#52859)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/52859
This reverts commit 92a4ee1cf6092dd941591f80885eb7fef5b2c0d8.
Added support for bfloat16 for CUDA 11 and removed fast-path for empty input tensors that was affecting autograd graph.
Test Plan: Imported from OSS
Reviewed By: H-Huang
Differential Revision: D27402390
Pulled By: heitorschueroff
fbshipit-source-id: 73c5ccf54f3da3d29eb63c9ed3601e2fe6951034
diff --git a/aten/src/ATen/autocast_mode.cpp b/aten/src/ATen/autocast_mode.cpp
index b204e7d..1a36659 100644
--- a/aten/src/ATen/autocast_mode.cpp
+++ b/aten/src/ATen/autocast_mode.cpp
@@ -282,6 +282,7 @@
KERNEL(ADD_NS(baddbmm), "baddbmm", Tensor (const Tensor &, const Tensor &, const Tensor &, const Scalar&, const Scalar&), fp16)
KERNEL(ADD_NS(bmm), "bmm", Tensor (const Tensor &, const Tensor &), fp16)
KERNEL(ADD_NS(chain_matmul), "chain_matmul", Tensor (TensorList), fp16)
+ KERNEL(ADD_NS(linalg_multi_dot), "linalg_multi_dot", Tensor (TensorList), fp16)
// The macro doesn't like these (I think it chokes on commas inside <>) so write them manually
m.impl(TORCH_SELECTIVE_NAME("aten::_thnn_fused_lstm_cell"),
TORCH_FN((&WrapFunction<CastPolicy::fp16,
diff --git a/aten/src/ATen/native/LinearAlgebra.cpp b/aten/src/ATen/native/LinearAlgebra.cpp
index 449dd57..461afa1 100644
--- a/aten/src/ATen/native/LinearAlgebra.cpp
+++ b/aten/src/ATen/native/LinearAlgebra.cpp
@@ -1,32 +1,27 @@
#include <ATen/ATen.h>
-#include <ATen/core/grad_mode.h>
#include <ATen/Dispatch.h>
#include <ATen/ExpandUtils.h>
#include <ATen/LegacyTHFunctionsCPU.h>
#include <ATen/NamedTensorUtils.h>
+#include <ATen/NativeFunctions.h>
+#include <ATen/Parallel.h>
+#include <ATen/TensorUtils.h>
+#include <ATen/Utils.h>
+#include <ATen/core/grad_mode.h>
#include <ATen/native/CPUBlas.h>
#include <ATen/native/IndexingUtils.h>
#include <ATen/native/LinearAlgebra.h>
#include <ATen/native/LinearAlgebraUtils.h>
+#include <ATen/native/ReduceOps.h>
#include <ATen/native/ReduceOpsUtils.h>
#include <ATen/native/Resize.h>
#include <ATen/native/TensorIterator.h>
-#include <ATen/NativeFunctions.h>
-#include <ATen/native/LinearAlgebra.h>
-#include <ATen/native/IndexingUtils.h>
-#include <ATen/native/ReduceOps.h>
-#include <ATen/TensorUtils.h>
-#include <ATen/Parallel.h>
-#include <ATen/TensorUtils.h>
-#include <ATen/Utils.h>
#include <c10/util/accumulate.h>
+#include <c10/util/irange.h>
#include <c10/util/variant.h>
-
#include <functional>
#include <limits>
#include <numeric>
-#include <vector>
-
namespace at {
namespace native {
@@ -359,6 +354,269 @@
return at::linalg_matrix_rank(self, c10::nullopt, symmetric);
}
+// multi_dot helper functions
+namespace {
+
+/**
+ * @brief Computes the optimal matrix chain multiplication order
+ *
+ * Follows the dynamic programming algorithm from Cormen et al,
+ * "Introduction to Algorithms, Third Edition", Chapter 15.2,
+ * p. 370-378. Note that the book uses 1-based indexing.
+ *
+ * The cost of multiplying two matrices with sizes p x q and q x r
+ * is defined here as p * q * r. The optimal multiplication order
+ * is the one that minimizes the total cost.
+ *
+ * @param tensors list of 2D tensors
+ * @return a 2D vector s used by #matrix_chain_multiplication to construct
+ * the optimal matrix multiplication order. The optimal multiplication
+ * order for multiplying tensors i...j is to multiply tensors i...s[i, j]
+ * and tensors (s[i, j] + 1)...j first and then the result of that.
+ */
+std::vector<std::vector<int64_t>> matrix_chain_order(TensorList tensors) {
+ const size_t n = tensors.size();
+
+ // Tensor i has dimensions p[i] x p[i + 1]
+ std::vector<int64_t> p(n + 1);
+ for (const auto i : c10::irange(n)) {
+ p[i] = tensors[i].size(0);
+ }
+ p[n] = tensors[n - 1].size(1);
+
+ // m[i, j] = k where k is the minimum cost for multiplying tensors i...j
+ std::vector<std::vector<int64_t>> m(n, std::vector<int64_t>(n, 0));
+
+ // s[i, j] = k where k is the index at which to split the list such that
+ // optimally multiplying matrices i...k and k...j first and then the resulting
+ // matrices is the optimal order for multiplying matrices i...j.
+ std::vector<std::vector<int64_t>> s(n, std::vector<int64_t>(n));
+
+ // Compute the optimal multiplication order
+ for (const auto l : c10::irange(1, n)) {
+ for (const auto i : c10::irange(n - l)) {
+ const auto j = i + l;
+ m[i][j] = std::numeric_limits<int64_t>::max();
+ for (const auto k : c10::irange(i, j)) {
+ const auto q = m[i][k] + m[k + 1][j] + p[i] * p[k + 1] * p[j + 1];
+ if (q < m[i][j]) {
+ m[i][j] = q;
+ s[i][j] = k;
+ }
+ }
+ }
+ }
+
+ return s;
+}
+
+/**
+ * @brief Recursively multiplies the tensors i...j using the given order
+ *
+ * @param tensors matrices to multiply togther
+ * @param order optimal chain multiplication order from #matrix_chain_order
+ * @param i index of first tensor to be multiplied
+ * @param j index of last tensor to be multiplied
+ * @return Tensor result of multiplying tensors[i...j] together.
+ */
+Tensor matrix_chain_multiplication(
+ TensorList tensors,
+ const std::vector<std::vector<int64_t>>& order,
+ int64_t i,
+ int64_t j) {
+ if (i == j) {
+ return tensors[i];
+ }
+ return at::mm(
+ matrix_chain_multiplication(tensors, order, i, order[i][j]),
+ matrix_chain_multiplication(tensors, order, order[i][j] + 1, j));
+}
+
+// Implements torch.linalg.multi_dot
+Tensor multi_dot_impl(TensorList _tensors, c10::optional<Tensor> _out) {
+ const size_t n = _tensors.size();
+ TORCH_CHECK(n >= 2, "multi_dot(): expected at least 2 tensors but got ", n);
+
+ std::vector<int64_t> out_shape;
+ std::vector<Tensor> tensors(n);
+
+ // If the first tensor is 1D of size n view it as a row vector (1, n)
+ if (_tensors[0].dim() == 1) {
+ tensors[0] = _tensors[0].unsqueeze(0);
+ } else if (_tensors[0].dim() == 2) {
+ tensors[0] = _tensors[0];
+ out_shape.emplace_back(tensors[0].size(0));
+ } else {
+ TORCH_CHECK(
+ false,
+ "multi_dot(): the first tensor must be 1D or 2D but got ",
+ _tensors[0].dim(),
+ "D");
+ }
+
+ // If the last tensor is 1D of size n view it as a column vector (n, 1)
+ if (_tensors[n - 1].dim() == 1) {
+ tensors[n - 1] = _tensors[n - 1].unsqueeze(-1);
+ } else if (_tensors[n - 1].dim() == 2) {
+ tensors[n - 1] = _tensors[n - 1];
+ out_shape.emplace_back(tensors[n - 1].size(1));
+ } else {
+ TORCH_CHECK(
+ false,
+ "multi_dot(): the last tensor must be 1D or 2D but got ",
+ _tensors[0].dim(),
+ "D");
+ }
+
+ // Ensure middle tensors are 2D
+ for (const auto i : c10::irange(1, n - 1)) {
+ TORCH_CHECK(
+ _tensors[i].dim() == 2,
+ "multi_dot(): tensor ",
+ i,
+ " must be 2D but got ",
+ _tensors[0].dim(),
+ "D");
+ tensors[i] = _tensors[i];
+ }
+
+ // Ensure all tensors have the same device and dtype and check
+ // that the shapes can be multiplied
+ const auto dtype = tensors[0].dtype();
+ const auto device = tensors[0].device();
+ for (const auto i : c10::irange(1, n)) {
+ TORCH_CHECK(
+ tensors[i].dtype() == dtype,
+ "multi_dot(): all tensors must have be the same dtype but tensor 0 is ",
+ dtype,
+ " and tensor ",
+ i,
+ " ",
+ tensors[i].dtype());
+ TORCH_CHECK(
+ tensors[i].device() == device,
+ "multi_dot(): all tensors must be on the same device but tensor 0 is on ",
+ device,
+ " and tensor ",
+ i,
+ " on ",
+ tensors[i].device());
+ TORCH_CHECK(
+ tensors[i - 1].size(-1) == tensors[i].size(0),
+ "multi_dot(): tensors ",
+ i - 1,
+ " and ",
+ i,
+ " with shapes ",
+ _tensors[i - 1].sizes(),
+ " and ",
+ _tensors[i].sizes(),
+ " cannot be multiplied")
+ }
+
+ Tensor result;
+
+ if (_out.has_value()) {
+ auto out = *_out;
+ TORCH_CHECK(
+ dtype == out.dtype(),
+ "multi_dot(): expected out tensor to have dtype ",
+ dtype,
+ " but got ",
+ out.dtype());
+ TORCH_CHECK(
+ device == out.device(),
+ "multi_dot(): expected out tensor to be on device ",
+ device,
+ " but got ",
+ out.device());
+
+ // If the last and last tensors have shapes (a, b) and (b, c) the
+ // output has shape (a, c). If either the first or last tensor is 1D
+ // a and/or c dimensions will be implicitely size 1 and will be ommited
+ // from the output. e.g. for inputs (a, b) x (b) the output has shape (a,).
+ at::native::resize_output(out, out_shape);
+
+ // View output as 2D for simplicity of computation.
+ result = out.view({tensors[0].size(0), tensors.back().size(-1)});
+ }
+
+ // The resize_ and view calls below are to ensure the
+ // output shape respects the original dimensionality of
+ // the first and last tensors which we are now viewed as 2D
+
+ if (tensors.size() == 2) {
+ return _out.has_value() ? at::mm_out(result, tensors[0], tensors[1])
+ : at::mm(tensors[0], tensors[1]).view(out_shape);
+ }
+
+ // Why the separate implementation for 3 matrices?
+ // The logic for three matrices is much faster when done directly
+ // Requires 1 comparison to 4 comparisons and fewer arithmetic operations
+ if (tensors.size() == 3) {
+ const auto a = tensors[0].size(0);
+ const auto b = tensors[1].size(0);
+ const auto c = tensors[2].size(0);
+ const auto d = tensors[2].size(1);
+
+ // The matrices are of size (a x b), (b x c), (c x d)
+ // cost_1 is the cost of parenthesizing (a x b) and (b x c) and then
+ // combining (c x d) cost_2 is the cost of parenthesizing (b x c) and (c x
+ // d) and then combining (a x b)
+ const auto cost_1 = (a * c) * (b + d);
+ const auto cost_2 = (b * d) * (a + c);
+
+ if (cost_1 > cost_2) {
+ return _out.has_value()
+ ? at::mm_out(result, tensors[0], at::mm(tensors[1], tensors[2]))
+ : at::mm(tensors[0], at::mm(tensors[1], tensors[2])).view(out_shape);
+ } else {
+ return _out.has_value()
+ ? at::mm_out(result, at::mm(tensors[0], tensors[1]), tensors[2])
+ : at::mm(at::mm(tensors[0], tensors[1]), tensors[2]).view(out_shape);
+ }
+ }
+
+ // Algorithm for multiplying 4 or more matrices
+ const auto order = matrix_chain_order(tensors);
+ const int64_t i = 0;
+ const int64_t j = n - 1;
+
+ if (_out.has_value()) {
+ // We manually implement the first recursive layer here so we can use mm_out
+ // for the final multiplication
+ return at::mm_out(
+ result,
+ matrix_chain_multiplication(tensors, order, i, order[i][j]),
+ matrix_chain_multiplication(tensors, order, order[i][j] + 1, j));
+ }
+ return matrix_chain_multiplication(tensors, order, i, j).view(out_shape);
+}
+
+} // namespace
+
+Tensor linalg_multi_dot(TensorList tensors) {
+ return multi_dot_impl(tensors, c10::nullopt);
+}
+
+Tensor& linalg_multi_dot_out(TensorList tensors, Tensor& result) {
+ multi_dot_impl(tensors, result);
+ return result;
+}
+
+Tensor chain_matmul(TensorList matrices) {
+ checkAllSameDim(matrices, 2);
+
+ TORCH_CHECK(
+ matrices.size() > 0, "chain_matmul(): Expected one or more matrices");
+
+ if (matrices.size() == 1) {
+ return matrices[0].clone();
+ }
+
+ return at::native::linalg_multi_dot(matrices);
+}
+
static void check_1d(const Tensor& t, const char* arg, const char* fn) {
TORCH_CHECK(t.dim() == 1, fn, ": Expected 1-D argument ", arg, ", but got ", t.dim(), "-D");
}
@@ -2236,93 +2494,6 @@
return result;
}
-static inline Tensor _chain_matmul_general(TensorList matrices, std::vector<std::vector<int64_t>>& order, int64_t i, int64_t j) {
- if (i == j)
- return matrices[i];
- else
- return at::mm(_chain_matmul_general(matrices, order, i, order[i][j]), _chain_matmul_general(matrices, order, order[i][j] + 1, j));
-}
-
-// Why the separate implementation for 3 matrices?
-// The logic for three matrices is much faster when done directly
-// Requires 1 comparison to 4 comparisons and lesser arithmetic operations
-static inline Tensor _chain_matmul_three_matrices(TensorList matrices) {
- int64_t a = matrices[0].size(0); // This is the first dimension
- int64_t b = matrices[1].size(0); // This is the common dimension between the first two matrices
- int64_t c = matrices[2].size(0); // This is the common dimension between the last two matrices
- int64_t d = matrices[2].size(1); // This is the last dimension
-
- // The matrices are of size (a x b), (b x c), (c x d)
- // cost_1 is the cost of parenthesizing (a x b) and (b x c) and then combining (c x d)
- // cost_2 is the cost of parenthesizing (b x c) and (c x d) and then combining (a x b)
- int64_t cost_1 = (a * c) * (b + d);
- int64_t cost_2 = (b * d) * (a + c);
-
- if (cost_1 > cost_2) {
- return at::mm(matrices[0], at::mm(matrices[1], matrices[2]));
- } else {
- return at::mm(at::mm(matrices[0], matrices[1]), matrices[2]);
- }
-}
-
-Tensor chain_matmul(TensorList matrices) {
- checkAllSameDim(matrices, 2);
-
- TORCH_CHECK(matrices.size() > 0, "chain_matmul: Expected one or more matrices");
- if (matrices.size() == 1) {
- return matrices[0];
- } else if (matrices.size() == 2) {
- return at::mm(matrices[0], matrices[1]);
- } else if (matrices.size() == 3) {
- return _chain_matmul_three_matrices(matrices);
- } else {
-
- // Following the algorithm in Chapter 15.2 : Introduction to Algorithms, Cormen et al.
- // Minor modifications have been made to accommodate zero-indexing
- auto n = matrices.size();
-
- // Dim vector - the length of which is n + 1. Note that for matrix multiplication, there
- // needs to a common dimension between the multiplicands, hence for n matrices, there are
- // n + 1 values. The values p_{i} and p_{i + 1} correspond to the dimensions of matrix i in
- // the chain (zero-indexed)
- std::vector<int64_t> p;
- p.push_back(matrices[0].size(0));
- for (size_t i = 0; i < n; i++) {
- p.push_back(matrices[i].size(1));
- }
-
- // Cost matrix - an element m[i, j] of this matrix corresponds to the minimum cost of
- // parenthesizing matrices A_{i} to A_{j}. By this definition m[i, i] = 0 for all i
- // m[i, j] is filled using the substructure property of the algorithm, meaning:
- // m[i, j] = min_{i <= k < j} m[i, k] + m[k, j] + p_{i-1}p_{k}p_{j}
- std::vector<std::vector<int64_t>> m(n, std::vector<int64_t>(n, 0));
-
- // Auxiliary table for constructing the order
- // s[i, j] stores the index k at which the optimal split is obtained
- std::vector<std::vector<int64_t>> s(n, std::vector<int64_t>(n));
-
- // j and q are used repetitively in the algorithm below
- int64_t j, q;
-
- for (int64_t l = 1; l < n; l++) {
- for (int64_t i = 0; i < n - l; i++) {
- j = i + l;
- m[i][j] = std::numeric_limits<int64_t>::max();
- for (int64_t k = i; k < j; k++) {
- q = m[i][k] + m[k + 1][j] + p[i] * p[k + 1] * p[j + 1];
- if (q < m[i][j]) {
- m[i][j] = q;
- s[i][j] = k;
- }
- }
- }
- }
-
- // We use the result from the algorithm to compute the matrix chain product via recursion
- return _chain_matmul_general(matrices, s, 0, n - 1);
- }
-}
-
/*
Calculates the Kronecker product between two Tensors.
*/
diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml
index 5d7a3b1..d945a8d 100644
--- a/aten/src/ATen/native/native_functions.yaml
+++ b/aten/src/ATen/native/native_functions.yaml
@@ -8759,6 +8759,12 @@
python_module: linalg
variants: function
+- func: linalg_multi_dot(Tensor[] tensors) -> Tensor
+ python_module: linalg
+
+- func: linalg_multi_dot.out(Tensor[] tensors, *, Tensor(a!) out) -> Tensor(a!)
+ python_module: linalg
+
## Functions that are only for testing
# It is undocumented and should not be used outside of tests.
- func: _test_serialization_subcmul(Tensor self, Tensor other, Scalar alpha=1) -> Tensor
diff --git a/docs/source/amp.rst b/docs/source/amp.rst
index 865bfe8..bf93d95 100644
--- a/docs/source/amp.rst
+++ b/docs/source/amp.rst
@@ -104,6 +104,7 @@
``baddbmm``,
``bmm``,
``chain_matmul``,
+``multi_dot``,
``conv1d``,
``conv2d``,
``conv3d``,
diff --git a/docs/source/linalg.rst b/docs/source/linalg.rst
index e723b2f..3558441 100644
--- a/docs/source/linalg.rst
+++ b/docs/source/linalg.rst
@@ -24,6 +24,7 @@
.. autofunction:: eigvalsh
.. autofunction:: matrix_power
.. autofunction:: matrix_rank
+.. autofunction:: multi_dot
.. autofunction:: norm
.. autofunction:: vector_norm
.. autofunction:: pinv
diff --git a/test/test_cuda.py b/test/test_cuda.py
index 0d397f5..7328097 100644
--- a/test/test_cuda.py
+++ b/test/test_cuda.py
@@ -2680,6 +2680,12 @@
self._run_autocast_outofplace(op, args, torch.float32, module=torch._C._nn)
@unittest.skipIf(not TEST_CUDNN, 'CUDNN not available')
+ def test_autocast_linalg_fp16(self):
+ with torch.backends.cudnn.flags(enabled=True, deterministic=True):
+ for op, args in self.autocast_lists.linalg_fp16:
+ self._run_autocast_outofplace(op, args, torch.float16, module=torch._C._linalg)
+
+ @unittest.skipIf(not TEST_CUDNN, 'CUDNN not available')
def test_autocast_methods_fp16(self):
with torch.backends.cudnn.flags(enabled=True, deterministic=True):
for op, args in self.autocast_lists.methods_fp16:
diff --git a/test/test_fx.py b/test/test_fx.py
index d978ae1..771d857 100644
--- a/test/test_fx.py
+++ b/test/test_fx.py
@@ -2177,7 +2177,7 @@
@onlyCPU
@ops(op_db, allowed_dtypes=(torch.float,))
def test_get_torch_func_signature_exhaustive(self, device, dtype, op):
- known_no_schema = {'stack', 'hstack', 'vstack', 'dstack', 'repeat', '__getitem__'}
+ known_no_schema = {'stack', 'hstack', 'vstack', 'dstack', 'repeat', '__getitem__', 'linalg.multi_dot'}
try:
sample_inputs_itr = op.sample_inputs(device, dtype, requires_grad=False)
diff --git a/test/test_linalg.py b/test/test_linalg.py
index 0b82703..439a454 100644
--- a/test/test_linalg.py
+++ b/test/test_linalg.py
@@ -3659,6 +3659,71 @@
self.assertEqual(torch.matrix_rank(aaT, True), np.linalg.matrix_rank(aaT.cpu().numpy(), True))
self.assertEqual(torch.matrix_rank(aaT, 0.01, True), np.linalg.matrix_rank(aaT.cpu().numpy(), 0.01, True))
+ @onlyOnCPUAndCUDA
+ @dtypes(torch.double, torch.cdouble)
+ def test_multi_dot(self, device, dtype):
+ def check(*shapes, discontiguous=False):
+ tensors = [make_tensor(shape, device, dtype, discontiguous=discontiguous) for shape in shapes]
+ np_arrays = [tensor.cpu().numpy() for tensor in tensors]
+ res = torch.linalg.multi_dot(tensors).cpu()
+ ref = torch.from_numpy(np.array(np.linalg.multi_dot(np_arrays)))
+ self.assertEqual(res, ref)
+
+ # test for inputs with empty dimensions
+ check([0], [0])
+ check([2], [2, 0])
+ check([1, 0], [0])
+ check([0, 2], [2, 1])
+ check([2, 2], [2, 0])
+ check([2, 0], [0, 3])
+ check([0, 0], [0, 1])
+ check([4, 2], [2, 0], [0, 3], [3, 2])
+
+ # test variable output shapes
+ check([2], [2])
+ check([1, 2], [2])
+ check([2], [2, 1])
+ check([1, 2], [2, 1])
+ check([3, 2], [2, 4])
+
+ # test multiple input tensors
+ check([3], [3, 4], [4, 2], [2, 5], [5])
+ check([1, 2], [2, 2], [2, 3], [3, 1])
+
+ # test large tensors
+ check([10, 100], [100, 5], [5, 50])
+
+ # test discontiguous input
+ check([3, 2], [2, 2], [2, 3], [3, 4], discontiguous=True)
+
+ @onlyOnCPUAndCUDA
+ @dtypes(torch.float)
+ def test_multi_dot_errors(self, device, dtype):
+ def check(tensors, out, msg):
+ with self.assertRaisesRegex(RuntimeError, msg):
+ torch.linalg.multi_dot(tensors, out=out)
+
+ a = make_tensor(2, device, dtype)
+
+ check([], None, "expected at least 2 tensors")
+ check([a], None, "expected at least 2 tensors")
+
+ check([torch.tensor(1, device=device, dtype=dtype), a], None, "the first tensor must be 1D or 2D")
+ check([a, torch.tensor(1, device=device, dtype=dtype)], None, "the last tensor must be 1D or 2D")
+
+ check([a, a, a], None, "tensor 1 must be 2D")
+ check([a, make_tensor((2, 2, 2), device, dtype), a], None, "tensor 1 must be 2D")
+
+ check([a, make_tensor(2, device, torch.double)], None, "all tensors must have be the same dtype")
+ check([a, a], torch.empty(0, device=device, dtype=torch.double), "expected out tensor to have dtype")
+
+ if self.device_type == 'cuda':
+ check([a, make_tensor(2, 'cpu', dtype)], None, "all tensors must be on the same device")
+ check([a, a], torch.empty(0, dtype=dtype), "expected out tensor to be on device")
+
+ check([a, make_tensor(3, device, dtype)], None, "cannot be multiplied")
+ check([a, make_tensor((3, 2), device, dtype), a], None, "cannot be multiplied")
+
@precisionOverride({torch.float32: 5e-6, torch.complex64: 5e-6})
@skipCUDAIfNoMagma
@skipCPUIfNoLapack
@@ -6139,7 +6204,7 @@
run_test([10, 20, 30, 5])
run_test([15, 5, 10, 20, 25])
- with self.assertRaisesRegex(RuntimeError, "chain_matmul: Expected one or more matrices"):
+ with self.assertRaisesRegex(RuntimeError, r"chain_matmul\(\): Expected one or more matrices"):
torch.chain_matmul()
@skipCUDAIfNoMagma
diff --git a/tools/autograd/templates/python_linalg_functions.cpp b/tools/autograd/templates/python_linalg_functions.cpp
index d361e74..52eeca2 100644
--- a/tools/autograd/templates/python_linalg_functions.cpp
+++ b/tools/autograd/templates/python_linalg_functions.cpp
@@ -17,6 +17,7 @@
using at::MemoryFormat;
using at::Generator;
using at::IntArrayRef;
+using at::TensorList;
using namespace torch::autograd::utils;
diff --git a/torch/csrc/api/include/torch/linalg.h b/torch/csrc/api/include/torch/linalg.h
index 4f0962c..73e7237 100644
--- a/torch/csrc/api/include/torch/linalg.h
+++ b/torch/csrc/api/include/torch/linalg.h
@@ -96,6 +96,14 @@
return torch::linalg_matrix_rank_out(result, input, tol, hermitian);
}
+inline Tensor multi_dot(TensorList tensors) {
+ return torch::linalg_multi_dot(tensors);
+}
+
+inline Tensor& multi_dot_out(TensorList tensors, Tensor& result) {
+ return torch::linalg_multi_dot_out(result, tensors);
+}
+
inline Tensor pinv(const Tensor& input, double rcond, bool hermitian) {
return torch::linalg_pinv(input, rcond, hermitian);
}
@@ -245,6 +253,15 @@
return detail::matrix_rank_out(result, input, tol, hermitian);
}
+/// See https://pytorch.org/docs/master/linalg.html#torch.linalg.multi_dot
+inline Tensor multi_dot(TensorList tensors) {
+ return torch::linalg_multi_dot(tensors);
+}
+
+inline Tensor& multi_dot_out(TensorList tensors, Tensor& result) {
+ return torch::linalg_multi_dot_out(result, tensors);
+}
+
/// Computes pseudo-inverse
///
/// See https://pytorch.org/docs/master/linalg.html#torch.linalg.pinv
diff --git a/torch/linalg/__init__.py b/torch/linalg/__init__.py
index 383d538..47855fa 100644
--- a/torch/linalg/__init__.py
+++ b/torch/linalg/__init__.py
@@ -696,6 +696,72 @@
tensor(5.4345)
""")
+multi_dot = _add_docstr(_linalg.linalg_multi_dot, r"""
+linalg.multi_dot(tensors, *, out=None)
+
+Efficiently multiplies two or more matrices given by :attr:`tensors` by ordering the
+multiplications so that the fewest arithmetic operations are performed.
+
+Every tensor in :attr:`tensors` must be 2D, except for the first and last which
+may be 1D. If the first tensor is a 1D vector of size `n` it is treated as a row vector
+of size `(1, n)`, similarly if the last tensor is a 1D vector of size `n` it is treated
+as a column vector of size `(n, 1)`.
+
+If the first tensor has size `(a, b)` and the last tensor has size `(c, d)` the
+output will have size `(a, d)`. However, if either tensor is 1D then the implied
+dimension of size `1` as described above is squeezed from the output. e.g. for tensors
+of size `(b)` and `(c, d)` the output will have size `(d)`.
+
+.. warning:: This function does not broadcast.
+
+.. note:: This function is implemented by chaining :func:`torch.mm` calls after
+ computing the optimal matrix multiplication order.
+
+.. note:: This function is similar to NumPy's `multi_dot` except that the first and last
+ tensors must be either 1D or 2D whereas NumPy allows them to be nD.
+
+.. note:: The cost of multiplying two matrices with shapes `(a, b)` and `(b, c)` is
+ `a * b * c`. Given matrices `A`, `B` and `C` each with shapes `(10, 100)`,
+ `(100, 5)` and `(5, 50)` respectively, we can calculate the cost of different
+ multiplication orders as follows:
+
+ .. math::
+
+ cost((AB)C) = 10*100*5 + 10*5*50 = 5000 + 2500 = 7500
+ cost(A(BC)) = 10*100*50 + 100*5*50 = 50000 + 25000 = 75000
+
+ In this case, multiplying A and B first followed by C is 10 times faster.
+
+Args:
+ tensors (sequence of Tensors): two or more tensors to multiply. The first and last
+ tensors may be 1D or 2D. Every other tensor must be 2D.
+
+Keyword args:
+ out (Tensor, optional): The output tensor. Ignored if ``None``. Default: ``None``
+
+Examples::
+
+ >>> from torch.linalg import multi_dot
+
+ >>> multi_dot([torch.tensor([1, 2]), torch.tensor([2, 3])])
+ tensor(8)
+ >>> multi_dot([torch.tensor([[1, 2]]), torch.tensor([2, 3])])
+ tensor([8])
+ >>> multi_dot([torch.tensor([[1, 2]]), torch.tensor([[2], [3]])])
+ tensor([[8]])
+
+ >>> a = torch.arange(2 * 3).view(2, 3)
+ >>> b = torch.arange(3 * 2).view(3, 2)
+ >>> c = torch.arange(2 * 2).view(2, 2)
+ >>> multi_dot((a, b, c))
+ tensor([[ 26, 49],
+ [ 80, 148]])
+
+ >>> multi_dot((a.to(torch.float), torch.empty(3, 0), torch.empty(0, 2)))
+ tensor([[0., 0.],
+ [0., 0.]])
+""")
+
norm = _add_docstr(_linalg.linalg_norm, r"""
linalg.norm(input, ord=None, dim=None, keepdim=False, *, out=None, dtype=None) -> Tensor
diff --git a/torch/overrides.py b/torch/overrides.py
index cda37e7..17004f5 100644
--- a/torch/overrides.py
+++ b/torch/overrides.py
@@ -535,6 +535,7 @@
torch.linalg.matrix_power: lambda input, n, out=None: -1,
torch.matrix_rank: lambda input, tol=None, symmetric=False: -1,
torch.linalg.matrix_rank: lambda input, tol=None, hermitian=False: -1,
+ torch.linalg.multi_dot: lambda tensors, out=None: -1,
torch.matrix_exp: lambda input: -1,
torch.max: lambda input, out=None: -1,
torch.maximum: lambda input, other, out=None: -1,
diff --git a/torch/testing/_internal/autocast_test_lists.py b/torch/testing/_internal/autocast_test_lists.py
index f48de4f..1c84a3d 100644
--- a/torch/testing/_internal/autocast_test_lists.py
+++ b/torch/testing/_internal/autocast_test_lists.py
@@ -221,6 +221,9 @@
("soft_margin_loss", mat0_fp16 + (torch.ones((n, n), device=dev, dtype=torch.long),)),
("multi_margin_loss", mat0_fp16 + (torch.ones((n,), device=dev, dtype=torch.long),)),
]
+ self.linalg_fp16 = [
+ ("linalg_multi_dot", (mat0_fp32 + mat1_fp32 + mat2_fp32,)),
+ ]
self.methods_fp16 = [
("__matmul__", mat0_fp32 + mat1_fp32)
]
diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py
index 5803a17..fadb42f 100644
--- a/torch/testing/_internal/common_methods_invocations.py
+++ b/torch/testing/_internal/common_methods_invocations.py
@@ -405,6 +405,29 @@
return inputs
+def sample_inputs_linalg_multi_dot(op_info, device, dtype, requires_grad):
+ # Each test case consists of the sizes in the chain of multiplications
+ # e.g. [2, 3, 4, 5] generates matrices (2, 3) @ (3, 4) @ (4, 5)
+ test_cases = [
+ [1, 2, 1],
+ [2, 0, 2],
+ [0, 2, 2],
+ [2, 2, 2, 2],
+ [2, 3, 4, 5],
+ [5, 4, 0, 2],
+ [2, 4, 3, 5, 3, 2]
+ ]
+
+ result = []
+ for sizes in test_cases:
+ tensors = []
+ for size in zip(sizes[:-1], sizes[1:]):
+ t = make_tensor(size, device, dtype, requires_grad=requires_grad)
+ tensors.append(t)
+ result.append(SampleInput(tensors))
+
+ return result
+
def sample_inputs_linalg_norm(op_info, device, dtype, requires_grad):
test_sizes = [
(S,),
@@ -2822,6 +2845,19 @@
supports_inplace_autograd=False,
decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack, skipCUDAIfRocm],
sample_inputs_func=sample_inputs_linalg_matrix_power,),
+ OpInfo('linalg.multi_dot',
+ # Need this lambda because gradcheck does not work with TensorList inputs
+ aten_name='linalg_multi_dot',
+ dtypes=floating_and_complex_types_and(torch.half),
+ dtypesIfCPU=all_types_and_complex_and(torch.half, torch.bfloat16),
+ dtypesIfCUDA=floating_and_complex_types_and(torch.half, *[torch.bfloat16] if CUDA11OrLater else []),
+ supports_inplace_autograd=False,
+ # Batched grad checks fail for empty input tensors (see https://github.com/pytorch/pytorch/issues/53407)
+ check_batched_grad=False,
+ check_batched_gradgrad=False,
+ sample_inputs_func=sample_inputs_linalg_multi_dot,
+ # test_variant_consistency_jit does not work with TensorList inputs
+ skips=(SkipInfo('TestCommon', 'test_variant_consistency_jit'),)),
OpInfo('linalg.norm',
op=torch.linalg.norm,
dtypes=floating_and_complex_types_and(torch.float16, torch.bfloat16),