[Relanding] Implemented torch.linalg.multi_dot (#52859)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/52859

This reverts commit 92a4ee1cf6092dd941591f80885eb7fef5b2c0d8.

Added support for bfloat16 for CUDA 11 and removed fast-path for empty input tensors that was affecting autograd graph.

Test Plan: Imported from OSS

Reviewed By: H-Huang

Differential Revision: D27402390

Pulled By: heitorschueroff

fbshipit-source-id: 73c5ccf54f3da3d29eb63c9ed3601e2fe6951034
diff --git a/aten/src/ATen/autocast_mode.cpp b/aten/src/ATen/autocast_mode.cpp
index b204e7d..1a36659 100644
--- a/aten/src/ATen/autocast_mode.cpp
+++ b/aten/src/ATen/autocast_mode.cpp
@@ -282,6 +282,7 @@
   KERNEL(ADD_NS(baddbmm), "baddbmm", Tensor (const Tensor &, const Tensor &, const Tensor &, const Scalar&, const Scalar&), fp16)
   KERNEL(ADD_NS(bmm), "bmm", Tensor (const Tensor &, const Tensor &), fp16)
   KERNEL(ADD_NS(chain_matmul), "chain_matmul", Tensor (TensorList), fp16)
+  KERNEL(ADD_NS(linalg_multi_dot), "linalg_multi_dot", Tensor (TensorList), fp16)
   // The macro doesn't like these (I think it chokes on commas inside <>) so write them manually
   m.impl(TORCH_SELECTIVE_NAME("aten::_thnn_fused_lstm_cell"),
          TORCH_FN((&WrapFunction<CastPolicy::fp16,
diff --git a/aten/src/ATen/native/LinearAlgebra.cpp b/aten/src/ATen/native/LinearAlgebra.cpp
index 449dd57..461afa1 100644
--- a/aten/src/ATen/native/LinearAlgebra.cpp
+++ b/aten/src/ATen/native/LinearAlgebra.cpp
@@ -1,32 +1,27 @@
 #include <ATen/ATen.h>
-#include <ATen/core/grad_mode.h>
 #include <ATen/Dispatch.h>
 #include <ATen/ExpandUtils.h>
 #include <ATen/LegacyTHFunctionsCPU.h>
 #include <ATen/NamedTensorUtils.h>
+#include <ATen/NativeFunctions.h>
+#include <ATen/Parallel.h>
+#include <ATen/TensorUtils.h>
+#include <ATen/Utils.h>
+#include <ATen/core/grad_mode.h>
 #include <ATen/native/CPUBlas.h>
 #include <ATen/native/IndexingUtils.h>
 #include <ATen/native/LinearAlgebra.h>
 #include <ATen/native/LinearAlgebraUtils.h>
+#include <ATen/native/ReduceOps.h>
 #include <ATen/native/ReduceOpsUtils.h>
 #include <ATen/native/Resize.h>
 #include <ATen/native/TensorIterator.h>
-#include <ATen/NativeFunctions.h>
-#include <ATen/native/LinearAlgebra.h>
-#include <ATen/native/IndexingUtils.h>
-#include <ATen/native/ReduceOps.h>
-#include <ATen/TensorUtils.h>
-#include <ATen/Parallel.h>
-#include <ATen/TensorUtils.h>
-#include <ATen/Utils.h>
 #include <c10/util/accumulate.h>
+#include <c10/util/irange.h>
 #include <c10/util/variant.h>
-
 #include <functional>
 #include <limits>
 #include <numeric>
-#include <vector>
-
 
 namespace at {
 namespace native {
@@ -359,6 +354,269 @@
   return at::linalg_matrix_rank(self, c10::nullopt, symmetric);
 }
 
+// multi_dot helper functions
+namespace {
+
+/**
+ * @brief Computes the optimal matrix chain multiplication order
+ *
+ * Follows the dynamic programming algorithm from Cormen et al,
+ * "Introduction to Algorithms, Third Edition", Chapter 15.2,
+ * p. 370-378. Note that the book uses 1-based indexing.
+ *
+ * The cost of multiplying two matrices with sizes p x q and q x r
+ * is defined here as p * q * r. The optimal multiplication order
+ * is the one that minimizes the total cost.
+ *
+ * @param tensors list of 2D tensors
+ * @return a 2D vector s used by #matrix_chain_multiplication to construct
+ *         the optimal matrix multiplication order. The optimal multiplication
+ *         order for multiplying tensors i...j is to multiply tensors i...s[i, j]
+ *         and tensors (s[i, j] + 1)...j first and then the result of that.
+ */
+std::vector<std::vector<int64_t>> matrix_chain_order(TensorList tensors) {
+  const size_t n = tensors.size();
+
+  // Tensor i has dimensions p[i] x p[i + 1]
+  std::vector<int64_t> p(n + 1);
+  for (const auto i : c10::irange(n)) {
+    p[i] = tensors[i].size(0);
+  }
+  p[n] = tensors[n - 1].size(1);
+
+  // m[i, j] = k where k is the minimum cost for multiplying tensors i...j
+  std::vector<std::vector<int64_t>> m(n, std::vector<int64_t>(n, 0));
+
+  // s[i, j] = k where k is the index at which to split the list such that
+  // optimally multiplying matrices i...k and k...j first and then the resulting
+  // matrices is the optimal order for multiplying matrices i...j.
+  std::vector<std::vector<int64_t>> s(n, std::vector<int64_t>(n));
+
+  // Compute the optimal multiplication order
+  for (const auto l : c10::irange(1, n)) {
+    for (const auto i : c10::irange(n - l)) {
+      const auto j = i + l;
+      m[i][j] = std::numeric_limits<int64_t>::max();
+      for (const auto k : c10::irange(i, j)) {
+        const auto q = m[i][k] + m[k + 1][j] + p[i] * p[k + 1] * p[j + 1];
+        if (q < m[i][j]) {
+          m[i][j] = q;
+          s[i][j] = k;
+        }
+      }
+    }
+  }
+
+  return s;
+}
+
+/**
+ * @brief Recursively multiplies the tensors i...j using the given order
+ *
+ * @param tensors matrices to multiply togther
+ * @param order optimal chain multiplication order from #matrix_chain_order
+ * @param i index of first tensor to be multiplied
+ * @param j index of last tensor to be multiplied
+ * @return Tensor result of multiplying tensors[i...j] together.
+ */
+Tensor matrix_chain_multiplication(
+    TensorList tensors,
+    const std::vector<std::vector<int64_t>>& order,
+    int64_t i,
+    int64_t j) {
+  if (i == j) {
+    return tensors[i];
+  }
+  return at::mm(
+      matrix_chain_multiplication(tensors, order, i, order[i][j]),
+      matrix_chain_multiplication(tensors, order, order[i][j] + 1, j));
+}
+
+// Implements torch.linalg.multi_dot
+Tensor multi_dot_impl(TensorList _tensors, c10::optional<Tensor> _out) {
+  const size_t n = _tensors.size();
+  TORCH_CHECK(n >= 2, "multi_dot(): expected at least 2 tensors but got ", n);
+
+  std::vector<int64_t> out_shape;
+  std::vector<Tensor> tensors(n);
+
+  // If the first tensor is 1D of size n view it as a row vector (1, n)
+  if (_tensors[0].dim() == 1) {
+    tensors[0] = _tensors[0].unsqueeze(0);
+  } else if (_tensors[0].dim() == 2) {
+    tensors[0] = _tensors[0];
+    out_shape.emplace_back(tensors[0].size(0));
+  } else {
+    TORCH_CHECK(
+        false,
+        "multi_dot(): the first tensor must be 1D or 2D but got ",
+        _tensors[0].dim(),
+        "D");
+  }
+
+  // If the last tensor is 1D of size n view it as a column vector (n, 1)
+  if (_tensors[n - 1].dim() == 1) {
+    tensors[n - 1] = _tensors[n - 1].unsqueeze(-1);
+  } else if (_tensors[n - 1].dim() == 2) {
+    tensors[n - 1] = _tensors[n - 1];
+    out_shape.emplace_back(tensors[n - 1].size(1));
+  } else {
+    TORCH_CHECK(
+        false,
+        "multi_dot(): the last tensor must be 1D or 2D but got ",
+        _tensors[0].dim(),
+        "D");
+  }
+
+  // Ensure middle tensors are 2D
+  for (const auto i : c10::irange(1, n - 1)) {
+    TORCH_CHECK(
+        _tensors[i].dim() == 2,
+        "multi_dot(): tensor ",
+        i,
+        " must be 2D but got ",
+        _tensors[0].dim(),
+        "D");
+    tensors[i] = _tensors[i];
+  }
+
+  // Ensure all tensors have the same device and dtype and check
+  // that the shapes can be multiplied
+  const auto dtype = tensors[0].dtype();
+  const auto device = tensors[0].device();
+  for (const auto i : c10::irange(1, n)) {
+    TORCH_CHECK(
+        tensors[i].dtype() == dtype,
+        "multi_dot(): all tensors must have be the same dtype but tensor 0 is ",
+        dtype,
+        " and tensor ",
+        i,
+        " ",
+        tensors[i].dtype());
+    TORCH_CHECK(
+        tensors[i].device() == device,
+        "multi_dot(): all tensors must be on the same device but tensor 0 is on ",
+        device,
+        " and tensor ",
+        i,
+        " on ",
+        tensors[i].device());
+    TORCH_CHECK(
+        tensors[i - 1].size(-1) == tensors[i].size(0),
+        "multi_dot(): tensors ",
+        i - 1,
+        " and ",
+        i,
+        " with shapes ",
+        _tensors[i - 1].sizes(),
+        " and ",
+        _tensors[i].sizes(),
+        " cannot be multiplied")
+  }
+
+  Tensor result;
+
+  if (_out.has_value()) {
+    auto out = *_out;
+    TORCH_CHECK(
+        dtype == out.dtype(),
+        "multi_dot(): expected out tensor to have dtype ",
+        dtype,
+        " but got ",
+        out.dtype());
+    TORCH_CHECK(
+        device == out.device(),
+        "multi_dot(): expected out tensor to be on device ",
+        device,
+        " but got ",
+        out.device());
+
+    // If the last and last tensors have shapes (a, b) and (b, c) the
+    // output has shape (a, c). If either the first or last tensor is 1D
+    // a and/or c dimensions will be implicitely size 1 and will be ommited
+    // from the output. e.g. for inputs (a, b) x (b) the output has shape (a,).
+    at::native::resize_output(out, out_shape);
+
+    // View output as 2D for simplicity of computation.
+    result = out.view({tensors[0].size(0), tensors.back().size(-1)});
+  }
+
+  // The resize_ and view calls below are to ensure the
+  // output shape respects the original dimensionality of
+  // the first and last tensors which we are now viewed as 2D
+
+  if (tensors.size() == 2) {
+    return _out.has_value() ? at::mm_out(result, tensors[0], tensors[1])
+                         : at::mm(tensors[0], tensors[1]).view(out_shape);
+  }
+
+  // Why the separate implementation for 3 matrices?
+  // The logic for three matrices is much faster when done directly
+  // Requires 1 comparison to 4 comparisons and fewer arithmetic operations
+  if (tensors.size() == 3) {
+    const auto a = tensors[0].size(0);
+    const auto b = tensors[1].size(0);
+    const auto c = tensors[2].size(0);
+    const auto d = tensors[2].size(1);
+
+    // The matrices are of size (a x b), (b x c), (c x d)
+    // cost_1 is the cost of parenthesizing (a x b) and (b x c) and then
+    // combining (c x d) cost_2 is the cost of parenthesizing (b x c) and (c x
+    // d) and then combining (a x b)
+    const auto cost_1 = (a * c) * (b + d);
+    const auto cost_2 = (b * d) * (a + c);
+
+    if (cost_1 > cost_2) {
+      return _out.has_value()
+          ? at::mm_out(result, tensors[0], at::mm(tensors[1], tensors[2]))
+          : at::mm(tensors[0], at::mm(tensors[1], tensors[2])).view(out_shape);
+    } else {
+      return _out.has_value()
+          ? at::mm_out(result, at::mm(tensors[0], tensors[1]), tensors[2])
+          : at::mm(at::mm(tensors[0], tensors[1]), tensors[2]).view(out_shape);
+    }
+  }
+
+  // Algorithm for multiplying 4 or more matrices
+  const auto order = matrix_chain_order(tensors);
+  const int64_t i = 0;
+  const int64_t j = n - 1;
+
+  if (_out.has_value()) {
+    // We manually implement the first recursive layer here so we can use mm_out
+    // for the final multiplication
+    return at::mm_out(
+        result,
+        matrix_chain_multiplication(tensors, order, i, order[i][j]),
+        matrix_chain_multiplication(tensors, order, order[i][j] + 1, j));
+  }
+  return matrix_chain_multiplication(tensors, order, i, j).view(out_shape);
+}
+
+} // namespace
+
+Tensor linalg_multi_dot(TensorList tensors) {
+  return multi_dot_impl(tensors, c10::nullopt);
+}
+
+Tensor& linalg_multi_dot_out(TensorList tensors, Tensor& result) {
+  multi_dot_impl(tensors, result);
+  return result;
+}
+
+Tensor chain_matmul(TensorList matrices) {
+  checkAllSameDim(matrices, 2);
+
+  TORCH_CHECK(
+      matrices.size() > 0, "chain_matmul(): Expected one or more matrices");
+
+  if (matrices.size() == 1) {
+    return matrices[0].clone();
+  }
+
+  return at::native::linalg_multi_dot(matrices);
+}
+
 static void check_1d(const Tensor& t, const char* arg, const char* fn) {
  TORCH_CHECK(t.dim() == 1, fn, ": Expected 1-D argument ", arg, ", but got ", t.dim(), "-D");
 }
@@ -2236,93 +2494,6 @@
   return result;
 }
 
-static inline Tensor _chain_matmul_general(TensorList matrices, std::vector<std::vector<int64_t>>& order, int64_t i, int64_t j) {
-  if (i == j)
-    return matrices[i];
-  else
-    return at::mm(_chain_matmul_general(matrices, order, i, order[i][j]), _chain_matmul_general(matrices, order, order[i][j] + 1, j));
-}
-
-// Why the separate implementation for 3 matrices?
-// The logic for three matrices is much faster when done directly
-// Requires 1 comparison to 4 comparisons and lesser arithmetic operations
-static inline Tensor _chain_matmul_three_matrices(TensorList matrices) {
-  int64_t a = matrices[0].size(0);  // This is the first dimension
-  int64_t b = matrices[1].size(0);  // This is the common dimension between the first two matrices
-  int64_t c = matrices[2].size(0);  // This is the common dimension between the last two matrices
-  int64_t d = matrices[2].size(1);  // This is the last dimension
-
-  // The matrices are of size (a x b), (b x c), (c x d)
-  // cost_1 is the cost of parenthesizing (a x b) and (b x c) and then combining (c x d)
-  // cost_2 is the cost of parenthesizing (b x c) and (c x d) and then combining (a x b)
-  int64_t cost_1 = (a * c) * (b + d);
-  int64_t cost_2 = (b * d) * (a + c);
-
-  if (cost_1 > cost_2) {
-    return at::mm(matrices[0], at::mm(matrices[1], matrices[2]));
-  } else {
-    return at::mm(at::mm(matrices[0], matrices[1]), matrices[2]);
-  }
-}
-
-Tensor chain_matmul(TensorList matrices) {
-  checkAllSameDim(matrices, 2);
-
-  TORCH_CHECK(matrices.size() > 0, "chain_matmul: Expected one or more matrices");
-  if (matrices.size() == 1) {
-    return matrices[0];
-  } else if (matrices.size() == 2) {
-    return at::mm(matrices[0], matrices[1]);
-  } else if (matrices.size() == 3) {
-    return _chain_matmul_three_matrices(matrices);
-  } else {
-
-    // Following the algorithm in Chapter 15.2 : Introduction to Algorithms, Cormen et al.
-    // Minor modifications have been made to accommodate zero-indexing
-    auto n = matrices.size();
-
-    // Dim vector - the length of which is n + 1. Note that for matrix multiplication, there
-    // needs to a common dimension between the multiplicands, hence for n matrices, there are
-    // n + 1 values. The values p_{i} and p_{i + 1} correspond to the dimensions of matrix i in
-    // the chain (zero-indexed)
-    std::vector<int64_t> p;
-    p.push_back(matrices[0].size(0));
-    for (size_t i = 0; i < n; i++) {
-      p.push_back(matrices[i].size(1));
-    }
-
-    // Cost matrix - an element m[i, j] of this matrix corresponds to the minimum cost of
-    // parenthesizing matrices A_{i} to A_{j}. By this definition m[i, i] = 0 for all i
-    // m[i, j] is filled using the substructure property of the algorithm, meaning:
-    // m[i, j] = min_{i <= k < j} m[i, k] + m[k, j] + p_{i-1}p_{k}p_{j}
-    std::vector<std::vector<int64_t>> m(n, std::vector<int64_t>(n, 0));
-
-    // Auxiliary table for constructing the order
-    // s[i, j] stores the index k at which the optimal split is obtained
-    std::vector<std::vector<int64_t>> s(n, std::vector<int64_t>(n));
-
-    // j and q are used repetitively in the algorithm below
-    int64_t j, q;
-
-    for (int64_t l = 1; l < n; l++) {
-      for (int64_t i = 0; i < n - l; i++) {
-        j = i + l;
-        m[i][j] = std::numeric_limits<int64_t>::max();
-        for (int64_t k = i; k < j; k++) {
-          q = m[i][k] + m[k + 1][j] + p[i] * p[k + 1] * p[j + 1];
-          if (q < m[i][j]) {
-            m[i][j] = q;
-            s[i][j] = k;
-          }
-        }
-      }
-    }
-
-    // We use the result from the algorithm to compute the matrix chain product via recursion
-    return _chain_matmul_general(matrices, s, 0, n - 1);
-  }
-}
-
 /*
 Calculates the Kronecker product between two Tensors.
 */
diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml
index 5d7a3b1..d945a8d 100644
--- a/aten/src/ATen/native/native_functions.yaml
+++ b/aten/src/ATen/native/native_functions.yaml
@@ -8759,6 +8759,12 @@
   python_module: linalg
   variants: function
 
+- func: linalg_multi_dot(Tensor[] tensors) -> Tensor
+  python_module: linalg
+
+- func: linalg_multi_dot.out(Tensor[] tensors, *, Tensor(a!) out) -> Tensor(a!)
+  python_module: linalg
+
 ## Functions that are only for testing
 # It is undocumented and should not be used outside of tests.
 - func: _test_serialization_subcmul(Tensor self, Tensor other, Scalar alpha=1) -> Tensor
diff --git a/docs/source/amp.rst b/docs/source/amp.rst
index 865bfe8..bf93d95 100644
--- a/docs/source/amp.rst
+++ b/docs/source/amp.rst
@@ -104,6 +104,7 @@
 ``baddbmm``,
 ``bmm``,
 ``chain_matmul``,
+``multi_dot``,
 ``conv1d``,
 ``conv2d``,
 ``conv3d``,
diff --git a/docs/source/linalg.rst b/docs/source/linalg.rst
index e723b2f..3558441 100644
--- a/docs/source/linalg.rst
+++ b/docs/source/linalg.rst
@@ -24,6 +24,7 @@
 .. autofunction:: eigvalsh
 .. autofunction:: matrix_power
 .. autofunction:: matrix_rank
+.. autofunction:: multi_dot
 .. autofunction:: norm
 .. autofunction:: vector_norm
 .. autofunction:: pinv
diff --git a/test/test_cuda.py b/test/test_cuda.py
index 0d397f5..7328097 100644
--- a/test/test_cuda.py
+++ b/test/test_cuda.py
@@ -2680,6 +2680,12 @@
             self._run_autocast_outofplace(op, args, torch.float32, module=torch._C._nn)
 
     @unittest.skipIf(not TEST_CUDNN, 'CUDNN not available')
+    def test_autocast_linalg_fp16(self):
+        with torch.backends.cudnn.flags(enabled=True, deterministic=True):
+            for op, args in self.autocast_lists.linalg_fp16:
+                self._run_autocast_outofplace(op, args, torch.float16, module=torch._C._linalg)
+
+    @unittest.skipIf(not TEST_CUDNN, 'CUDNN not available')
     def test_autocast_methods_fp16(self):
         with torch.backends.cudnn.flags(enabled=True, deterministic=True):
             for op, args in self.autocast_lists.methods_fp16:
diff --git a/test/test_fx.py b/test/test_fx.py
index d978ae1..771d857 100644
--- a/test/test_fx.py
+++ b/test/test_fx.py
@@ -2177,7 +2177,7 @@
     @onlyCPU
     @ops(op_db, allowed_dtypes=(torch.float,))
     def test_get_torch_func_signature_exhaustive(self, device, dtype, op):
-        known_no_schema = {'stack', 'hstack', 'vstack', 'dstack', 'repeat', '__getitem__'}
+        known_no_schema = {'stack', 'hstack', 'vstack', 'dstack', 'repeat', '__getitem__', 'linalg.multi_dot'}
 
         try:
             sample_inputs_itr = op.sample_inputs(device, dtype, requires_grad=False)
diff --git a/test/test_linalg.py b/test/test_linalg.py
index 0b82703..439a454 100644
--- a/test/test_linalg.py
+++ b/test/test_linalg.py
@@ -3659,6 +3659,71 @@
             self.assertEqual(torch.matrix_rank(aaT, True), np.linalg.matrix_rank(aaT.cpu().numpy(), True))
             self.assertEqual(torch.matrix_rank(aaT, 0.01, True), np.linalg.matrix_rank(aaT.cpu().numpy(), 0.01, True))
 
+    @onlyOnCPUAndCUDA
+    @dtypes(torch.double, torch.cdouble)
+    def test_multi_dot(self, device, dtype):
+        def check(*shapes, discontiguous=False):
+            tensors = [make_tensor(shape, device, dtype, discontiguous=discontiguous) for shape in shapes]
+            np_arrays = [tensor.cpu().numpy() for tensor in tensors]
+            res = torch.linalg.multi_dot(tensors).cpu()
+            ref = torch.from_numpy(np.array(np.linalg.multi_dot(np_arrays)))
+            self.assertEqual(res, ref)
+
+        # test for inputs with empty dimensions
+        check([0], [0])
+        check([2], [2, 0])
+        check([1, 0], [0])
+        check([0, 2], [2, 1])
+        check([2, 2], [2, 0])
+        check([2, 0], [0, 3])
+        check([0, 0], [0, 1])
+        check([4, 2], [2, 0], [0, 3], [3, 2])
+
+        # test variable output shapes
+        check([2], [2])
+        check([1, 2], [2])
+        check([2], [2, 1])
+        check([1, 2], [2, 1])
+        check([3, 2], [2, 4])
+
+        # test multiple input tensors
+        check([3], [3, 4], [4, 2], [2, 5], [5])
+        check([1, 2], [2, 2], [2, 3], [3, 1])
+
+        # test large tensors
+        check([10, 100], [100, 5], [5, 50])
+
+        # test discontiguous input
+        check([3, 2], [2, 2], [2, 3], [3, 4], discontiguous=True)
+
+    @onlyOnCPUAndCUDA
+    @dtypes(torch.float)
+    def test_multi_dot_errors(self, device, dtype):
+        def check(tensors, out, msg):
+            with self.assertRaisesRegex(RuntimeError, msg):
+                torch.linalg.multi_dot(tensors, out=out)
+
+        a = make_tensor(2, device, dtype)
+
+        check([], None, "expected at least 2 tensors")
+        check([a], None, "expected at least 2 tensors")
+
+        check([torch.tensor(1, device=device, dtype=dtype), a], None, "the first tensor must be 1D or 2D")
+        check([a, torch.tensor(1, device=device, dtype=dtype)], None, "the last tensor must be 1D or 2D")
+
+        check([a, a, a], None, "tensor 1 must be 2D")
+        check([a, make_tensor((2, 2, 2), device, dtype), a], None, "tensor 1 must be 2D")
+
+        check([a, make_tensor(2, device, torch.double)], None, "all tensors must have be the same dtype")
+        check([a, a], torch.empty(0, device=device, dtype=torch.double), "expected out tensor to have dtype")
+
+        if self.device_type == 'cuda':
+            check([a, make_tensor(2, 'cpu', dtype)], None, "all tensors must be on the same device")
+            check([a, a], torch.empty(0, dtype=dtype), "expected out tensor to be on device")
+
+        check([a, make_tensor(3, device, dtype)], None, "cannot be multiplied")
+        check([a, make_tensor((3, 2), device, dtype), a], None, "cannot be multiplied")
+
     @precisionOverride({torch.float32: 5e-6, torch.complex64: 5e-6})
     @skipCUDAIfNoMagma
     @skipCPUIfNoLapack
@@ -6139,7 +6204,7 @@
         run_test([10, 20, 30, 5])
         run_test([15, 5, 10, 20, 25])
 
-        with self.assertRaisesRegex(RuntimeError, "chain_matmul: Expected one or more matrices"):
+        with self.assertRaisesRegex(RuntimeError, r"chain_matmul\(\): Expected one or more matrices"):
             torch.chain_matmul()
 
     @skipCUDAIfNoMagma
diff --git a/tools/autograd/templates/python_linalg_functions.cpp b/tools/autograd/templates/python_linalg_functions.cpp
index d361e74..52eeca2 100644
--- a/tools/autograd/templates/python_linalg_functions.cpp
+++ b/tools/autograd/templates/python_linalg_functions.cpp
@@ -17,6 +17,7 @@
 using at::MemoryFormat;
 using at::Generator;
 using at::IntArrayRef;
+using at::TensorList;
 
 using namespace torch::autograd::utils;
 
diff --git a/torch/csrc/api/include/torch/linalg.h b/torch/csrc/api/include/torch/linalg.h
index 4f0962c..73e7237 100644
--- a/torch/csrc/api/include/torch/linalg.h
+++ b/torch/csrc/api/include/torch/linalg.h
@@ -96,6 +96,14 @@
   return torch::linalg_matrix_rank_out(result, input, tol, hermitian);
 }
 
+inline Tensor multi_dot(TensorList tensors) {
+  return torch::linalg_multi_dot(tensors);
+}
+
+inline Tensor& multi_dot_out(TensorList tensors, Tensor& result) {
+  return torch::linalg_multi_dot_out(result, tensors);
+}
+
 inline Tensor pinv(const Tensor& input, double rcond, bool hermitian) {
   return torch::linalg_pinv(input, rcond, hermitian);
 }
@@ -245,6 +253,15 @@
   return detail::matrix_rank_out(result, input, tol, hermitian);
 }
 
+/// See https://pytorch.org/docs/master/linalg.html#torch.linalg.multi_dot
+inline Tensor multi_dot(TensorList tensors) {
+  return torch::linalg_multi_dot(tensors);
+}
+
+inline Tensor& multi_dot_out(TensorList tensors, Tensor& result) {
+  return torch::linalg_multi_dot_out(result, tensors);
+}
+
 /// Computes pseudo-inverse
 ///
 /// See https://pytorch.org/docs/master/linalg.html#torch.linalg.pinv
diff --git a/torch/linalg/__init__.py b/torch/linalg/__init__.py
index 383d538..47855fa 100644
--- a/torch/linalg/__init__.py
+++ b/torch/linalg/__init__.py
@@ -696,6 +696,72 @@
     tensor(5.4345)
 """)
 
+multi_dot = _add_docstr(_linalg.linalg_multi_dot, r"""
+linalg.multi_dot(tensors, *, out=None)
+
+Efficiently multiplies two or more matrices given by :attr:`tensors` by ordering the
+multiplications so that the fewest arithmetic operations are performed.
+
+Every tensor in :attr:`tensors` must be 2D, except for the first and last which
+may be 1D. If the first tensor is a 1D vector of size `n` it is treated as a row vector
+of size `(1, n)`, similarly if the last tensor is a 1D vector of size `n` it is treated
+as a column vector of size `(n, 1)`.
+
+If the first tensor has size `(a, b)` and the last tensor has size `(c, d)` the
+output will have size `(a, d)`. However, if either tensor is 1D then the implied
+dimension of size `1` as described above is squeezed from the output. e.g. for tensors
+of size `(b)` and `(c, d)` the output will have size `(d)`.
+
+.. warning:: This function does not broadcast.
+
+.. note:: This function is implemented by chaining :func:`torch.mm` calls after
+          computing the optimal matrix multiplication order.
+
+.. note:: This function is similar to NumPy's `multi_dot` except that the first and last
+          tensors must be either 1D or 2D whereas NumPy allows them to be nD.
+
+.. note:: The cost of multiplying two matrices with shapes `(a, b)` and `(b, c)` is
+          `a * b * c`. Given matrices `A`, `B` and `C` each with shapes `(10, 100)`,
+          `(100, 5)` and `(5, 50)` respectively, we can calculate the cost of different
+          multiplication orders as follows:
+
+          .. math::
+
+            cost((AB)C) = 10*100*5 + 10*5*50 = 5000 + 2500 = 7500
+            cost(A(BC)) = 10*100*50 + 100*5*50 = 50000 + 25000 = 75000
+
+          In this case, multiplying A and B first followed by C is 10 times faster.
+
+Args:
+    tensors (sequence of Tensors): two or more tensors to multiply. The first and last
+        tensors may be 1D or 2D. Every other tensor must be 2D.
+
+Keyword args:
+    out (Tensor, optional): The output tensor. Ignored if ``None``. Default: ``None``
+
+Examples::
+
+    >>> from torch.linalg import multi_dot
+
+    >>> multi_dot([torch.tensor([1, 2]), torch.tensor([2, 3])])
+    tensor(8)
+    >>> multi_dot([torch.tensor([[1, 2]]), torch.tensor([2, 3])])
+    tensor([8])
+    >>> multi_dot([torch.tensor([[1, 2]]), torch.tensor([[2], [3]])])
+    tensor([[8]])
+
+    >>> a = torch.arange(2 * 3).view(2, 3)
+    >>> b = torch.arange(3 * 2).view(3, 2)
+    >>> c = torch.arange(2 * 2).view(2, 2)
+    >>> multi_dot((a, b, c))
+    tensor([[ 26,  49],
+            [ 80, 148]])
+
+    >>> multi_dot((a.to(torch.float), torch.empty(3, 0), torch.empty(0, 2)))
+    tensor([[0., 0.],
+            [0., 0.]])
+""")
+
 norm = _add_docstr(_linalg.linalg_norm, r"""
 linalg.norm(input, ord=None, dim=None, keepdim=False, *, out=None, dtype=None) -> Tensor
 
diff --git a/torch/overrides.py b/torch/overrides.py
index cda37e7..17004f5 100644
--- a/torch/overrides.py
+++ b/torch/overrides.py
@@ -535,6 +535,7 @@
         torch.linalg.matrix_power: lambda input, n, out=None: -1,
         torch.matrix_rank: lambda input, tol=None, symmetric=False: -1,
         torch.linalg.matrix_rank: lambda input, tol=None, hermitian=False: -1,
+        torch.linalg.multi_dot: lambda tensors, out=None: -1,
         torch.matrix_exp: lambda input: -1,
         torch.max: lambda input, out=None: -1,
         torch.maximum: lambda input, other, out=None: -1,
diff --git a/torch/testing/_internal/autocast_test_lists.py b/torch/testing/_internal/autocast_test_lists.py
index f48de4f..1c84a3d 100644
--- a/torch/testing/_internal/autocast_test_lists.py
+++ b/torch/testing/_internal/autocast_test_lists.py
@@ -221,6 +221,9 @@
             ("soft_margin_loss", mat0_fp16 + (torch.ones((n, n), device=dev, dtype=torch.long),)),
             ("multi_margin_loss", mat0_fp16 + (torch.ones((n,), device=dev, dtype=torch.long),)),
         ]
+        self.linalg_fp16 = [
+            ("linalg_multi_dot", (mat0_fp32 + mat1_fp32 + mat2_fp32,)),
+        ]
         self.methods_fp16 = [
             ("__matmul__", mat0_fp32 + mat1_fp32)
         ]
diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py
index 5803a17..fadb42f 100644
--- a/torch/testing/_internal/common_methods_invocations.py
+++ b/torch/testing/_internal/common_methods_invocations.py
@@ -405,6 +405,29 @@
 
     return inputs
 
+def sample_inputs_linalg_multi_dot(op_info, device, dtype, requires_grad):
+    # Each test case consists of the sizes in the chain of multiplications
+    # e.g. [2, 3, 4, 5] generates matrices (2, 3) @ (3, 4) @ (4, 5)
+    test_cases = [
+        [1, 2, 1],
+        [2, 0, 2],
+        [0, 2, 2],
+        [2, 2, 2, 2],
+        [2, 3, 4, 5],
+        [5, 4, 0, 2],
+        [2, 4, 3, 5, 3, 2]
+    ]
+
+    result = []
+    for sizes in test_cases:
+        tensors = []
+        for size in zip(sizes[:-1], sizes[1:]):
+            t = make_tensor(size, device, dtype, requires_grad=requires_grad)
+            tensors.append(t)
+        result.append(SampleInput(tensors))
+
+    return result
+
 def sample_inputs_linalg_norm(op_info, device, dtype, requires_grad):
     test_sizes = [
         (S,),
@@ -2822,6 +2845,19 @@
            supports_inplace_autograd=False,
            decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack, skipCUDAIfRocm],
            sample_inputs_func=sample_inputs_linalg_matrix_power,),
+    OpInfo('linalg.multi_dot',
+           # Need this lambda because gradcheck does not work with TensorList inputs
+           aten_name='linalg_multi_dot',
+           dtypes=floating_and_complex_types_and(torch.half),
+           dtypesIfCPU=all_types_and_complex_and(torch.half, torch.bfloat16),
+           dtypesIfCUDA=floating_and_complex_types_and(torch.half, *[torch.bfloat16] if CUDA11OrLater else []),
+           supports_inplace_autograd=False,
+           # Batched grad checks fail for empty input tensors (see https://github.com/pytorch/pytorch/issues/53407)
+           check_batched_grad=False,
+           check_batched_gradgrad=False,
+           sample_inputs_func=sample_inputs_linalg_multi_dot,
+           # test_variant_consistency_jit does not work with TensorList inputs
+           skips=(SkipInfo('TestCommon', 'test_variant_consistency_jit'),)),
     OpInfo('linalg.norm',
            op=torch.linalg.norm,
            dtypes=floating_and_complex_types_and(torch.float16, torch.bfloat16),