Parallelize tensor contraction over the inner dimension in cases where where one or both of the outer dimensions (m and n) are small but k is large. This speeds up individual matmul microbenchmarks by up to 85%.

Naming below is BM_Matmul_M_K_N_THREADS, measured on a 2-socket Intel Broadwell-based server.

Benchmark                          Base (ns)  New (ns) Improvement
------------------------------------------------------------------
BM_Matmul_1_80_13522_1                  387457    396013     -2.2%
BM_Matmul_1_80_13522_2                  406487    230789    +43.2%
BM_Matmul_1_80_13522_4                  395821    123211    +68.9%
BM_Matmul_1_80_13522_6                  391625     97002    +75.2%
BM_Matmul_1_80_13522_8                  408986    113828    +72.2%
BM_Matmul_1_80_13522_16                 399988     67600    +83.1%
BM_Matmul_1_80_13522_22                 411546     60044    +85.4%
BM_Matmul_1_80_13522_32                 393528     57312    +85.4%
BM_Matmul_1_80_13522_44                 390047     63525    +83.7%
BM_Matmul_1_80_13522_88                 387876     63592    +83.6%
BM_Matmul_1_1500_500_1                  245359    248119     -1.1%
BM_Matmul_1_1500_500_2                  401833    143271    +64.3%
BM_Matmul_1_1500_500_4                  210519    100231    +52.4%
BM_Matmul_1_1500_500_6                  251582     86575    +65.6%
BM_Matmul_1_1500_500_8                  211499     80444    +62.0%
BM_Matmul_3_250_512_1                    70297     68551     +2.5%
BM_Matmul_3_250_512_2                    70141     52450    +25.2%
BM_Matmul_3_250_512_4                    67872     58204    +14.2%
BM_Matmul_3_250_512_6                    71378     63340    +11.3%
BM_Matmul_3_250_512_8                    69595     41652    +40.2%
BM_Matmul_3_250_512_16                   72055     42549    +40.9%
BM_Matmul_3_250_512_22                   70158     54023    +23.0%
BM_Matmul_3_250_512_32                   71541     56042    +21.7%
BM_Matmul_3_250_512_44                   71843     57019    +20.6%
BM_Matmul_3_250_512_88                   69951     54045    +22.7%
BM_Matmul_3_1500_512_1                  369328    374284     -1.4%
BM_Matmul_3_1500_512_2                  428656    223603    +47.8%
BM_Matmul_3_1500_512_4                  205599    139508    +32.1%
BM_Matmul_3_1500_512_6                  214278    139071    +35.1%
BM_Matmul_3_1500_512_8                  184149    142338    +22.7%
BM_Matmul_3_1500_512_16                 156462    156983     -0.3%
BM_Matmul_3_1500_512_22                 163905    158259     +3.4%
BM_Matmul_3_1500_512_32                 155314    157662     -1.5%
BM_Matmul_3_1500_512_44                 235434    158657    +32.6%
BM_Matmul_3_1500_512_88                 156779    160275     -2.2%
BM_Matmul_1500_4_512_1                  363358    349528     +3.8%
BM_Matmul_1500_4_512_2                  303134    263319    +13.1%
BM_Matmul_1500_4_512_4                  176208    130086    +26.2%
BM_Matmul_1500_4_512_6                  148026    115449    +22.0%
BM_Matmul_1500_4_512_8                  131656     98421    +25.2%
BM_Matmul_1500_4_512_16                 134011     82861    +38.2%
BM_Matmul_1500_4_512_22                 134950     85685    +36.5%
BM_Matmul_1500_4_512_32                 133165     90081    +32.4%
BM_Matmul_1500_4_512_44                 133203     90644    +32.0%
BM_Matmul_1500_4_512_88                 134106    100566    +25.0%
BM_Matmul_4_1500_512_1                  439243    435058     +1.0%
BM_Matmul_4_1500_512_2                  451830    257032    +43.1%
BM_Matmul_4_1500_512_4                  276434    164513    +40.5%
BM_Matmul_4_1500_512_6                  182542    144827    +20.7%
BM_Matmul_4_1500_512_8                  179411    166256     +7.3%
BM_Matmul_4_1500_512_16                 158101    155560     +1.6%
BM_Matmul_4_1500_512_22                 152435    155448     -1.9%
BM_Matmul_4_1500_512_32                 155150    149538     +3.6%
BM_Matmul_4_1500_512_44                 193842    149777    +22.7%
BM_Matmul_4_1500_512_88                 149544    154468     -3.3%
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h b/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h
index 3b22e43..ea17a89 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h
@@ -590,6 +590,25 @@
 
     // zero out the result buffer (which must be of size at least m * n * sizeof(Scalar)
     this->m_device.memset(buffer, 0, m * n * sizeof(Scalar));
+    this->template evalGemmPartial<lhs_inner_dim_contiguous,
+                                   rhs_inner_dim_contiguous,
+                                   rhs_inner_dim_reordered, Alignment>(buffer,
+                                                                       0, k, 1);
+  }
+
+  template <bool lhs_inner_dim_contiguous, bool rhs_inner_dim_contiguous, bool rhs_inner_dim_reordered, int Alignment>
+  EIGEN_DEVICE_FUNC void evalGemmPartial(Scalar* buffer, Index k_start, Index k_end, int num_threads) const {
+    // columns in left side, rows in right side
+    const Index k = this->m_k_size;
+
+    eigen_assert(k_end >= k_start && k_start >= 0 && k_end <= k);
+    const Index k_slice = k_end - k_start;
+
+    // rows in left side
+    const Index m = this->m_i_size;
+
+    // columns in right side
+    const Index n = this->m_j_size;
 
     // define mr, nr, and all of my data mapper types
     typedef typename internal::remove_const<typename EvalLeftArgType::Scalar>::type LhsScalar;
@@ -620,7 +639,7 @@
     typedef internal::blas_data_mapper<Scalar, Index, ColMajor> OutputMapper;
 
     // Declare GEBP packing and kernel structs
-    internal::gemm_pack_lhs<LhsScalar, Index, typename LhsMapper::SubMapper, mr, Traits::LhsProgress, typename Traits::LhsPacket4Packing, ColMajor> pack_lhs;
+    internal::gemm_pack_lhs<LhsScalar, Index, typename LhsMapper::SubMapper, mr, Traits::LhsProgress,  typename Traits::LhsPacket4Packing, ColMajor> pack_lhs;
     internal::gemm_pack_rhs<RhsScalar, Index, typename RhsMapper::SubMapper, nr, ColMajor> pack_rhs;
 
     internal::gebp_kernel<LhsScalar, RhsScalar, Index, OutputMapper, mr, nr, false, false> gebp;
@@ -635,7 +654,7 @@
     OutputMapper output(buffer, m);
 
     // Sizes of the blocks to load in cache. See the Goto paper for details.
-    internal::TensorContractionBlocking<LhsScalar, RhsScalar, Index, internal::ShardByCol> blocking(k, m, n, 1);
+    internal::TensorContractionBlocking<LhsScalar, RhsScalar, Index, internal::ShardByCol> blocking(k_slice, m, n, num_threads);
     const Index kc = blocking.kc();
     const Index mc = numext::mini(m, blocking.mc());
     const Index nc = numext::mini(n, blocking.nc());
@@ -648,7 +667,7 @@
     for(Index i2=0; i2<m; i2+=mc)
     {
       const Index actual_mc = numext::mini(i2+mc,m)-i2;
-      for (Index k2 = 0; k2 < k; k2 += kc) {
+      for (Index k2 = k_start; k2 < k_end; k2 += kc) {
         // make sure we don't overshoot right edge of left matrix, then pack vertical panel
         const Index actual_kc = numext::mini(k2 + kc, k) - k2;
         pack_lhs(blockA, lhs.getSubMapper(i2, k2), actual_kc, actual_mc, 0, 0);
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h b/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h
index 0980854..57fe7cf 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h
@@ -147,6 +147,14 @@
         contractionCost(m, n, bm, bn, bk, shard_by_col, false);
     int num_threads = TensorCostModel<ThreadPoolDevice>::numThreads(
         static_cast<double>(n) * m, cost, this->m_device.numThreads());
+    int num_threads_by_k = numThreadsInnerDim(m, n, k);
+    if (false && shardByInnerDim(m, n, k, num_threads, num_threads_by_k)) {
+      // We are in the scenario where it is more effective to shard by the
+      // inner dimension.
+      this->template evalShardedByInnerDim<Alignment>(num_threads_by_k,
+                                                      buffer);
+      return;
+    }
 
     // TODO(dvyukov): this is a stop-gap to prevent regressions while the cost
     // model is not tuned. Remove this when the cost model is tuned.
@@ -242,9 +250,9 @@
         contract_t, internal::packet_traits<RhsScalar>::size,
         rhs_inner_dim_contiguous, rhs_inner_dim_reordered, Unaligned>
         RhsMapper;
-    typedef internal::gemm_pack_lhs<LhsScalar, Index,
-                                    typename LhsMapper::SubMapper, Traits::mr,
-                                    Traits::LhsProgress, typename Traits::LhsPacket4Packing, ColMajor>
+    typedef internal::gemm_pack_lhs<
+        LhsScalar, Index, typename LhsMapper::SubMapper, Traits::mr,
+        Traits::LhsProgress, typename Traits::LhsPacket4Packing, ColMajor>
         LhsPacker;
     typedef internal::gemm_pack_rhs<
         RhsScalar, Index, typename RhsMapper::SubMapper, Traits::nr, ColMajor>
@@ -709,20 +717,9 @@
                                           PacketType<RhsScalar, Device>::size);
     const int output_packet_size = internal::unpacket_traits<PacketReturnType>::size;
     const double kd = static_cast<double>(bk);
-    // Peak VFMA bandwidth is 0.5. However if we have not enough data for
-    // vectorization bandwidth drops. The 4.0 and 2.0 bandwidth is determined
-    // experimentally.
-    double computeBandwidth = bk == 1 ? 4.0 :
-          (shard_by_col ? bn : bm) < Traits::nr ||
-          (shard_by_col ? bm : bn) < Traits::mr ? 2.0 : 0.5;
-#ifndef EIGEN_VECTORIZE_FMA
-    // Bandwidth of all of VFMA/MULPS/ADDPS is 0.5 on latest Intel processors.
-    // However for MULPS/ADDPS we have dependent sequence of 2 such instructions,
-    // so overall bandwidth is 1.0.
-    if (computeBandwidth == 0.5) computeBandwidth = 1.0;
-#endif
+    double compute_bandwidth = computeBandwidth(false, bm, bn, bk);
     // Computations.
-    TensorOpCost cost = TensorOpCost(0, 0, kd * computeBandwidth, true, packed_size);
+    TensorOpCost cost = TensorOpCost(0, 0, kd * compute_bandwidth, true, packed_size);
     // Output stores.
     cost += TensorOpCost(0, sizeof(CoeffReturnType), 0, true, output_packet_size);
     if (prepacked) {
@@ -743,6 +740,162 @@
     return cost + lhsCost + rhsCost;
   }
 
+  template <int Alignment>
+  EIGEN_STRONG_INLINE void addToBuffer(size_t n, const Scalar* src_buf,
+                                       Scalar* tgt_buf) const {
+    const int output_packet_size = internal::unpacket_traits<PacketReturnType>::size;
+    size_t i = 0;
+    const size_t num_packets = n / output_packet_size;
+    for (; i < output_packet_size * num_packets; i += output_packet_size) {
+      const PacketReturnType src_val =
+          internal::pload<PacketReturnType>(src_buf + i);
+      const PacketReturnType tgt_val =
+          internal::ploadt<PacketReturnType, Alignment>(tgt_buf + i);
+      const PacketReturnType sum = internal::padd(src_val, tgt_val);
+      internal::pstoret<Scalar, PacketReturnType, Alignment>(tgt_buf + i, sum);
+    }
+    for (; i < n; ++i) {
+      tgt_buf[i] += src_buf[i];
+    }
+  }
+
+  // Decide whether we want to shard m x k x n contraction over the inner
+  // (contraction) dimension (k).
+  static bool shardByInnerDim(Index m, Index n, Index k, int num_threads,
+                              int num_threads_by_k) {
+    size_t bufsize = m * n * sizeof(Scalar);
+    bool shard_by_k = false;
+    if (n == 1 ||                // If mat*vec or...
+        num_threads_by_k < 2 ||  // running single threaded or...
+        num_threads_by_k <
+            num_threads ||  // sharding by k gives less parallelism or...
+        bufsize > l3CacheSize() / num_threads_by_k ||  // need more buffer space
+        // than L3 cache or...
+        k / num_threads_by_k < 2 * Traits::nr) {  // k per thread is tiny.
+      shard_by_k = false;
+    } else if (numext::maxi(m, n) / num_threads <
+                   Traits::nr ||  // both other dimensions are tiny or...
+               // k per thread is not small and...
+               (k / num_threads_by_k > 8 * Traits::nr &&
+                // one of the outer dimensions is tiny or sharding by k offers
+                // more parallelism.
+                (numext::mini(m, n) < 2 * Traits::nr ||
+                 num_threads_by_k > num_threads))) {
+      shard_by_k = true;
+    }
+    return shard_by_k;
+  }
+
+  template <int Alignment>
+  void evalShardedByInnerDim(int num_threads, Scalar* result) const {
+    const Index m = this->m_i_size;
+    const Index n = this->m_j_size;
+    const Index k = this->m_k_size;
+    // The underlying GEMM kernel assumes that k is a multiple of 8 and
+    // subtle breakage occurs if this is violated.
+    Index block_size = 8 * divup<Index>(k, 8 * num_threads);
+    int num_blocks = divup<Index>(k, block_size);
+    // we use 'result' for the first block's partial result.
+    MaxSizeVector<Scalar*> block_buffers(num_blocks - 1);
+    Barrier barrier(num_blocks);
+    auto process_block = [=, &barrier](Scalar* buf, Index first, Index last) {
+      ::memset(buf, 0, m * n * sizeof(Scalar));
+      TENSOR_CONTRACTION_DISPATCH(
+          this->template evalGemmPartial, Alignment,
+          (buf, first, last, this->m_device.numThreads()));
+      barrier.Notify();
+    };
+    Index start = 0;
+    for (int blocks_left = num_blocks; blocks_left > 0; --blocks_left) {
+      // The underlying GEMM kernel assumes that k is a multiple of 8 and
+      // subtle breakage occurs if this is violated.
+      block_size = 8 * divup<Index>(k - start, 8 * blocks_left);
+      Scalar* buf;
+      if (start == 0) {
+        buf = result;
+      } else {
+        buf = static_cast<Scalar*>(
+            this->m_device.allocate(m * n * sizeof(Scalar)));
+        block_buffers.push_back(buf);
+      }
+      Index end = start + block_size;
+      if (end > k) {
+        end = k;
+      }
+      this->m_device.enqueueNoNotification(
+          [=, &process_block]() { process_block(buf, start, end); });
+      start = end;
+    }
+    barrier.Wait();
+
+    // Add other partial results into first partial result.
+    for (const auto& buf : block_buffers) {
+      addToBuffer<Alignment>(m * n, buf, result);
+      this->m_device.deallocate(buf);
+    }
+  }
+
+  TensorOpCost contractionCostPerInnerDim(Index m, Index n, Index k) const {
+    // Compute cost.
+    const int output_packet_size = internal::unpacket_traits<PacketReturnType>::size;
+    TensorOpCost cost(0, 0, (computeBandwidth(true, m, n, k) * m) * n);
+    // Output stores.
+    cost += TensorOpCost(0, sizeof(CoeffReturnType), 0, true, output_packet_size);
+    TensorOpCost lhsCost = this->m_leftImpl.costPerCoeff(true) * m;
+    TensorOpCost rhsCost = this->m_rightImpl.costPerCoeff(true) * n;
+    // Since the inner gemm kernel is always sharded by column, the lhs
+    // load cost is negligible.
+    lhsCost.dropMemoryCost();
+    return cost + lhsCost + rhsCost;
+  }
+
+  int numThreadsInnerDim(Index m, Index n, Index k) const {
+    const int output_packet_size = internal::unpacket_traits<PacketReturnType>::size;
+    TensorOpCost cost = contractionCostPerInnerDim(m, n, k);
+    double total_parallel_cost =
+        TensorCostModel<ThreadPoolDevice>::totalCost(k, cost);
+    // Cost of reduction step accumulating the m*n per-thread buffers into the
+    // result.
+    double reduction_cost = TensorCostModel<ThreadPoolDevice>::totalCost(
+        m * n, TensorOpCost(2, 1, 1, true, output_packet_size));
+    Index num_threads = 1;
+    double min_cost = total_parallel_cost;
+    double kPerThreadOverHead = 4000;
+    double kFixedOverHead = 100000;
+    for (int nt = 2; nt <= this->m_device.numThreads(); nt++) {
+      double sequential_cost =
+          kFixedOverHead + nt * (reduction_cost + kPerThreadOverHead);
+      double parallel_cost = total_parallel_cost / nt + sequential_cost;
+      if (parallel_cost < min_cost) {
+        num_threads = nt;
+        min_cost = parallel_cost;
+      }
+    }
+    return num_threads;
+  }
+
+
+  double computeBandwidth(bool shard_by_col, Index bm, Index bn,
+                          Index bk) const {
+    // Peak VFMA bandwidth is 0.5. However if we have not enough data for
+    // vectorization bandwidth drops. The 4.0 and 2.0 bandwidth is determined
+    // experimentally.
+    double computeBandwidth =
+        bk == 1 ? 4.0
+                : (shard_by_col ? bn : bm) < Traits::nr ||
+                          (shard_by_col ? bm : bn) < Traits::mr
+                      ? 2.0
+                      : 0.5;
+#ifndef EIGEN_VECTORIZE_FMA
+    // Bandwidth of all of VFMA/MULPS/ADDPS is 0.5 on latest Intel processors.
+    // However for MULPS/ADDPS we have dependent sequence of 2 such
+    // instructions,
+    // so overall bandwidth is 1.0.
+    if (computeBandwidth == 0.5) computeBandwidth = 1.0;
+#endif
+    return computeBandwidth;
+  }
+
 #if defined(EIGEN_VECTORIZE_AVX) && defined(EIGEN_USE_LIBXSMM)
   // TODO(ezhulenev): Add support for output kernels and LIBXSMM.
   static_assert(std::is_same<OutputKernelType, const NoOpOutputKernel>::value,
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorCostModel.h b/unsupported/Eigen/CXX11/src/Tensor/TensorCostModel.h
index bb63bae..7f79ac3 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/TensorCostModel.h
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorCostModel.h
@@ -188,7 +188,6 @@
     return totalCost(output_size, cost_per_coeff) / kTaskSize;
   }
 
- private:
   static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE double totalCost(
       double output_size, const TensorOpCost& cost_per_coeff) {
     // Cost of memory fetches from L2 cache. 64 is typical cache line size.