Add optimal triton kernel parameters to bsr_dense_mm and scatter_mm for bfloat16 and float32 dtypes (#113553)

As in the title.

This PR is a follow-up to PR https://github.com/pytorch/pytorch/pull/112737 to address bfloat16 and float32 dtype cases. The performance increase is as follows (`NVIDIA A100-SXM4-80GB`):

- bsr_scatter_mm and bfloat16
  - for blocksize 16x16, the average/maximum speed up is about 29/75 %.
  - for blocksize 32x32, the average/maximum speed up is about 23/58 %.
  - for blocksize 64x64, the average/maximum speed up is about 27/66 %.
  - for blocksize 128x128, the average/maximum speed up is about 33/72 %.
- bsr_dense_mm and bfloat16
  - for blocksize 16x16, the average/maximum speed up is about 47/61 %.
  - for blocksize 32x32, the average/maximum speed up is about 29/43 %.
  - for blocksize 64x64, the average/maximum speed up is about 21/41 %.
  - for blocksize 128x128, the average/maximum speed up is about 12/29 %.
- bsr_dense_mm and  float32
  - for blocksize 16x16, the average/maximum speed up is about 35/49 %.
  - for blocksize 32x32, the average/maximum speed up is about 2/5 %.
  - for blocksize 64x64, the average/maximum speed up is about 2/21 %.
  - for blocksize 128x128, the average/maximum speed up is about 79/84 %.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/113553
Approved by: https://github.com/cpuhrsch
diff --git a/torch/sparse/_triton_ops.py b/torch/sparse/_triton_ops.py
index 93b1a27..6ff12bb 100644
--- a/torch/sparse/_triton_ops.py
+++ b/torch/sparse/_triton_ops.py
@@ -1202,7 +1202,8 @@
         out: Optional[torch.Tensor] = None,
         skip_checks: bool = False,
         max_grid: Optional[Tuple[Optional[int], Optional[int], Optional[int]]] = None,
-        meta: Optional[dict] = None
+        meta: Optional[dict] = None,
+        enable_bsr_scatter_mm: bool = True
     ):
         f_name = "bsr_dense_mm"
         m, kl = bsr.shape[-2:]
@@ -1249,13 +1250,19 @@
 
         blocksize = bsr.values().shape[-2:]
 
-        if max(blocksize) == 16 and bsr.dense_dim() == 0 and bsr.ndim == 2:
+        if enable_bsr_scatter_mm and max(blocksize) == 16 and bsr.dense_dim() == 0 and bsr.ndim == 2:
+            dtype = bsr.dtype
             # bsr_scatter_mm is more performant than bsr_dense_mm for
             # 16x16 blocksizes and large enough input shapes:
             if (
-                    (m >= 4096 and n >= 8192)
-                    or (m == 2048 and n >= 32768)
-                    or (n >= 131072)
+                    (dtype in {torch.float16, torch.bfloat16}
+                     and ((m >= 4096 and n >= 8192)
+                          or (m == 2048 and n >= 32768)
+                          or (n >= 131072))) or
+                    (dtype == torch.float32
+                     and (m >= 1024
+                          or (m == 512 and n >= 512)
+                          or (m == 256 and n >= 2048)))
             ):
                 return bsr_scatter_mm(bsr, dense, out=out)
 
diff --git a/torch/sparse/_triton_ops_meta.py b/torch/sparse/_triton_ops_meta.py
index 80d642c..a014d1c 100644
--- a/torch/sparse/_triton_ops_meta.py
+++ b/torch/sparse/_triton_ops_meta.py
@@ -89,10 +89,18 @@
             " BEGIN/END GENERATED DATA comment blocks appear to be corrupted"
         )
         return
+
+    def sort_key(key):
+        op, device_name, version = key
+        version = tuple(
+            (str(item) if isinstance(item, torch.dtype) else item) for item in version
+        )
+        return (op, device_name, version)
+
     part1 = current_content[: begin_data_index + len(begin_data_str)]
     part2 = current_content[end_data_index:]
     data_part = []
-    for op_key in sorted(_operation_device_version_data):
+    for op_key in sorted(_operation_device_version_data, key=sort_key):
         data_part.append("    " + repr(op_key).replace("'", '"') + ": {")
         op_data = _operation_device_version_data[op_key]
         for key in sorted(op_data):
@@ -105,7 +113,9 @@
         f.close()
 
 
-def minimize(target_func, initial_parameters, step_func):
+def minimize(
+    target_func, initial_parameters, reference_parameters, step_func, max_step=2
+):
     """Find a dict of parameters that minimizes the target function using
     the initial dict of parameters and a step function that progresses
     a specified parameter in a dict of parameters.
@@ -116,6 +126,8 @@
       ``target_func(parameters: dict) -> float``
     initial_parameters (dict): a set of parameters used as an initial
       value to the minimization process.
+    reference_parameters (dict): a set of parameters used as an
+      reference value with respect to which the speed up is computed.
     step_func (callable): a functional with the signature
       ``step_func(parameter_name:str, parameter_value:int, direction:int, parameters:dict) -> int``
       that increments or decrements (when ``direction`` is positive or
@@ -128,6 +140,7 @@
     parameters (dict): a set of parameters that minimizes the target
       function.
     speedup_incr (float): a speedup change given in percentage
+
     """
 
     def to_key(parameters):
@@ -137,13 +150,34 @@
         return dict(zip(sorted(parameters), key))
 
     all_values = dict()
+
+    try:
+        reference_target = target_func(reference_parameters)
+    except Exception as msg:
+        print(f"{reference_parameters=} lead to failure: {msg}.")
+        reference_target = None
+    if reference_target is not None:
+        all_values[to_key(reference_parameters)] = reference_target
+
     parameters = initial_parameters
     try:
         initial_target = target_func(parameters)
     except Exception as msg:
-        print(f"{parameters=} lead to failure: {msg}. Skipping.")
-        return parameters, -1
-    all_values[to_key(parameters)] = initial_target
+        if reference_target is None:
+            print(f"{initial_parameters=} lead to failure: {msg}. Optimization failed!")
+            return {}, -1, None
+        print(
+            f"{initial_parameters=} lead to failure: {msg}. Using reference parameters instead of initial parameters."
+        )
+        parameters = reference_parameters
+        initial_target = reference_target
+
+    if reference_target is None:
+        print("Using initial parameters instead of reference parameters.")
+        reference_target = initial_target
+
+    initial_key = to_key(parameters)
+    all_values[initial_key] = initial_target
 
     while True:
         current_key = to_key(parameters)
@@ -152,7 +186,9 @@
         new_minimizer = False
         for name in parameters:
             value = parameters[name]
-            for direction in [1, -1]:
+            for direction in range(-max_step, max_step + 1):
+                if direction == 0:
+                    continue
                 next_value = step_func(name, value, direction, parameters)
                 if next_value == value:
                     continue
@@ -175,8 +211,30 @@
         if new_minimizer:
             parameters = from_key(minimizer_key, parameters)
         else:
-            speedup_incr = (1 - minimizer_target / initial_target) * 100
-            return parameters, speedup_incr
+            # ensure stable minimizer:
+            minimizer_keys = {
+                k
+                for k, v in all_values.items()
+                if isinstance(v, float) and abs(1 - v / minimizer_target) < 0.001
+            }
+            minimizer_key = (
+                initial_key if initial_key in minimizer_keys else min(minimizer_keys)
+            )
+            minimizer_target = all_values[minimizer_key]
+            parameters = from_key(minimizer_key, parameters)
+            speedup_incr = (1 - minimizer_target / reference_target) * 100
+            if speedup_incr < 0:
+                print(
+                    f"{speedup_incr=} is negative. Rerunning minimize with reference parameters as initial parameters."
+                )
+                return minimize(
+                    target_func,
+                    reference_parameters,
+                    reference_parameters,
+                    step_func,
+                    max_step=max_step,
+                )
+            return parameters, speedup_incr, minimizer_target
 
 
 def create_blocked_tensor(B, M, N, blocksize, sparsity, dtype, device):
@@ -212,20 +270,37 @@
     from torch.sparse._triton_ops import bsr_scatter_mm, bsr_scatter_mm_indices_data
 
     key = (m, k, n, bm, bk)
-    version = (0, dtype, sparsity)
 
-    initial_meta = get_meta("scatter_mm", key, version=version, exact=True)
+    version = (0, dtype, sparsity)
+    device_name = torch.cuda.get_device_name()
+
+    reference_meta = dict(
+        GROUP_SIZE=1,
+        TILE_M=16,
+        TILE_N=16,
+        SPLIT_N=n // 16,
+        num_stages=1,
+        num_warps=1,
+    )
+
+    initial_meta = get_meta(
+        "scatter_mm", key, device_name=device_name, version=version, exact=True
+    )
 
     if initial_meta is None:
-        initial_meta = get_meta("scatter_mm", key, version=(0, dtype, 0.5), exact=True)
+        initial_meta = get_meta(
+            "bsr_dense_mm",
+            key,
+            device_name=device_name,
+            version=(0, dtype, 0.5),
+            exact=True,
+        )
         if initial_meta is None:
-            initial_meta = dict(
-                GROUP_SIZE=1, TILE_M=16, TILE_N=16, SPLIT_N=1, num_warps=1, num_stages=1
-            )
+            initial_meta = reference_meta
     elif not force:
         return
 
-    print(f"{m, k, n, bm, bk=}")
+    print(f"{m, k, n, bm, bk, initial_meta, reference_meta=}")
     torch.manual_seed(0)
     bsr = create_blocked_tensor(
         0, m, k, (bm, bk), sparsity, dtype, device
@@ -250,6 +325,7 @@
         # return next value in positive or negative direction, or
         # input value if the step will result an invalid
         # value. The input value is assumed to be valid.
+
         is_log = name in {"SPLIT_N", "TILE_M", "TILE_N", "num_warps"}
         min_value = dict(
             SPLIT_N=1, TILE_M=16, TILE_N=16, num_warps=1, num_stages=1, GROUP_SIZE=1
@@ -261,22 +337,39 @@
             SPLIT_N=2, TILE_M=2, TILE_N=2, num_warps=2, num_stages=1, GROUP_SIZE=1
         )[name]
         if is_log:
-            next_value = value * value_step if direction > 0 else value // value_step
+            next_value = (
+                value * value_step**direction
+                if direction > 0
+                else value // (value_step ** abs(direction))
+            )
         else:
-            next_value = value + value_step if direction > 0 else value - value_step
+            next_value = value + value_step * direction
         if min_value is not None:
             next_value = max(next_value, min_value)
         if max_value is not None:
             next_value = min(next_value, max_value)
         if name == "SPLIT_N" and n % next_value != 0:
             return value
+        # Hard-skip parameter combinations that break CUDA state for pytorch:
+        if (dtype, name, next_value, m, n, k, bm, bk) in {
+            (torch.float32, "num_warps", 32, 256, 256, 256, 16, 16),
+            (torch.float32, "num_warps", 16, 256, 256, 256, 32, 32),
+            (torch.float32, "num_warps", 16, 256, 256, 256, 64, 64),
+            (torch.float32, "num_warps", 16, 256, 256, 256, 128, 128),
+            (torch.float32, "num_warps", 16, 512, 512, 256, 128, 128),
+        } and re.match(r"NVIDIA A100[^\d]", device_name) is not None:
+            return value
         return next_value
 
-    meta, speedup = minimize(bench, initial_meta, step_meta_parameter)
-    if speedup < 3 and 0:
-        # don't bother updating parameters when the speed up change is less than 3 %
+    meta, speedup, timing = minimize(
+        bench, initial_meta, reference_meta, step_meta_parameter
+    )
+    print(f"{meta=} {speedup=:.1f} % {timing=:.3f} ms")
+    if initial_meta is not reference_meta and initial_meta == meta:
         return
-    print(f"{meta=} {speedup=:.1f} %")
+    print(f"{meta=} {speedup=:.1f} % {timing=:.3f} ms")
+    if speedup < 0:
+        return
     device_name = torch.cuda.get_device_name()
 
     update(
@@ -294,6 +387,8 @@
     key = (m, k, n, bm, bk)
     version = (0, dtype, sparsity)
 
+    reference_meta = dict(GROUP_SIZE_ROW=1, num_stages=1, num_warps=4)
+
     initial_meta = get_meta("bsr_dense_mm", key, version=version, exact=True)
 
     if initial_meta is None:
@@ -301,11 +396,11 @@
             "bsr_dense_mm", key, version=(0, dtype, 0.5), exact=True
         )
         if initial_meta is None:
-            initial_meta = dict(GROUP_SIZE_ROW=1, num_stages=1, num_warps=1)
+            initial_meta = reference_meta
     elif not force:
         return
 
-    print(f"{m, k, n, bm, bk=}")
+    print(f"{m, k, n, bm, bk, initial_meta=}")
     torch.manual_seed(0)
     bsr = create_blocked_tensor(
         0, m, k, (bm, bk), sparsity, dtype, device
@@ -331,20 +426,27 @@
         max_value = dict().get(name)
         value_step = dict(num_warps=2, num_stages=1, GROUP_SIZE_ROW=1)[name]
         if is_log:
-            next_value = value * value_step if direction > 0 else value // value_step
+            next_value = (
+                value * value_step**direction
+                if direction > 0
+                else value // (value_step ** abs(direction))
+            )
         else:
-            next_value = value + value_step if direction > 0 else value - value_step
+            next_value = value + value_step * direction
         if min_value is not None:
             next_value = max(next_value, min_value)
         if max_value is not None:
             next_value = min(next_value, max_value)
         return next_value
 
-    meta, speedup = minimize(bench, initial_meta, step_meta_parameter)
-    if speedup < 3 and 0:
-        # don't bother updating parameters when the speed up change is less than 3 %
+    meta, speedup, timing = minimize(
+        bench, initial_meta, reference_meta, step_meta_parameter, max_step=2
+    )
+    if initial_meta is not reference_meta and initial_meta == meta:
         return
-    print(f"{meta=} {speedup=:.1f} %")
+    print(f"{meta=} {speedup=:.1f} % {timing=:.3f} ms")
+    if speedup < 0:
+        return
     device_name = torch.cuda.get_device_name()
 
     update(
@@ -352,10 +454,9 @@
     )
 
 
-def main(op="scatter_mm", force=False):
+def main(op="scatter_mm", force=False, dtype=torch.float16):
     import itertools
 
-    dtype = torch.float16
     sizes_lst = [256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072]
     shapes_lst = [(sz, sz) for sz in sizes_lst[:-3]]
     blocksize_lst = [(16, 16), (32, 32), (64, 64), (128, 128)]
@@ -367,10 +468,12 @@
                 shapes_lst, sizes_lst, blocksize_lst
             ):
                 if op == "scatter_mm":
-                    optimize_scatter_mm(M, K, N, BM, BK, force=force, sparsity=sparsity)
+                    optimize_scatter_mm(
+                        M, K, N, BM, BK, force=force, sparsity=sparsity, dtype=dtype
+                    )
                 elif op == "bsr_dense_mm":
                     optimize_bsr_dense_mm(
-                        M, K, N, BM, BK, force=force, sparsity=sparsity
+                        M, K, N, BM, BK, force=force, sparsity=sparsity, dtype=dtype
                     )
                 else:
                     raise NotImplementedError(op)
@@ -378,7 +481,7 @@
             break
         except Exception as msg:
             dump()
-            print(msg)
+            raise
     dump()
 
     if 0:
@@ -486,6 +589,288 @@
     #   bsr_dense_mm : M, K, N, Ms, Ks -> GROUP_SIZE_ROW, num_stages, num_warps
     #
     # BEGIN GENERATED DATA
+    ("bsr_dense_mm", "NVIDIA A100-SXM4-80GB", (0, torch.bfloat16, 0.5)): {
+        (256, 256, 256, 16, 16): (3, 3, 1),
+        (256, 256, 256, 32, 32): (1, 5, 1),
+        (256, 256, 256, 64, 64): (2, 3, 2),
+        (256, 256, 256, 128, 128): (2, 2, 8),
+        (256, 256, 512, 16, 16): (2, 5, 1),
+        (256, 256, 512, 32, 32): (1, 4, 1),
+        (256, 256, 512, 64, 64): (1, 3, 2),
+        (256, 256, 512, 128, 128): (2, 2, 8),
+        (256, 256, 1024, 16, 16): (2, 3, 1),
+        (256, 256, 1024, 32, 32): (1, 4, 1),
+        (256, 256, 1024, 64, 64): (3, 3, 2),
+        (256, 256, 1024, 128, 128): (1, 2, 8),
+        (256, 256, 2048, 16, 16): (2, 3, 1),
+        (256, 256, 2048, 32, 32): (1, 3, 1),
+        (256, 256, 2048, 64, 64): (2, 3, 4),
+        (256, 256, 2048, 128, 128): (2, 2, 8),
+        (256, 256, 4096, 16, 16): (3, 3, 1),
+        (256, 256, 4096, 32, 32): (4, 3, 1),
+        (256, 256, 4096, 64, 64): (2, 3, 2),
+        (256, 256, 4096, 128, 128): (2, 2, 8),
+        (256, 256, 8192, 16, 16): (1, 3, 1),
+        (256, 256, 8192, 32, 32): (1, 1, 1),
+        (256, 256, 8192, 64, 64): (1, 1, 4),
+        (256, 256, 8192, 128, 128): (1, 1, 4),
+        (256, 256, 16384, 16, 16): (1, 3, 1),
+        (256, 256, 16384, 32, 32): (3, 3, 1),
+        (256, 256, 16384, 64, 64): (1, 1, 4),
+        (256, 256, 16384, 128, 128): (2, 1, 4),
+        (256, 256, 32768, 16, 16): (1, 3, 1),
+        (256, 256, 32768, 32, 32): (1, 3, 1),
+        (256, 256, 32768, 64, 64): (1, 2, 4),
+        (256, 256, 32768, 128, 128): (2, 1, 4),
+        (256, 256, 65536, 16, 16): (1, 3, 1),
+        (256, 256, 65536, 32, 32): (1, 3, 1),
+        (256, 256, 65536, 64, 64): (1, 2, 4),
+        (256, 256, 65536, 128, 128): (1, 1, 4),
+        (256, 256, 131072, 16, 16): (1, 1, 2),
+        (256, 256, 131072, 32, 32): (1, 1, 1),
+        (256, 256, 131072, 64, 64): (1, 2, 4),
+        (256, 256, 131072, 128, 128): (1, 1, 4),
+        (512, 512, 256, 16, 16): (2, 5, 1),
+        (512, 512, 256, 32, 32): (1, 5, 1),
+        (512, 512, 256, 64, 64): (3, 3, 4),
+        (512, 512, 256, 128, 128): (1, 2, 8),
+        (512, 512, 512, 16, 16): (2, 4, 1),
+        (512, 512, 512, 32, 32): (1, 5, 2),
+        (512, 512, 512, 64, 64): (1, 5, 4),
+        (512, 512, 512, 128, 128): (2, 2, 8),
+        (512, 512, 1024, 16, 16): (1, 3, 1),
+        (512, 512, 1024, 32, 32): (1, 4, 1),
+        (512, 512, 1024, 64, 64): (2, 4, 4),
+        (512, 512, 1024, 128, 128): (1, 2, 8),
+        (512, 512, 2048, 16, 16): (3, 3, 1),
+        (512, 512, 2048, 32, 32): (3, 3, 1),
+        (512, 512, 2048, 64, 64): (3, 3, 4),
+        (512, 512, 2048, 128, 128): (2, 2, 8),
+        (512, 512, 4096, 16, 16): (2, 3, 1),
+        (512, 512, 4096, 32, 32): (1, 3, 1),
+        (512, 512, 4096, 64, 64): (1, 3, 4),
+        (512, 512, 4096, 128, 128): (2, 1, 4),
+        (512, 512, 8192, 16, 16): (2, 3, 1),
+        (512, 512, 8192, 32, 32): (1, 3, 1),
+        (512, 512, 8192, 64, 64): (1, 3, 2),
+        (512, 512, 8192, 128, 128): (6, 2, 8),
+        (512, 512, 16384, 16, 16): (1, 3, 1),
+        (512, 512, 16384, 32, 32): (1, 3, 1),
+        (512, 512, 16384, 64, 64): (3, 3, 2),
+        (512, 512, 16384, 128, 128): (2, 1, 4),
+        (512, 512, 32768, 16, 16): (1, 2, 1),
+        (512, 512, 32768, 32, 32): (1, 3, 1),
+        (512, 512, 32768, 64, 64): (1, 3, 2),
+        (512, 512, 32768, 128, 128): (2, 1, 4),
+        (512, 512, 65536, 16, 16): (4, 3, 1),
+        (512, 512, 65536, 32, 32): (1, 3, 1),
+        (512, 512, 65536, 64, 64): (1, 3, 2),
+        (512, 512, 65536, 128, 128): (1, 1, 4),
+        (512, 512, 131072, 16, 16): (1, 2, 2),
+        (512, 512, 131072, 32, 32): (1, 1, 1),
+        (512, 512, 131072, 64, 64): (2, 3, 2),
+        (512, 512, 131072, 128, 128): (1, 1, 4),
+        (1024, 1024, 256, 16, 16): (1, 4, 1),
+        (1024, 1024, 256, 32, 32): (1, 5, 2),
+        (1024, 1024, 256, 64, 64): (3, 3, 2),
+        (1024, 1024, 256, 128, 128): (1, 2, 8),
+        (1024, 1024, 512, 16, 16): (2, 3, 1),
+        (1024, 1024, 512, 32, 32): (2, 4, 1),
+        (1024, 1024, 512, 64, 64): (2, 3, 4),
+        (1024, 1024, 512, 128, 128): (1, 2, 8),
+        (1024, 1024, 1024, 16, 16): (1, 3, 1),
+        (1024, 1024, 1024, 32, 32): (1, 3, 1),
+        (1024, 1024, 1024, 64, 64): (3, 3, 4),
+        (1024, 1024, 1024, 128, 128): (2, 2, 8),
+        (1024, 1024, 2048, 16, 16): (2, 3, 1),
+        (1024, 1024, 2048, 32, 32): (1, 4, 1),
+        (1024, 1024, 2048, 64, 64): (1, 3, 2),
+        (1024, 1024, 2048, 128, 128): (4, 2, 8),
+        (1024, 1024, 4096, 16, 16): (2, 3, 1),
+        (1024, 1024, 4096, 32, 32): (1, 4, 1),
+        (1024, 1024, 4096, 64, 64): (1, 3, 4),
+        (1024, 1024, 4096, 128, 128): (4, 2, 8),
+        (1024, 1024, 8192, 16, 16): (2, 3, 1),
+        (1024, 1024, 8192, 32, 32): (2, 3, 1),
+        (1024, 1024, 8192, 64, 64): (4, 3, 2),
+        (1024, 1024, 8192, 128, 128): (4, 1, 4),
+        (1024, 1024, 16384, 16, 16): (1, 2, 1),
+        (1024, 1024, 16384, 32, 32): (4, 3, 1),
+        (1024, 1024, 16384, 64, 64): (1, 3, 2),
+        (1024, 1024, 16384, 128, 128): (4, 1, 4),
+        (1024, 1024, 32768, 16, 16): (1, 2, 1),
+        (1024, 1024, 32768, 32, 32): (1, 3, 1),
+        (1024, 1024, 32768, 64, 64): (1, 3, 2),
+        (1024, 1024, 32768, 128, 128): (2, 1, 4),
+        (1024, 1024, 65536, 16, 16): (2, 2, 1),
+        (1024, 1024, 65536, 32, 32): (9, 3, 1),
+        (1024, 1024, 65536, 64, 64): (7, 3, 2),
+        (1024, 1024, 65536, 128, 128): (4, 1, 4),
+        (1024, 1024, 131072, 16, 16): (1, 1, 4),
+        (1024, 1024, 131072, 32, 32): (1, 1, 1),
+        (1024, 1024, 131072, 64, 64): (2, 3, 2),
+        (1024, 1024, 131072, 128, 128): (4, 1, 4),
+        (2048, 2048, 256, 16, 16): (4, 5, 1),
+        (2048, 2048, 256, 32, 32): (1, 4, 1),
+        (2048, 2048, 256, 64, 64): (1, 3, 4),
+        (2048, 2048, 256, 128, 128): (3, 2, 8),
+        (2048, 2048, 512, 16, 16): (2, 4, 1),
+        (2048, 2048, 512, 32, 32): (1, 3, 1),
+        (2048, 2048, 512, 64, 64): (1, 3, 2),
+        (2048, 2048, 512, 128, 128): (1, 2, 8),
+        (2048, 2048, 1024, 16, 16): (1, 3, 1),
+        (2048, 2048, 1024, 32, 32): (1, 4, 1),
+        (2048, 2048, 1024, 64, 64): (1, 3, 4),
+        (2048, 2048, 1024, 128, 128): (6, 2, 8),
+        (2048, 2048, 2048, 16, 16): (2, 3, 1),
+        (2048, 2048, 2048, 32, 32): (1, 4, 1),
+        (2048, 2048, 2048, 64, 64): (1, 3, 2),
+        (2048, 2048, 2048, 128, 128): (6, 2, 8),
+        (2048, 2048, 4096, 16, 16): (4, 3, 1),
+        (2048, 2048, 4096, 32, 32): (2, 4, 2),
+        (2048, 2048, 4096, 64, 64): (1, 3, 2),
+        (2048, 2048, 4096, 128, 128): (2, 1, 4),
+        (2048, 2048, 8192, 16, 16): (8, 2, 1),
+        (2048, 2048, 8192, 32, 32): (4, 3, 2),
+        (2048, 2048, 8192, 64, 64): (4, 3, 2),
+        (2048, 2048, 8192, 128, 128): (2, 1, 4),
+        (2048, 2048, 16384, 16, 16): (4, 2, 1),
+        (2048, 2048, 16384, 32, 32): (4, 4, 1),
+        (2048, 2048, 16384, 64, 64): (4, 3, 2),
+        (2048, 2048, 16384, 128, 128): (2, 1, 4),
+        (2048, 2048, 32768, 16, 16): (2, 2, 1),
+        (2048, 2048, 32768, 32, 32): (11, 4, 1),
+        (2048, 2048, 32768, 64, 64): (1, 1, 1),
+        (2048, 2048, 32768, 128, 128): (2, 1, 4),
+        (2048, 2048, 65536, 16, 16): (8, 2, 1),
+        (2048, 2048, 65536, 32, 32): (9, 3, 1),
+        (2048, 2048, 65536, 64, 64): (1, 1, 1),
+        (2048, 2048, 65536, 128, 128): (2, 1, 4),
+        (2048, 2048, 131072, 16, 16): (1, 1, 4),
+        (2048, 2048, 131072, 32, 32): (1, 1, 1),
+        (2048, 2048, 131072, 64, 64): (2, 3, 2),
+        (2048, 2048, 131072, 128, 128): (2, 1, 4),
+        (4096, 4096, 256, 16, 16): (4, 4, 1),
+        (4096, 4096, 256, 32, 32): (1, 3, 1),
+        (4096, 4096, 256, 64, 64): (1, 3, 4),
+        (4096, 4096, 256, 128, 128): (1, 2, 8),
+        (4096, 4096, 512, 16, 16): (1, 4, 1),
+        (4096, 4096, 512, 32, 32): (1, 5, 2),
+        (4096, 4096, 512, 64, 64): (1, 3, 4),
+        (4096, 4096, 512, 128, 128): (2, 2, 8),
+        (4096, 4096, 1024, 16, 16): (1, 3, 1),
+        (4096, 4096, 1024, 32, 32): (1, 4, 2),
+        (4096, 4096, 1024, 64, 64): (1, 4, 4),
+        (4096, 4096, 1024, 128, 128): (2, 2, 8),
+        (4096, 4096, 2048, 16, 16): (1, 3, 1),
+        (4096, 4096, 2048, 32, 32): (3, 4, 2),
+        (4096, 4096, 2048, 64, 64): (1, 3, 2),
+        (4096, 4096, 2048, 128, 128): (2, 2, 8),
+        (4096, 4096, 4096, 16, 16): (2, 3, 1),
+        (4096, 4096, 4096, 32, 32): (2, 4, 2),
+        (4096, 4096, 4096, 64, 64): (1, 3, 2),
+        (4096, 4096, 4096, 128, 128): (4, 1, 4),
+        (4096, 4096, 8192, 16, 16): (4, 2, 1),
+        (4096, 4096, 8192, 32, 32): (2, 4, 2),
+        (4096, 4096, 8192, 64, 64): (4, 3, 2),
+        (4096, 4096, 8192, 128, 128): (4, 1, 4),
+        (4096, 4096, 16384, 16, 16): (1, 1, 1),
+        (4096, 4096, 16384, 32, 32): (4, 4, 1),
+        (4096, 4096, 16384, 64, 64): (4, 3, 2),
+        (4096, 4096, 16384, 128, 128): (4, 1, 4),
+        (4096, 4096, 32768, 16, 16): (1, 1, 1),
+        (4096, 4096, 32768, 32, 32): (3, 4, 1),
+        (4096, 4096, 32768, 64, 64): (3, 3, 2),
+        (4096, 4096, 32768, 128, 128): (4, 1, 4),
+        (4096, 4096, 65536, 16, 16): (1, 1, 1),
+        (4096, 4096, 65536, 32, 32): (3, 4, 1),
+        (4096, 4096, 65536, 64, 64): (2, 3, 2),
+        (4096, 4096, 65536, 128, 128): (4, 1, 4),
+        (4096, 4096, 131072, 16, 16): (1, 1, 4),
+        (4096, 4096, 131072, 32, 32): (1, 1, 1),
+        (4096, 4096, 131072, 64, 64): (3, 3, 2),
+        (4096, 4096, 131072, 128, 128): (4, 1, 4),
+        (8192, 8192, 256, 16, 16): (4, 4, 1),
+        (8192, 8192, 256, 32, 32): (2, 1, 1),
+        (8192, 8192, 256, 64, 64): (4, 3, 8),
+        (8192, 8192, 256, 128, 128): (1, 1, 4),
+        (8192, 8192, 512, 16, 16): (2, 2, 1),
+        (8192, 8192, 512, 32, 32): (4, 4, 2),
+        (8192, 8192, 512, 64, 64): (4, 4, 4),
+        (8192, 8192, 512, 128, 128): (4, 2, 8),
+        (8192, 8192, 1024, 16, 16): (4, 5, 1),
+        (8192, 8192, 1024, 32, 32): (4, 4, 2),
+        (8192, 8192, 1024, 64, 64): (4, 3, 4),
+        (8192, 8192, 1024, 128, 128): (4, 2, 8),
+        (8192, 8192, 2048, 16, 16): (4, 5, 1),
+        (8192, 8192, 2048, 32, 32): (4, 4, 2),
+        (8192, 8192, 2048, 64, 64): (1, 1, 1),
+        (8192, 8192, 2048, 128, 128): (4, 1, 4),
+        (8192, 8192, 4096, 16, 16): (4, 3, 1),
+        (8192, 8192, 4096, 32, 32): (4, 4, 2),
+        (8192, 8192, 4096, 64, 64): (4, 3, 2),
+        (8192, 8192, 4096, 128, 128): (4, 1, 4),
+        (8192, 8192, 8192, 16, 16): (4, 2, 1),
+        (8192, 8192, 8192, 32, 32): (4, 4, 1),
+        (8192, 8192, 8192, 64, 64): (4, 3, 2),
+        (8192, 8192, 8192, 128, 128): (4, 1, 4),
+        (8192, 8192, 16384, 16, 16): (4, 2, 1),
+        (8192, 8192, 16384, 32, 32): (4, 4, 1),
+        (8192, 8192, 16384, 64, 64): (4, 3, 2),
+        (8192, 8192, 16384, 128, 128): (4, 1, 4),
+        (8192, 8192, 32768, 16, 16): (4, 2, 1),
+        (8192, 8192, 32768, 32, 32): (4, 4, 1),
+        (8192, 8192, 32768, 64, 64): (4, 3, 2),
+        (8192, 8192, 32768, 128, 128): (4, 1, 4),
+        (8192, 8192, 65536, 16, 16): (4, 2, 1),
+        (8192, 8192, 65536, 32, 32): (4, 3, 1),
+        (8192, 8192, 65536, 64, 64): (4, 4, 2),
+        (8192, 8192, 65536, 128, 128): (4, 1, 4),
+        (8192, 8192, 131072, 16, 16): (4, 1, 4),
+        (8192, 8192, 131072, 32, 32): (4, 2, 1),
+        (8192, 8192, 131072, 64, 64): (4, 3, 2),
+        (8192, 8192, 131072, 128, 128): (4, 1, 4),
+        (16384, 16384, 256, 16, 16): (4, 8, 1),
+        (16384, 16384, 256, 32, 32): (4, 4, 2),
+        (16384, 16384, 256, 64, 64): (4, 4, 4),
+        (16384, 16384, 256, 128, 128): (6, 2, 8),
+        (16384, 16384, 512, 16, 16): (4, 7, 1),
+        (16384, 16384, 512, 32, 32): (4, 5, 2),
+        (16384, 16384, 512, 64, 64): (4, 3, 2),
+        (16384, 16384, 512, 128, 128): (4, 2, 8),
+        (16384, 16384, 1024, 16, 16): (4, 9, 1),
+        (16384, 16384, 1024, 32, 32): (4, 4, 1),
+        (16384, 16384, 1024, 64, 64): (4, 3, 2),
+        (16384, 16384, 1024, 128, 128): (4, 1, 4),
+        (16384, 16384, 2048, 16, 16): (4, 9, 1),
+        (16384, 16384, 2048, 32, 32): (4, 4, 1),
+        (16384, 16384, 2048, 64, 64): (4, 3, 2),
+        (16384, 16384, 2048, 128, 128): (4, 1, 4),
+        (16384, 16384, 4096, 16, 16): (4, 2, 1),
+        (16384, 16384, 4096, 32, 32): (4, 5, 1),
+        (16384, 16384, 4096, 64, 64): (4, 3, 2),
+        (16384, 16384, 4096, 128, 128): (4, 1, 4),
+        (16384, 16384, 8192, 16, 16): (4, 2, 1),
+        (16384, 16384, 8192, 32, 32): (4, 5, 1),
+        (16384, 16384, 8192, 64, 64): (4, 3, 2),
+        (16384, 16384, 8192, 128, 128): (4, 1, 4),
+        (16384, 16384, 16384, 16, 16): (4, 2, 1),
+        (16384, 16384, 16384, 32, 32): (4, 4, 1),
+        (16384, 16384, 16384, 64, 64): (4, 3, 2),
+        (16384, 16384, 16384, 128, 128): (4, 1, 4),
+        (16384, 16384, 32768, 16, 16): (4, 2, 1),
+        (16384, 16384, 32768, 32, 32): (4, 5, 1),
+        (16384, 16384, 32768, 64, 64): (4, 3, 2),
+        (16384, 16384, 32768, 128, 128): (4, 1, 4),
+        (16384, 16384, 65536, 16, 16): (4, 2, 1),
+        (16384, 16384, 65536, 32, 32): (4, 6, 1),
+        (16384, 16384, 65536, 64, 64): (4, 3, 2),
+        (16384, 16384, 65536, 128, 128): (4, 1, 4),
+        (16384, 16384, 131072, 16, 16): (4, 2, 2),
+        (16384, 16384, 131072, 32, 32): (4, 2, 1),
+        (16384, 16384, 131072, 64, 64): (4, 3, 2),
+        (16384, 16384, 131072, 128, 128): (4, 1, 4),
+    },
     ("bsr_dense_mm", "NVIDIA A100-SXM4-80GB", (0, torch.float16, 0.3)): {
         (256, 256, 256, 16, 16): (5, 1, 2),
         (256, 256, 256, 32, 32): (4, 3, 4),
@@ -769,80 +1154,80 @@
         (16384, 16384, 131072, 128, 128): (4, 1, 4),
     },
     ("bsr_dense_mm", "NVIDIA A100-SXM4-80GB", (0, torch.float16, 0.5)): {
-        (256, 256, 256, 16, 16): (8, 1, 4),
-        (256, 256, 256, 32, 32): (4, 3, 4),
+        (256, 256, 256, 16, 16): (8, 6, 1),
+        (256, 256, 256, 32, 32): (4, 5, 2),
         (256, 256, 256, 64, 64): (3, 3, 4),
         (256, 256, 256, 128, 128): (4, 2, 8),
-        (256, 256, 512, 16, 16): (3, 1, 4),
-        (256, 256, 512, 32, 32): (4, 1, 4),
-        (256, 256, 512, 64, 64): (4, 3, 4),
-        (256, 256, 512, 128, 128): (4, 2, 4),
+        (256, 256, 512, 16, 16): (4, 1, 4),
+        (256, 256, 512, 32, 32): (4, 3, 4),
+        (256, 256, 512, 64, 64): (6, 3, 4),
+        (256, 256, 512, 128, 128): (2, 2, 8),
         (256, 256, 1024, 16, 16): (4, 1, 2),
-        (256, 256, 1024, 32, 32): (4, 3, 4),
+        (256, 256, 1024, 32, 32): (1, 3, 2),
         (256, 256, 1024, 64, 64): (4, 3, 4),
         (256, 256, 1024, 128, 128): (4, 2, 8),
-        (256, 256, 2048, 16, 16): (4, 1, 1),
-        (256, 256, 2048, 32, 32): (5, 1, 4),
-        (256, 256, 2048, 64, 64): (4, 3, 4),
+        (256, 256, 2048, 16, 16): (2, 3, 1),
+        (256, 256, 2048, 32, 32): (5, 4, 1),
+        (256, 256, 2048, 64, 64): (3, 3, 4),
         (256, 256, 2048, 128, 128): (4, 2, 8),
-        (256, 256, 4096, 16, 16): (4, 1, 1),
+        (256, 256, 4096, 16, 16): (4, 3, 1),
         (256, 256, 4096, 32, 32): (4, 1, 4),
         (256, 256, 4096, 64, 64): (4, 3, 4),
         (256, 256, 4096, 128, 128): (4, 2, 8),
         (256, 256, 8192, 16, 16): (5, 3, 1),
-        (256, 256, 8192, 32, 32): (4, 3, 2),
+        (256, 256, 8192, 32, 32): (4, 3, 1),
         (256, 256, 8192, 64, 64): (4, 2, 4),
-        (256, 256, 8192, 128, 128): (4, 1, 4),
-        (256, 256, 16384, 16, 16): (5, 2, 1),
-        (256, 256, 16384, 32, 32): (3, 2, 2),
-        (256, 256, 16384, 64, 64): (4, 1, 4),
+        (256, 256, 8192, 128, 128): (2, 1, 4),
+        (256, 256, 16384, 16, 16): (7, 3, 1),
+        (256, 256, 16384, 32, 32): (3, 3, 1),
+        (256, 256, 16384, 64, 64): (4, 2, 4),
         (256, 256, 16384, 128, 128): (4, 1, 4),
-        (256, 256, 32768, 16, 16): (5, 3, 1),
-        (256, 256, 32768, 32, 32): (4, 3, 2),
+        (256, 256, 32768, 16, 16): (7, 3, 1),
+        (256, 256, 32768, 32, 32): (1, 3, 1),
         (256, 256, 32768, 64, 64): (4, 2, 4),
         (256, 256, 32768, 128, 128): (4, 1, 4),
         (256, 256, 65536, 16, 16): (5, 3, 1),
         (256, 256, 65536, 32, 32): (4, 3, 1),
-        (256, 256, 65536, 64, 64): (5, 2, 4),
-        (256, 256, 65536, 128, 128): (4, 1, 4),
+        (256, 256, 65536, 64, 64): (6, 2, 4),
+        (256, 256, 65536, 128, 128): (1, 1, 4),
         (256, 256, 131072, 16, 16): (4, 1, 2),
         (256, 256, 131072, 32, 32): (4, 2, 2),
         (256, 256, 131072, 64, 64): (4, 2, 4),
         (256, 256, 131072, 128, 128): (4, 1, 4),
-        (512, 512, 256, 16, 16): (4, 1, 4),
-        (512, 512, 256, 32, 32): (4, 1, 4),
+        (512, 512, 256, 16, 16): (4, 5, 1),
+        (512, 512, 256, 32, 32): (4, 5, 2),
         (512, 512, 256, 64, 64): (4, 3, 4),
         (512, 512, 256, 128, 128): (4, 2, 8),
-        (512, 512, 512, 16, 16): (4, 1, 1),
-        (512, 512, 512, 32, 32): (4, 1, 4),
-        (512, 512, 512, 64, 64): (4, 3, 4),
+        (512, 512, 512, 16, 16): (2, 4, 1),
+        (512, 512, 512, 32, 32): (4, 4, 4),
+        (512, 512, 512, 64, 64): (4, 5, 4),
         (512, 512, 512, 128, 128): (4, 2, 8),
-        (512, 512, 1024, 16, 16): (5, 4, 1),
-        (512, 512, 1024, 32, 32): (4, 1, 4),
+        (512, 512, 1024, 16, 16): (1, 4, 1),
+        (512, 512, 1024, 32, 32): (2, 3, 1),
         (512, 512, 1024, 64, 64): (4, 4, 4),
         (512, 512, 1024, 128, 128): (4, 2, 8),
-        (512, 512, 2048, 16, 16): (4, 3, 1),
-        (512, 512, 2048, 32, 32): (4, 3, 2),
-        (512, 512, 2048, 64, 64): (3, 3, 4),
+        (512, 512, 2048, 16, 16): (5, 3, 1),
+        (512, 512, 2048, 32, 32): (2, 3, 2),
+        (512, 512, 2048, 64, 64): (1, 3, 2),
         (512, 512, 2048, 128, 128): (4, 2, 8),
         (512, 512, 4096, 16, 16): (4, 3, 1),
         (512, 512, 4096, 32, 32): (4, 4, 2),
-        (512, 512, 4096, 64, 64): (4, 3, 4),
-        (512, 512, 4096, 128, 128): (4, 2, 8),
-        (512, 512, 8192, 16, 16): (4, 3, 1),
-        (512, 512, 8192, 32, 32): (5, 3, 2),
-        (512, 512, 8192, 64, 64): (5, 2, 4),
-        (512, 512, 8192, 128, 128): (4, 1, 4),
+        (512, 512, 4096, 64, 64): (5, 3, 4),
+        (512, 512, 4096, 128, 128): (2, 2, 8),
+        (512, 512, 8192, 16, 16): (2, 3, 1),
+        (512, 512, 8192, 32, 32): (1, 3, 1),
+        (512, 512, 8192, 64, 64): (5, 3, 2),
+        (512, 512, 8192, 128, 128): (4, 1, 16),
         (512, 512, 16384, 16, 16): (4, 3, 1),
         (512, 512, 16384, 32, 32): (4, 3, 1),
         (512, 512, 16384, 64, 64): (4, 3, 4),
         (512, 512, 16384, 128, 128): (4, 1, 4),
-        (512, 512, 32768, 16, 16): (4, 2, 1),
+        (512, 512, 32768, 16, 16): (1, 2, 1),
         (512, 512, 32768, 32, 32): (5, 3, 1),
         (512, 512, 32768, 64, 64): (4, 3, 2),
         (512, 512, 32768, 128, 128): (5, 1, 4),
         (512, 512, 65536, 16, 16): (4, 3, 1),
-        (512, 512, 65536, 32, 32): (4, 3, 1),
+        (512, 512, 65536, 32, 32): (1, 3, 1),
         (512, 512, 65536, 64, 64): (4, 3, 2),
         (512, 512, 65536, 128, 128): (5, 1, 4),
         (512, 512, 131072, 16, 16): (4, 1, 4),
@@ -850,40 +1235,40 @@
         (512, 512, 131072, 64, 64): (4, 3, 2),
         (512, 512, 131072, 128, 128): (4, 1, 4),
         (1024, 1024, 256, 16, 16): (4, 4, 1),
-        (1024, 1024, 256, 32, 32): (4, 1, 4),
+        (1024, 1024, 256, 32, 32): (2, 4, 2),
         (1024, 1024, 256, 64, 64): (4, 4, 4),
         (1024, 1024, 256, 128, 128): (4, 2, 8),
-        (1024, 1024, 512, 16, 16): (4, 4, 1),
+        (1024, 1024, 512, 16, 16): (3, 3, 1),
         (1024, 1024, 512, 32, 32): (5, 4, 2),
         (1024, 1024, 512, 64, 64): (4, 3, 4),
         (1024, 1024, 512, 128, 128): (3, 2, 8),
-        (1024, 1024, 1024, 16, 16): (5, 3, 1),
-        (1024, 1024, 1024, 32, 32): (1, 3, 1),
+        (1024, 1024, 1024, 16, 16): (7, 3, 1),
+        (1024, 1024, 1024, 32, 32): (2, 3, 1),
         (1024, 1024, 1024, 64, 64): (1, 3, 2),
         (1024, 1024, 1024, 128, 128): (1, 2, 8),
-        (1024, 1024, 2048, 16, 16): (4, 3, 1),
-        (1024, 1024, 2048, 32, 32): (4, 4, 1),
+        (1024, 1024, 2048, 16, 16): (2, 3, 1),
+        (1024, 1024, 2048, 32, 32): (1, 4, 1),
         (1024, 1024, 2048, 64, 64): (3, 3, 4),
         (1024, 1024, 2048, 128, 128): (4, 2, 8),
         (1024, 1024, 4096, 16, 16): (4, 3, 1),
-        (1024, 1024, 4096, 32, 32): (4, 3, 1),
+        (1024, 1024, 4096, 32, 32): (4, 4, 1),
         (1024, 1024, 4096, 64, 64): (4, 3, 4),
         (1024, 1024, 4096, 128, 128): (4, 2, 8),
-        (1024, 1024, 8192, 16, 16): (4, 3, 1),
+        (1024, 1024, 8192, 16, 16): (2, 3, 1),
         (1024, 1024, 8192, 32, 32): (4, 3, 1),
-        (1024, 1024, 8192, 64, 64): (4, 3, 4),
+        (1024, 1024, 8192, 64, 64): (4, 3, 2),
         (1024, 1024, 8192, 128, 128): (4, 1, 4),
         (1024, 1024, 16384, 16, 16): (4, 2, 1),
         (1024, 1024, 16384, 32, 32): (4, 3, 1),
         (1024, 1024, 16384, 64, 64): (4, 3, 2),
         (1024, 1024, 16384, 128, 128): (4, 1, 4),
         (1024, 1024, 32768, 16, 16): (4, 2, 1),
-        (1024, 1024, 32768, 32, 32): (5, 3, 2),
-        (1024, 1024, 32768, 64, 64): (5, 3, 2),
+        (1024, 1024, 32768, 32, 32): (8, 3, 1),
+        (1024, 1024, 32768, 64, 64): (8, 3, 2),
         (1024, 1024, 32768, 128, 128): (4, 1, 4),
-        (1024, 1024, 65536, 16, 16): (4, 2, 1),
-        (1024, 1024, 65536, 32, 32): (5, 3, 1),
-        (1024, 1024, 65536, 64, 64): (5, 3, 2),
+        (1024, 1024, 65536, 16, 16): (4, 4, 1),
+        (1024, 1024, 65536, 32, 32): (7, 3, 1),
+        (1024, 1024, 65536, 64, 64): (7, 3, 2),
         (1024, 1024, 65536, 128, 128): (4, 1, 4),
         (1024, 1024, 131072, 16, 16): (4, 1, 4),
         (1024, 1024, 131072, 32, 32): (4, 2, 1),
@@ -891,66 +1276,66 @@
         (1024, 1024, 131072, 128, 128): (4, 1, 4),
         (2048, 2048, 256, 16, 16): (5, 4, 1),
         (2048, 2048, 256, 32, 32): (4, 4, 2),
-        (2048, 2048, 256, 64, 64): (5, 3, 4),
+        (2048, 2048, 256, 64, 64): (2, 3, 4),
         (2048, 2048, 256, 128, 128): (4, 2, 8),
         (2048, 2048, 512, 16, 16): (4, 4, 1),
         (2048, 2048, 512, 32, 32): (5, 3, 2),
-        (2048, 2048, 512, 64, 64): (6, 3, 4),
+        (2048, 2048, 512, 64, 64): (8, 3, 4),
         (2048, 2048, 512, 128, 128): (4, 2, 8),
         (2048, 2048, 1024, 16, 16): (4, 3, 1),
-        (2048, 2048, 1024, 32, 32): (4, 3, 4),
+        (2048, 2048, 1024, 32, 32): (2, 4, 1),
         (2048, 2048, 1024, 64, 64): (5, 3, 4),
-        (2048, 2048, 1024, 128, 128): (4, 2, 8),
+        (2048, 2048, 1024, 128, 128): (6, 2, 8),
         (2048, 2048, 2048, 16, 16): (3, 3, 1),
-        (2048, 2048, 2048, 32, 32): (1, 4, 1),
+        (2048, 2048, 2048, 32, 32): (2, 4, 1),
         (2048, 2048, 2048, 64, 64): (3, 3, 2),
         (2048, 2048, 2048, 128, 128): (4, 2, 8),
         (2048, 2048, 4096, 16, 16): (4, 3, 1),
         (2048, 2048, 4096, 32, 32): (4, 4, 2),
-        (2048, 2048, 4096, 64, 64): (4, 3, 2),
-        (2048, 2048, 4096, 128, 128): (4, 1, 4),
-        (2048, 2048, 8192, 16, 16): (4, 2, 1),
-        (2048, 2048, 8192, 32, 32): (4, 4, 2),
-        (2048, 2048, 8192, 64, 64): (4, 3, 2),
-        (2048, 2048, 8192, 128, 128): (4, 1, 4),
+        (2048, 2048, 4096, 64, 64): (6, 3, 2),
+        (2048, 2048, 4096, 128, 128): (2, 1, 4),
+        (2048, 2048, 8192, 16, 16): (6, 2, 1),
+        (2048, 2048, 8192, 32, 32): (6, 4, 2),
+        (2048, 2048, 8192, 64, 64): (6, 3, 2),
+        (2048, 2048, 8192, 128, 128): (2, 1, 4),
         (2048, 2048, 16384, 16, 16): (4, 2, 1),
-        (2048, 2048, 16384, 32, 32): (4, 4, 2),
+        (2048, 2048, 16384, 32, 32): (4, 4, 1),
         (2048, 2048, 16384, 64, 64): (4, 3, 2),
-        (2048, 2048, 16384, 128, 128): (4, 1, 4),
-        (2048, 2048, 32768, 16, 16): (6, 2, 1),
-        (2048, 2048, 32768, 32, 32): (5, 4, 1),
-        (2048, 2048, 32768, 64, 64): (5, 3, 2),
+        (2048, 2048, 16384, 128, 128): (2, 1, 4),
+        (2048, 2048, 32768, 16, 16): (8, 2, 1),
+        (2048, 2048, 32768, 32, 32): (7, 4, 1),
+        (2048, 2048, 32768, 64, 64): (9, 3, 2),
         (2048, 2048, 32768, 128, 128): (4, 1, 4),
         (2048, 2048, 65536, 16, 16): (3, 2, 1),
-        (2048, 2048, 65536, 32, 32): (5, 4, 1),
+        (2048, 2048, 65536, 32, 32): (9, 3, 1),
         (2048, 2048, 65536, 64, 64): (4, 3, 2),
-        (2048, 2048, 65536, 128, 128): (4, 1, 4),
+        (2048, 2048, 65536, 128, 128): (2, 1, 4),
         (2048, 2048, 131072, 16, 16): (4, 1, 4),
         (2048, 2048, 131072, 32, 32): (4, 1, 1),
-        (2048, 2048, 131072, 64, 64): (5, 3, 2),
+        (2048, 2048, 131072, 64, 64): (4, 3, 2),
         (2048, 2048, 131072, 128, 128): (4, 1, 4),
         (4096, 4096, 256, 16, 16): (4, 4, 1),
-        (4096, 4096, 256, 32, 32): (4, 3, 2),
+        (4096, 4096, 256, 32, 32): (1, 3, 2),
         (4096, 4096, 256, 64, 64): (3, 3, 4),
         (4096, 4096, 256, 128, 128): (3, 2, 8),
         (4096, 4096, 512, 16, 16): (1, 3, 1),
-        (4096, 4096, 512, 32, 32): (4, 3, 4),
+        (4096, 4096, 512, 32, 32): (1, 3, 4),
         (4096, 4096, 512, 64, 64): (6, 3, 4),
-        (4096, 4096, 512, 128, 128): (4, 2, 8),
-        (4096, 4096, 1024, 16, 16): (3, 3, 1),
+        (4096, 4096, 512, 128, 128): (2, 2, 8),
+        (4096, 4096, 1024, 16, 16): (1, 3, 1),
         (4096, 4096, 1024, 32, 32): (4, 4, 2),
         (4096, 4096, 1024, 64, 64): (4, 4, 4),
-        (4096, 4096, 1024, 128, 128): (4, 2, 8),
+        (4096, 4096, 1024, 128, 128): (2, 2, 8),
         (4096, 4096, 2048, 16, 16): (1, 3, 1),
         (4096, 4096, 2048, 32, 32): (3, 4, 2),
-        (4096, 4096, 2048, 64, 64): (4, 3, 4),
+        (4096, 4096, 2048, 64, 64): (4, 3, 2),
         (4096, 4096, 2048, 128, 128): (4, 1, 4),
         (4096, 4096, 4096, 16, 16): (2, 3, 1),
-        (4096, 4096, 4096, 32, 32): (4, 4, 2),
+        (4096, 4096, 4096, 32, 32): (2, 4, 2),
         (4096, 4096, 4096, 64, 64): (1, 3, 2),
         (4096, 4096, 4096, 128, 128): (4, 1, 4),
-        (4096, 4096, 8192, 16, 16): (4, 2, 1),
-        (4096, 4096, 8192, 32, 32): (4, 3, 2),
+        (4096, 4096, 8192, 16, 16): (8, 2, 1),
+        (4096, 4096, 8192, 32, 32): (2, 4, 2),
         (4096, 4096, 8192, 64, 64): (4, 3, 2),
         (4096, 4096, 8192, 128, 128): (4, 1, 4),
         (4096, 4096, 16384, 16, 16): (1, 1, 1),
@@ -959,18 +1344,18 @@
         (4096, 4096, 16384, 128, 128): (4, 1, 4),
         (4096, 4096, 32768, 16, 16): (4, 2, 1),
         (4096, 4096, 32768, 32, 32): (5, 3, 1),
-        (4096, 4096, 32768, 64, 64): (5, 3, 2),
+        (4096, 4096, 32768, 64, 64): (3, 3, 2),
         (4096, 4096, 32768, 128, 128): (4, 1, 4),
         (4096, 4096, 65536, 16, 16): (4, 2, 1),
-        (4096, 4096, 65536, 32, 32): (2, 3, 1),
-        (4096, 4096, 65536, 64, 64): (2, 3, 2),
+        (4096, 4096, 65536, 32, 32): (2, 4, 1),
+        (4096, 4096, 65536, 64, 64): (3, 3, 2),
         (4096, 4096, 65536, 128, 128): (4, 1, 4),
         (4096, 4096, 131072, 16, 16): (1, 1, 4),
         (4096, 4096, 131072, 32, 32): (4, 2, 1),
-        (4096, 4096, 131072, 64, 64): (5, 3, 2),
+        (4096, 4096, 131072, 64, 64): (7, 3, 2),
         (4096, 4096, 131072, 128, 128): (4, 1, 4),
         (8192, 8192, 256, 16, 16): (4, 4, 1),
-        (8192, 8192, 256, 32, 32): (4, 3, 4),
+        (8192, 8192, 256, 32, 32): (4, 5, 2),
         (8192, 8192, 256, 64, 64): (4, 3, 4),
         (8192, 8192, 256, 128, 128): (5, 1, 4),
         (8192, 8192, 512, 16, 16): (4, 5, 1),
@@ -981,9 +1366,9 @@
         (8192, 8192, 1024, 32, 32): (4, 4, 2),
         (8192, 8192, 1024, 64, 64): (4, 4, 4),
         (8192, 8192, 1024, 128, 128): (4, 1, 4),
-        (8192, 8192, 2048, 16, 16): (4, 3, 1),
-        (8192, 8192, 2048, 32, 32): (4, 4, 1),
-        (8192, 8192, 2048, 64, 64): (4, 3, 4),
+        (8192, 8192, 2048, 16, 16): (4, 5, 1),
+        (8192, 8192, 2048, 32, 32): (4, 4, 2),
+        (8192, 8192, 2048, 64, 64): (4, 3, 2),
         (8192, 8192, 2048, 128, 128): (4, 1, 4),
         (8192, 8192, 4096, 16, 16): (4, 3, 1),
         (8192, 8192, 4096, 32, 32): (4, 4, 1),
@@ -999,29 +1384,29 @@
         (8192, 8192, 16384, 128, 128): (4, 1, 4),
         (8192, 8192, 32768, 16, 16): (4, 2, 1),
         (8192, 8192, 32768, 32, 32): (4, 4, 1),
-        (8192, 8192, 32768, 64, 64): (4, 4, 2),
+        (8192, 8192, 32768, 64, 64): (4, 3, 2),
         (8192, 8192, 32768, 128, 128): (4, 1, 4),
         (8192, 8192, 65536, 16, 16): (4, 2, 1),
-        (8192, 8192, 65536, 32, 32): (4, 3, 1),
-        (8192, 8192, 65536, 64, 64): (4, 3, 2),
+        (8192, 8192, 65536, 32, 32): (4, 4, 1),
+        (8192, 8192, 65536, 64, 64): (4, 4, 2),
         (8192, 8192, 65536, 128, 128): (4, 1, 4),
         (8192, 8192, 131072, 16, 16): (4, 1, 4),
         (8192, 8192, 131072, 32, 32): (4, 2, 1),
         (8192, 8192, 131072, 64, 64): (4, 3, 2),
         (8192, 8192, 131072, 128, 128): (4, 1, 4),
-        (16384, 16384, 256, 16, 16): (4, 4, 1),
+        (16384, 16384, 256, 16, 16): (4, 7, 1),
         (16384, 16384, 256, 32, 32): (4, 4, 2),
         (16384, 16384, 256, 64, 64): (4, 4, 4),
-        (16384, 16384, 256, 128, 128): (4, 2, 8),
-        (16384, 16384, 512, 16, 16): (4, 3, 1),
+        (16384, 16384, 256, 128, 128): (6, 2, 8),
+        (16384, 16384, 512, 16, 16): (4, 7, 1),
         (16384, 16384, 512, 32, 32): (4, 5, 2),
-        (16384, 16384, 512, 64, 64): (4, 4, 4),
+        (16384, 16384, 512, 64, 64): (4, 3, 2),
         (16384, 16384, 512, 128, 128): (4, 2, 8),
-        (16384, 16384, 1024, 16, 16): (4, 3, 1),
+        (16384, 16384, 1024, 16, 16): (4, 9, 1),
         (16384, 16384, 1024, 32, 32): (4, 4, 1),
         (16384, 16384, 1024, 64, 64): (4, 4, 4),
         (16384, 16384, 1024, 128, 128): (4, 1, 4),
-        (16384, 16384, 2048, 16, 16): (4, 3, 1),
+        (16384, 16384, 2048, 16, 16): (4, 9, 1),
         (16384, 16384, 2048, 32, 32): (4, 4, 1),
         (16384, 16384, 2048, 64, 64): (4, 4, 4),
         (16384, 16384, 2048, 128, 128): (4, 1, 4),
@@ -1038,11 +1423,11 @@
         (16384, 16384, 16384, 64, 64): (4, 3, 2),
         (16384, 16384, 16384, 128, 128): (4, 1, 4),
         (16384, 16384, 32768, 16, 16): (4, 2, 1),
-        (16384, 16384, 32768, 32, 32): (4, 6, 1),
+        (16384, 16384, 32768, 32, 32): (4, 5, 1),
         (16384, 16384, 32768, 64, 64): (4, 3, 2),
         (16384, 16384, 32768, 128, 128): (4, 1, 4),
         (16384, 16384, 65536, 16, 16): (4, 2, 1),
-        (16384, 16384, 65536, 32, 32): (4, 2, 1),
+        (16384, 16384, 65536, 32, 32): (4, 5, 1),
         (16384, 16384, 65536, 64, 64): (4, 3, 2),
         (16384, 16384, 65536, 128, 128): (4, 1, 4),
         (16384, 16384, 131072, 16, 16): (4, 1, 4),
@@ -1332,12 +1717,576 @@
         (16384, 16384, 131072, 64, 64): (4, 3, 2),
         (16384, 16384, 131072, 128, 128): (4, 1, 4),
     },
+    ("bsr_dense_mm", "NVIDIA A100-SXM4-80GB", (0, torch.float32, 0.5)): {
+        (256, 256, 256, 16, 16): (1, 1, 8),
+        (256, 256, 256, 32, 32): (1, 3, 4),
+        (256, 256, 256, 64, 64): (1, 1, 8),
+        (256, 256, 256, 128, 128): (1, 1, 16),
+        (256, 256, 512, 16, 16): (2, 1, 4),
+        (256, 256, 512, 32, 32): (3, 1, 4),
+        (256, 256, 512, 64, 64): (2, 1, 8),
+        (256, 256, 512, 128, 128): (1, 1, 32),
+        (256, 256, 1024, 16, 16): (2, 1, 1),
+        (256, 256, 1024, 32, 32): (1, 1, 4),
+        (256, 256, 1024, 64, 64): (4, 1, 8),
+        (256, 256, 1024, 128, 128): (1, 1, 32),
+        (256, 256, 2048, 16, 16): (1, 1, 1),
+        (256, 256, 2048, 32, 32): (1, 1, 4),
+        (256, 256, 2048, 64, 64): (2, 1, 8),
+        (256, 256, 2048, 128, 128): (1, 1, 32),
+        (256, 256, 4096, 16, 16): (1, 1, 1),
+        (256, 256, 4096, 32, 32): (1, 1, 4),
+        (256, 256, 4096, 64, 64): (4, 1, 8),
+        (256, 256, 4096, 128, 128): (1, 1, 32),
+        (256, 256, 8192, 16, 16): (1, 1, 1),
+        (256, 256, 8192, 32, 32): (1, 1, 4),
+        (256, 256, 8192, 64, 64): (1, 1, 4),
+        (256, 256, 8192, 128, 128): (1, 1, 32),
+        (256, 256, 16384, 16, 16): (2, 1, 1),
+        (256, 256, 16384, 32, 32): (2, 1, 4),
+        (256, 256, 16384, 64, 64): (1, 1, 4),
+        (256, 256, 16384, 128, 128): (1, 1, 32),
+        (256, 256, 32768, 16, 16): (1, 1, 1),
+        (256, 256, 32768, 32, 32): (1, 1, 4),
+        (256, 256, 32768, 64, 64): (1, 1, 4),
+        (256, 256, 32768, 128, 128): (1, 1, 32),
+        (256, 256, 65536, 16, 16): (4, 1, 1),
+        (256, 256, 65536, 32, 32): (1, 1, 4),
+        (256, 256, 65536, 64, 64): (1, 1, 4),
+        (256, 256, 65536, 128, 128): (1, 1, 32),
+        (256, 256, 131072, 16, 16): (2, 1, 1),
+        (256, 256, 131072, 32, 32): (2, 1, 4),
+        (256, 256, 131072, 64, 64): (1, 1, 4),
+        (256, 256, 131072, 128, 128): (1, 1, 32),
+        (512, 512, 256, 16, 16): (1, 1, 4),
+        (512, 512, 256, 32, 32): (1, 1, 4),
+        (512, 512, 256, 64, 64): (1, 1, 8),
+        (512, 512, 256, 128, 128): (1, 1, 32),
+        (512, 512, 512, 16, 16): (1, 1, 1),
+        (512, 512, 512, 32, 32): (1, 1, 4),
+        (512, 512, 512, 64, 64): (1, 1, 8),
+        (512, 512, 512, 128, 128): (1, 1, 32),
+        (512, 512, 1024, 16, 16): (1, 1, 1),
+        (512, 512, 1024, 32, 32): (3, 1, 4),
+        (512, 512, 1024, 64, 64): (4, 1, 8),
+        (512, 512, 1024, 128, 128): (1, 1, 32),
+        (512, 512, 2048, 16, 16): (2, 1, 1),
+        (512, 512, 2048, 32, 32): (3, 1, 4),
+        (512, 512, 2048, 64, 64): (5, 1, 8),
+        (512, 512, 2048, 128, 128): (1, 1, 32),
+        (512, 512, 4096, 16, 16): (1, 1, 1),
+        (512, 512, 4096, 32, 32): (2, 1, 4),
+        (512, 512, 4096, 64, 64): (1, 1, 8),
+        (512, 512, 4096, 128, 128): (1, 1, 32),
+        (512, 512, 8192, 16, 16): (1, 1, 1),
+        (512, 512, 8192, 32, 32): (2, 1, 4),
+        (512, 512, 8192, 64, 64): (1, 1, 4),
+        (512, 512, 8192, 128, 128): (2, 1, 32),
+        (512, 512, 16384, 16, 16): (1, 1, 1),
+        (512, 512, 16384, 32, 32): (2, 1, 2),
+        (512, 512, 16384, 64, 64): (1, 1, 4),
+        (512, 512, 16384, 128, 128): (1, 1, 32),
+        (512, 512, 32768, 16, 16): (4, 1, 1),
+        (512, 512, 32768, 32, 32): (4, 1, 2),
+        (512, 512, 32768, 64, 64): (1, 1, 4),
+        (512, 512, 32768, 128, 128): (1, 1, 32),
+        (512, 512, 65536, 16, 16): (4, 1, 1),
+        (512, 512, 65536, 32, 32): (4, 1, 2),
+        (512, 512, 65536, 64, 64): (1, 1, 4),
+        (512, 512, 65536, 128, 128): (1, 1, 32),
+        (512, 512, 131072, 16, 16): (4, 1, 1),
+        (512, 512, 131072, 32, 32): (4, 1, 2),
+        (512, 512, 131072, 64, 64): (1, 1, 4),
+        (512, 512, 131072, 128, 128): (1, 1, 32),
+        (1024, 1024, 256, 16, 16): (1, 1, 1),
+        (1024, 1024, 256, 32, 32): (1, 1, 4),
+        (1024, 1024, 256, 64, 64): (1, 1, 8),
+        (1024, 1024, 256, 128, 128): (1, 1, 32),
+        (1024, 1024, 512, 16, 16): (1, 1, 1),
+        (1024, 1024, 512, 32, 32): (1, 1, 4),
+        (1024, 1024, 512, 64, 64): (4, 1, 8),
+        (1024, 1024, 512, 128, 128): (1, 1, 32),
+        (1024, 1024, 1024, 16, 16): (1, 1, 1),
+        (1024, 1024, 1024, 32, 32): (1, 1, 4),
+        (1024, 1024, 1024, 64, 64): (3, 1, 8),
+        (1024, 1024, 1024, 128, 128): (1, 1, 32),
+        (1024, 1024, 2048, 16, 16): (1, 1, 1),
+        (1024, 1024, 2048, 32, 32): (3, 1, 4),
+        (1024, 1024, 2048, 64, 64): (1, 1, 8),
+        (1024, 1024, 2048, 128, 128): (4, 1, 32),
+        (1024, 1024, 4096, 16, 16): (1, 1, 1),
+        (1024, 1024, 4096, 32, 32): (4, 1, 2),
+        (1024, 1024, 4096, 64, 64): (1, 1, 8),
+        (1024, 1024, 4096, 128, 128): (4, 1, 32),
+        (1024, 1024, 8192, 16, 16): (1, 1, 1),
+        (1024, 1024, 8192, 32, 32): (4, 1, 2),
+        (1024, 1024, 8192, 64, 64): (1, 1, 4),
+        (1024, 1024, 8192, 128, 128): (1, 1, 32),
+        (1024, 1024, 16384, 16, 16): (1, 1, 1),
+        (1024, 1024, 16384, 32, 32): (2, 1, 2),
+        (1024, 1024, 16384, 64, 64): (1, 1, 4),
+        (1024, 1024, 16384, 128, 128): (1, 1, 32),
+        (1024, 1024, 32768, 16, 16): (4, 1, 1),
+        (1024, 1024, 32768, 32, 32): (4, 1, 1),
+        (1024, 1024, 32768, 64, 64): (1, 1, 4),
+        (1024, 1024, 32768, 128, 128): (1, 1, 32),
+        (1024, 1024, 65536, 16, 16): (4, 1, 1),
+        (1024, 1024, 65536, 32, 32): (4, 1, 1),
+        (1024, 1024, 65536, 64, 64): (1, 1, 4),
+        (1024, 1024, 65536, 128, 128): (1, 1, 32),
+        (1024, 1024, 131072, 16, 16): (4, 1, 1),
+        (1024, 1024, 131072, 32, 32): (4, 1, 1),
+        (1024, 1024, 131072, 64, 64): (1, 1, 4),
+        (1024, 1024, 131072, 128, 128): (1, 1, 32),
+        (2048, 2048, 256, 16, 16): (1, 1, 1),
+        (2048, 2048, 256, 32, 32): (3, 1, 4),
+        (2048, 2048, 256, 64, 64): (2, 1, 8),
+        (2048, 2048, 256, 128, 128): (1, 1, 32),
+        (2048, 2048, 512, 16, 16): (1, 1, 1),
+        (2048, 2048, 512, 32, 32): (1, 1, 4),
+        (2048, 2048, 512, 64, 64): (4, 1, 8),
+        (2048, 2048, 512, 128, 128): (1, 1, 32),
+        (2048, 2048, 1024, 16, 16): (1, 1, 1),
+        (2048, 2048, 1024, 32, 32): (1, 1, 4),
+        (2048, 2048, 1024, 64, 64): (2, 1, 8),
+        (2048, 2048, 1024, 128, 128): (4, 1, 32),
+        (2048, 2048, 2048, 16, 16): (1, 1, 1),
+        (2048, 2048, 2048, 32, 32): (4, 1, 2),
+        (2048, 2048, 2048, 64, 64): (2, 1, 4),
+        (2048, 2048, 2048, 128, 128): (4, 1, 32),
+        (2048, 2048, 4096, 16, 16): (1, 1, 1),
+        (2048, 2048, 4096, 32, 32): (4, 1, 1),
+        (2048, 2048, 4096, 64, 64): (2, 1, 4),
+        (2048, 2048, 4096, 128, 128): (1, 1, 32),
+        (2048, 2048, 8192, 16, 16): (1, 1, 1),
+        (2048, 2048, 8192, 32, 32): (4, 1, 1),
+        (2048, 2048, 8192, 64, 64): (2, 1, 4),
+        (2048, 2048, 8192, 128, 128): (1, 1, 32),
+        (2048, 2048, 16384, 16, 16): (2, 1, 1),
+        (2048, 2048, 16384, 32, 32): (4, 1, 1),
+        (2048, 2048, 16384, 64, 64): (2, 1, 4),
+        (2048, 2048, 16384, 128, 128): (1, 1, 32),
+        (2048, 2048, 32768, 16, 16): (2, 1, 1),
+        (2048, 2048, 32768, 32, 32): (4, 1, 1),
+        (2048, 2048, 32768, 64, 64): (3, 1, 4),
+        (2048, 2048, 32768, 128, 128): (1, 1, 32),
+        (2048, 2048, 65536, 16, 16): (2, 1, 1),
+        (2048, 2048, 65536, 32, 32): (4, 1, 1),
+        (2048, 2048, 65536, 64, 64): (11, 1, 4),
+        (2048, 2048, 65536, 128, 128): (4, 1, 32),
+        (2048, 2048, 131072, 16, 16): (4, 1, 1),
+        (2048, 2048, 131072, 32, 32): (4, 1, 1),
+        (2048, 2048, 131072, 64, 64): (3, 1, 4),
+        (2048, 2048, 131072, 128, 128): (1, 1, 32),
+        (4096, 4096, 256, 16, 16): (1, 1, 1),
+        (4096, 4096, 256, 32, 32): (1, 1, 4),
+        (4096, 4096, 256, 64, 64): (4, 1, 8),
+        (4096, 4096, 256, 128, 128): (1, 1, 32),
+        (4096, 4096, 512, 16, 16): (1, 1, 1),
+        (4096, 4096, 512, 32, 32): (1, 1, 4),
+        (4096, 4096, 512, 64, 64): (4, 1, 4),
+        (4096, 4096, 512, 128, 128): (4, 1, 32),
+        (4096, 4096, 1024, 16, 16): (1, 1, 1),
+        (4096, 4096, 1024, 32, 32): (1, 1, 2),
+        (4096, 4096, 1024, 64, 64): (4, 1, 4),
+        (4096, 4096, 1024, 128, 128): (4, 1, 32),
+        (4096, 4096, 2048, 16, 16): (1, 1, 1),
+        (4096, 4096, 2048, 32, 32): (1, 1, 2),
+        (4096, 4096, 2048, 64, 64): (1, 1, 4),
+        (4096, 4096, 2048, 128, 128): (4, 1, 32),
+        (4096, 4096, 4096, 16, 16): (2, 1, 1),
+        (4096, 4096, 4096, 32, 32): (4, 1, 2),
+        (4096, 4096, 4096, 64, 64): (1, 1, 4),
+        (4096, 4096, 4096, 128, 128): (4, 1, 32),
+        (4096, 4096, 8192, 16, 16): (2, 1, 1),
+        (4096, 4096, 8192, 32, 32): (4, 1, 2),
+        (4096, 4096, 8192, 64, 64): (1, 1, 4),
+        (4096, 4096, 8192, 128, 128): (1, 1, 32),
+        (4096, 4096, 16384, 16, 16): (2, 1, 1),
+        (4096, 4096, 16384, 32, 32): (4, 1, 2),
+        (4096, 4096, 16384, 64, 64): (1, 1, 4),
+        (4096, 4096, 16384, 128, 128): (1, 1, 32),
+        (4096, 4096, 32768, 16, 16): (4, 1, 1),
+        (4096, 4096, 32768, 32, 32): (4, 1, 2),
+        (4096, 4096, 32768, 64, 64): (1, 1, 4),
+        (4096, 4096, 32768, 128, 128): (1, 1, 32),
+        (4096, 4096, 65536, 16, 16): (4, 1, 1),
+        (4096, 4096, 65536, 32, 32): (4, 1, 2),
+        (4096, 4096, 65536, 64, 64): (1, 1, 4),
+        (4096, 4096, 65536, 128, 128): (1, 1, 32),
+        (4096, 4096, 131072, 16, 16): (4, 1, 1),
+        (4096, 4096, 131072, 32, 32): (4, 1, 2),
+        (4096, 4096, 131072, 64, 64): (1, 1, 4),
+        (4096, 4096, 131072, 128, 128): (1, 1, 32),
+        (8192, 8192, 256, 16, 16): (4, 1, 1),
+        (8192, 8192, 256, 32, 32): (4, 1, 4),
+        (8192, 8192, 256, 64, 64): (4, 1, 4),
+        (8192, 8192, 256, 128, 128): (1, 1, 32),
+        (8192, 8192, 512, 16, 16): (4, 1, 1),
+        (8192, 8192, 512, 32, 32): (4, 1, 2),
+        (8192, 8192, 512, 64, 64): (4, 1, 4),
+        (8192, 8192, 512, 128, 128): (4, 1, 32),
+        (8192, 8192, 1024, 16, 16): (4, 1, 1),
+        (8192, 8192, 1024, 32, 32): (4, 1, 2),
+        (8192, 8192, 1024, 64, 64): (1, 1, 4),
+        (8192, 8192, 1024, 128, 128): (4, 1, 32),
+        (8192, 8192, 2048, 16, 16): (4, 1, 1),
+        (8192, 8192, 2048, 32, 32): (4, 1, 2),
+        (8192, 8192, 2048, 64, 64): (1, 1, 4),
+        (8192, 8192, 2048, 128, 128): (4, 1, 32),
+        (8192, 8192, 4096, 16, 16): (4, 1, 1),
+        (8192, 8192, 4096, 32, 32): (4, 1, 2),
+        (8192, 8192, 4096, 64, 64): (1, 1, 4),
+        (8192, 8192, 4096, 128, 128): (4, 1, 32),
+        (8192, 8192, 8192, 16, 16): (4, 1, 1),
+        (8192, 8192, 8192, 32, 32): (4, 1, 2),
+        (8192, 8192, 8192, 64, 64): (1, 1, 4),
+        (8192, 8192, 8192, 128, 128): (4, 1, 32),
+        (8192, 8192, 16384, 16, 16): (4, 1, 1),
+        (8192, 8192, 16384, 32, 32): (4, 1, 2),
+        (8192, 8192, 16384, 64, 64): (1, 1, 4),
+        (8192, 8192, 16384, 128, 128): (1, 1, 32),
+        (8192, 8192, 32768, 16, 16): (4, 1, 1),
+        (8192, 8192, 32768, 32, 32): (4, 1, 2),
+        (8192, 8192, 32768, 64, 64): (1, 1, 4),
+        (8192, 8192, 32768, 128, 128): (1, 1, 32),
+        (8192, 8192, 65536, 16, 16): (4, 1, 1),
+        (8192, 8192, 65536, 32, 32): (4, 1, 2),
+        (8192, 8192, 65536, 64, 64): (1, 1, 4),
+        (8192, 8192, 65536, 128, 128): (1, 1, 32),
+        (8192, 8192, 131072, 16, 16): (4, 1, 1),
+        (8192, 8192, 131072, 32, 32): (4, 1, 2),
+        (8192, 8192, 131072, 64, 64): (1, 1, 4),
+        (8192, 8192, 131072, 128, 128): (1, 1, 32),
+        (16384, 16384, 256, 16, 16): (4, 1, 1),
+        (16384, 16384, 256, 32, 32): (4, 1, 4),
+        (16384, 16384, 256, 64, 64): (2, 1, 4),
+        (16384, 16384, 256, 128, 128): (1, 1, 32),
+        (16384, 16384, 512, 16, 16): (4, 1, 1),
+        (16384, 16384, 512, 32, 32): (4, 1, 2),
+        (16384, 16384, 512, 64, 64): (4, 1, 4),
+        (16384, 16384, 512, 128, 128): (4, 1, 32),
+        (16384, 16384, 1024, 16, 16): (4, 1, 1),
+        (16384, 16384, 1024, 32, 32): (4, 1, 2),
+        (16384, 16384, 1024, 64, 64): (1, 1, 4),
+        (16384, 16384, 1024, 128, 128): (4, 1, 32),
+        (16384, 16384, 2048, 16, 16): (4, 1, 1),
+        (16384, 16384, 2048, 32, 32): (4, 1, 2),
+        (16384, 16384, 2048, 64, 64): (1, 1, 4),
+        (16384, 16384, 2048, 128, 128): (4, 1, 32),
+        (16384, 16384, 4096, 16, 16): (4, 1, 1),
+        (16384, 16384, 4096, 32, 32): (4, 1, 1),
+        (16384, 16384, 4096, 64, 64): (1, 1, 4),
+        (16384, 16384, 4096, 128, 128): (1, 1, 32),
+        (16384, 16384, 8192, 16, 16): (4, 1, 1),
+        (16384, 16384, 8192, 32, 32): (4, 1, 1),
+        (16384, 16384, 8192, 64, 64): (1, 1, 4),
+        (16384, 16384, 8192, 128, 128): (1, 1, 32),
+        (16384, 16384, 16384, 16, 16): (4, 1, 1),
+        (16384, 16384, 16384, 32, 32): (4, 1, 1),
+        (16384, 16384, 16384, 64, 64): (1, 1, 4),
+        (16384, 16384, 16384, 128, 128): (1, 1, 32),
+        (16384, 16384, 32768, 16, 16): (4, 1, 1),
+        (16384, 16384, 32768, 32, 32): (4, 1, 1),
+        (16384, 16384, 32768, 64, 64): (1, 1, 4),
+        (16384, 16384, 32768, 128, 128): (1, 1, 32),
+        (16384, 16384, 65536, 16, 16): (4, 1, 1),
+        (16384, 16384, 65536, 32, 32): (4, 1, 1),
+        (16384, 16384, 65536, 64, 64): (1, 1, 4),
+        (16384, 16384, 65536, 128, 128): (1, 1, 32),
+        (16384, 16384, 131072, 16, 16): (2, 1, 1),
+        (16384, 16384, 131072, 32, 32): (4, 1, 1),
+        (16384, 16384, 131072, 64, 64): (1, 1, 4),
+        (16384, 16384, 131072, 128, 128): (1, 1, 32),
+    },
+    ("scatter_mm", "NVIDIA A100-SXM4-80GB", (0, torch.bfloat16, 0.5)): {
+        (256, 256, 256, 16, 16): (1, 1, 16, 16, 1, 2),
+        (256, 256, 256, 32, 32): (1, 1, 16, 16, 1, 4),
+        (256, 256, 256, 64, 64): (1, 1, 16, 16, 1, 1),
+        (256, 256, 256, 128, 128): (2, 4, 16, 64, 1, 4),
+        (256, 256, 512, 16, 16): (1, 1, 16, 16, 1, 4),
+        (256, 256, 512, 32, 32): (1, 1, 16, 32, 1, 4),
+        (256, 256, 512, 64, 64): (1, 1, 16, 32, 1, 1),
+        (256, 256, 512, 128, 128): (1, 1, 32, 32, 1, 4),
+        (256, 256, 1024, 16, 16): (1, 1, 16, 16, 1, 4),
+        (256, 256, 1024, 32, 32): (1, 2, 16, 32, 1, 1),
+        (256, 256, 1024, 64, 64): (1, 1, 32, 32, 1, 2),
+        (256, 256, 1024, 128, 128): (1, 1, 32, 64, 1, 4),
+        (256, 256, 2048, 16, 16): (1, 1, 16, 64, 1, 8),
+        (256, 256, 2048, 32, 32): (2, 1, 32, 64, 1, 2),
+        (256, 256, 2048, 64, 64): (1, 1, 32, 32, 1, 1),
+        (256, 256, 2048, 128, 128): (1, 1, 64, 64, 1, 4),
+        (256, 256, 4096, 16, 16): (1, 1, 16, 64, 1, 1),
+        (256, 256, 4096, 32, 32): (2, 2, 32, 64, 1, 2),
+        (256, 256, 4096, 64, 64): (1, 1, 32, 128, 1, 4),
+        (256, 256, 4096, 128, 128): (1, 1, 64, 64, 1, 4),
+        (256, 256, 8192, 16, 16): (1, 2, 16, 64, 1, 2),
+        (256, 256, 8192, 32, 32): (1, 1, 32, 64, 1, 2),
+        (256, 256, 8192, 64, 64): (1, 1, 32, 64, 1, 2),
+        (256, 256, 8192, 128, 128): (1, 1, 64, 64, 1, 4),
+        (256, 256, 16384, 16, 16): (1, 1, 16, 64, 1, 2),
+        (256, 256, 16384, 32, 32): (1, 1, 32, 64, 1, 2),
+        (256, 256, 16384, 64, 64): (1, 1, 64, 64, 1, 2),
+        (256, 256, 16384, 128, 128): (2, 16, 64, 64, 1, 4),
+        (256, 256, 32768, 16, 16): (1, 1, 16, 128, 1, 2),
+        (256, 256, 32768, 32, 32): (1, 1, 32, 64, 1, 2),
+        (256, 256, 32768, 64, 64): (1, 1, 64, 64, 1, 2),
+        (256, 256, 32768, 128, 128): (2, 32, 64, 64, 1, 4),
+        (256, 256, 65536, 16, 16): (1, 1, 16, 64, 1, 1),
+        (256, 256, 65536, 32, 32): (1, 1, 32, 64, 1, 2),
+        (256, 256, 65536, 64, 64): (1, 1, 64, 32, 1, 1),
+        (256, 256, 65536, 128, 128): (2, 32, 64, 64, 1, 4),
+        (256, 256, 131072, 16, 16): (1, 1, 16, 64, 1, 1),
+        (256, 256, 131072, 32, 32): (1, 1, 32, 64, 1, 2),
+        (256, 256, 131072, 64, 64): (4, 1, 64, 32, 1, 1),
+        (256, 256, 131072, 128, 128): (2, 64, 64, 64, 1, 4),
+        (512, 512, 256, 16, 16): (1, 1, 16, 16, 1, 2),
+        (512, 512, 256, 32, 32): (1, 1, 16, 32, 1, 1),
+        (512, 512, 256, 64, 64): (1, 2, 16, 32, 1, 1),
+        (512, 512, 256, 128, 128): (2, 16, 64, 16, 2, 4),
+        (512, 512, 512, 16, 16): (1, 1, 16, 16, 1, 4),
+        (512, 512, 512, 32, 32): (1, 1, 16, 32, 1, 1),
+        (512, 512, 512, 64, 64): (1, 1, 32, 32, 1, 2),
+        (512, 512, 512, 128, 128): (2, 8, 32, 64, 1, 4),
+        (512, 512, 1024, 16, 16): (1, 1, 16, 64, 1, 8),
+        (512, 512, 1024, 32, 32): (1, 1, 32, 32, 3, 1),
+        (512, 512, 1024, 64, 64): (1, 4, 32, 64, 1, 2),
+        (512, 512, 1024, 128, 128): (1, 4, 64, 64, 1, 4),
+        (512, 512, 2048, 16, 16): (1, 1, 16, 64, 1, 2),
+        (512, 512, 2048, 32, 32): (1, 1, 32, 64, 1, 2),
+        (512, 512, 2048, 64, 64): (1, 1, 64, 64, 3, 4),
+        (512, 512, 2048, 128, 128): (1, 1, 64, 64, 1, 4),
+        (512, 512, 4096, 16, 16): (1, 1, 16, 64, 1, 2),
+        (512, 512, 4096, 32, 32): (2, 64, 32, 64, 1, 2),
+        (512, 512, 4096, 64, 64): (1, 1, 64, 64, 3, 4),
+        (512, 512, 4096, 128, 128): (1, 1, 64, 64, 1, 4),
+        (512, 512, 8192, 16, 16): (1, 2, 16, 128, 1, 2),
+        (512, 512, 8192, 32, 32): (1, 1, 32, 64, 1, 2),
+        (512, 512, 8192, 64, 64): (1, 1, 64, 64, 1, 2),
+        (512, 512, 8192, 128, 128): (1, 1, 64, 64, 1, 4),
+        (512, 512, 16384, 16, 16): (1, 2, 16, 128, 1, 2),
+        (512, 512, 16384, 32, 32): (1, 1, 32, 64, 1, 2),
+        (512, 512, 16384, 64, 64): (1, 1, 64, 64, 3, 2),
+        (512, 512, 16384, 128, 128): (2, 1, 64, 64, 1, 4),
+        (512, 512, 32768, 16, 16): (1, 2, 16, 128, 1, 2),
+        (512, 512, 32768, 32, 32): (1, 1, 32, 64, 1, 2),
+        (512, 512, 32768, 64, 64): (1, 1, 64, 64, 3, 4),
+        (512, 512, 32768, 128, 128): (2, 1, 64, 64, 1, 4),
+        (512, 512, 65536, 16, 16): (1, 2, 16, 128, 1, 2),
+        (512, 512, 65536, 32, 32): (1, 1, 32, 64, 1, 2),
+        (512, 512, 65536, 64, 64): (1, 1, 64, 64, 3, 4),
+        (512, 512, 65536, 128, 128): (2, 1, 64, 64, 1, 4),
+        (512, 512, 131072, 16, 16): (1, 1, 16, 64, 1, 1),
+        (512, 512, 131072, 32, 32): (1, 1, 32, 64, 1, 2),
+        (512, 512, 131072, 64, 64): (1, 1, 64, 64, 3, 4),
+        (512, 512, 131072, 128, 128): (2, 4, 64, 64, 1, 4),
+        (1024, 1024, 256, 16, 16): (1, 1, 16, 16, 1, 4),
+        (1024, 1024, 256, 32, 32): (2, 16, 32, 16, 3, 4),
+        (1024, 1024, 256, 64, 64): (1, 4, 32, 32, 1, 2),
+        (1024, 1024, 256, 128, 128): (1, 4, 128, 16, 3, 16),
+        (1024, 1024, 512, 16, 16): (1, 1, 16, 64, 1, 2),
+        (1024, 1024, 512, 32, 32): (2, 2, 32, 64, 1, 2),
+        (1024, 1024, 512, 64, 64): (2, 8, 64, 64, 3, 4),
+        (1024, 1024, 512, 128, 128): (1, 4, 64, 64, 1, 8),
+        (1024, 1024, 1024, 16, 16): (1, 1, 16, 64, 1, 2),
+        (1024, 1024, 1024, 32, 32): (1, 1, 32, 64, 1, 2),
+        (1024, 1024, 1024, 64, 64): (1, 8, 64, 64, 3, 4),
+        (1024, 1024, 1024, 128, 128): (1, 8, 64, 64, 1, 4),
+        (1024, 1024, 2048, 16, 16): (1, 2, 16, 64, 1, 2),
+        (1024, 1024, 2048, 32, 32): (1, 1, 32, 64, 1, 2),
+        (1024, 1024, 2048, 64, 64): (2, 16, 64, 64, 2, 2),
+        (1024, 1024, 2048, 128, 128): (2, 32, 64, 64, 1, 4),
+        (1024, 1024, 4096, 16, 16): (2, 16, 16, 128, 1, 2),
+        (1024, 1024, 4096, 32, 32): (1, 16, 32, 64, 3, 2),
+        (1024, 1024, 4096, 64, 64): (1, 1, 64, 64, 3, 4),
+        (1024, 1024, 4096, 128, 128): (2, 64, 128, 64, 1, 4),
+        (1024, 1024, 8192, 16, 16): (2, 16, 16, 128, 1, 2),
+        (1024, 1024, 8192, 32, 32): (1, 16, 32, 64, 3, 2),
+        (1024, 1024, 8192, 64, 64): (1, 1, 64, 64, 3, 4),
+        (1024, 1024, 8192, 128, 128): (2, 1, 64, 64, 1, 4),
+        (1024, 1024, 16384, 16, 16): (1, 2, 16, 128, 1, 2),
+        (1024, 1024, 16384, 32, 32): (1, 16, 32, 64, 3, 2),
+        (1024, 1024, 16384, 64, 64): (1, 1, 64, 64, 3, 4),
+        (1024, 1024, 16384, 128, 128): (2, 16, 128, 64, 1, 4),
+        (1024, 1024, 32768, 16, 16): (1, 1, 16, 128, 1, 2),
+        (1024, 1024, 32768, 32, 32): (1, 1, 32, 128, 1, 2),
+        (1024, 1024, 32768, 64, 64): (1, 32, 64, 32, 2, 1),
+        (1024, 1024, 32768, 128, 128): (2, 8, 128, 64, 1, 4),
+        (1024, 1024, 65536, 16, 16): (3, 2, 16, 128, 1, 2),
+        (1024, 1024, 65536, 32, 32): (1, 1, 32, 128, 1, 2),
+        (1024, 1024, 65536, 64, 64): (2, 4, 64, 32, 2, 1),
+        (1024, 1024, 65536, 128, 128): (2, 8, 128, 64, 1, 4),
+        (1024, 1024, 131072, 16, 16): (2, 1, 16, 128, 1, 2),
+        (1024, 1024, 131072, 32, 32): (1, 1, 32, 128, 1, 2),
+        (1024, 1024, 131072, 64, 64): (1, 4, 64, 32, 2, 1),
+        (1024, 1024, 131072, 128, 128): (4, 1, 128, 64, 1, 4),
+        (2048, 2048, 256, 16, 16): (1, 1, 16, 64, 1, 8),
+        (2048, 2048, 256, 32, 32): (1, 1, 32, 32, 3, 1),
+        (2048, 2048, 256, 64, 64): (1, 1, 32, 32, 2, 1),
+        (2048, 2048, 256, 128, 128): (1, 4, 64, 64, 1, 8),
+        (2048, 2048, 512, 16, 16): (1, 2, 16, 64, 1, 2),
+        (2048, 2048, 512, 32, 32): (1, 2, 32, 64, 1, 4),
+        (2048, 2048, 512, 64, 64): (1, 4, 64, 64, 1, 8),
+        (2048, 2048, 512, 128, 128): (1, 4, 64, 64, 1, 4),
+        (2048, 2048, 1024, 16, 16): (1, 2, 16, 128, 1, 2),
+        (2048, 2048, 1024, 32, 32): (1, 1, 32, 64, 1, 2),
+        (2048, 2048, 1024, 64, 64): (1, 8, 64, 64, 1, 4),
+        (2048, 2048, 1024, 128, 128): (1, 8, 128, 64, 1, 4),
+        (2048, 2048, 2048, 16, 16): (3, 4, 16, 128, 1, 2),
+        (2048, 2048, 2048, 32, 32): (1, 16, 32, 64, 5, 2),
+        (2048, 2048, 2048, 64, 64): (1, 1, 64, 64, 3, 4),
+        (2048, 2048, 2048, 128, 128): (1, 8, 128, 64, 1, 4),
+        (2048, 2048, 4096, 16, 16): (1, 2, 16, 128, 1, 2),
+        (2048, 2048, 4096, 32, 32): (1, 8, 32, 64, 3, 2),
+        (2048, 2048, 4096, 64, 64): (1, 1, 64, 64, 3, 4),
+        (2048, 2048, 4096, 128, 128): (1, 8, 128, 64, 1, 4),
+        (2048, 2048, 8192, 16, 16): (2, 4, 16, 128, 1, 2),
+        (2048, 2048, 8192, 32, 32): (1, 4, 32, 128, 3, 2),
+        (2048, 2048, 8192, 64, 64): (1, 8, 64, 64, 3, 2),
+        (2048, 2048, 8192, 128, 128): (1, 8, 128, 64, 1, 4),
+        (2048, 2048, 16384, 16, 16): (1, 2, 16, 128, 1, 2),
+        (2048, 2048, 16384, 32, 32): (1, 4, 32, 128, 3, 2),
+        (2048, 2048, 16384, 64, 64): (1, 8, 64, 64, 3, 2),
+        (2048, 2048, 16384, 128, 128): (1, 4, 128, 64, 1, 4),
+        (2048, 2048, 32768, 16, 16): (3, 2, 16, 128, 1, 2),
+        (2048, 2048, 32768, 32, 32): (1, 1, 32, 128, 3, 2),
+        (2048, 2048, 32768, 64, 64): (1, 1, 64, 64, 3, 2),
+        (2048, 2048, 32768, 128, 128): (1, 4, 128, 64, 1, 4),
+        (2048, 2048, 65536, 16, 16): (1, 2, 16, 128, 1, 2),
+        (2048, 2048, 65536, 32, 32): (1, 4, 32, 128, 1, 2),
+        (2048, 2048, 65536, 64, 64): (1, 1, 64, 64, 3, 2),
+        (2048, 2048, 65536, 128, 128): (1, 2, 128, 64, 1, 4),
+        (2048, 2048, 131072, 16, 16): (4, 2, 16, 128, 1, 2),
+        (2048, 2048, 131072, 32, 32): (1, 1, 32, 128, 3, 2),
+        (2048, 2048, 131072, 64, 64): (1, 1, 64, 64, 3, 2),
+        (2048, 2048, 131072, 128, 128): (1, 2, 128, 64, 1, 4),
+        (4096, 4096, 256, 16, 16): (1, 1, 16, 64, 1, 2),
+        (4096, 4096, 256, 32, 32): (1, 1, 32, 64, 3, 4),
+        (4096, 4096, 256, 64, 64): (1, 1, 64, 64, 3, 4),
+        (4096, 4096, 256, 128, 128): (3, 4, 128, 32, 1, 4),
+        (4096, 4096, 512, 16, 16): (1, 2, 16, 128, 1, 2),
+        (4096, 4096, 512, 32, 32): (1, 2, 32, 64, 3, 2),
+        (4096, 4096, 512, 64, 64): (1, 4, 64, 64, 1, 4),
+        (4096, 4096, 512, 128, 128): (1, 4, 128, 64, 1, 4),
+        (4096, 4096, 1024, 16, 16): (1, 2, 16, 128, 1, 2),
+        (4096, 4096, 1024, 32, 32): (1, 8, 32, 64, 3, 2),
+        (4096, 4096, 1024, 64, 64): (1, 4, 64, 64, 1, 4),
+        (4096, 4096, 1024, 128, 128): (2, 4, 128, 64, 1, 4),
+        (4096, 4096, 2048, 16, 16): (1, 1, 16, 128, 1, 2),
+        (4096, 4096, 2048, 32, 32): (1, 4, 32, 128, 1, 4),
+        (4096, 4096, 2048, 64, 64): (1, 1, 64, 64, 3, 4),
+        (4096, 4096, 2048, 128, 128): (1, 16, 128, 64, 1, 4),
+        (4096, 4096, 4096, 16, 16): (1, 1, 16, 64, 3, 1),
+        (4096, 4096, 4096, 32, 32): (1, 4, 32, 64, 3, 2),
+        (4096, 4096, 4096, 64, 64): (1, 1, 64, 64, 3, 4),
+        (4096, 4096, 4096, 128, 128): (5, 1, 128, 64, 1, 4),
+        (4096, 4096, 8192, 16, 16): (1, 1, 16, 128, 1, 2),
+        (4096, 4096, 8192, 32, 32): (1, 1, 32, 128, 3, 2),
+        (4096, 4096, 8192, 64, 64): (1, 1, 64, 64, 3, 4),
+        (4096, 4096, 8192, 128, 128): (2, 1, 128, 64, 1, 4),
+        (4096, 4096, 16384, 16, 16): (1, 1, 16, 128, 1, 2),
+        (4096, 4096, 16384, 32, 32): (1, 1, 32, 128, 3, 2),
+        (4096, 4096, 16384, 64, 64): (1, 1, 64, 64, 4, 4),
+        (4096, 4096, 16384, 128, 128): (2, 1, 128, 64, 1, 4),
+        (4096, 4096, 32768, 16, 16): (3, 1, 16, 128, 1, 2),
+        (4096, 4096, 32768, 32, 32): (1, 1, 32, 128, 3, 2),
+        (4096, 4096, 32768, 64, 64): (1, 1, 64, 64, 3, 4),
+        (4096, 4096, 32768, 128, 128): (2, 1, 128, 64, 1, 4),
+        (4096, 4096, 65536, 16, 16): (2, 2, 16, 128, 1, 2),
+        (4096, 4096, 65536, 32, 32): (1, 1, 32, 128, 4, 2),
+        (4096, 4096, 65536, 64, 64): (1, 1, 64, 64, 4, 4),
+        (4096, 4096, 65536, 128, 128): (2, 1, 128, 64, 1, 4),
+        (4096, 4096, 131072, 16, 16): (2, 1, 16, 128, 1, 2),
+        (4096, 4096, 131072, 32, 32): (1, 1, 32, 128, 3, 2),
+        (4096, 4096, 131072, 64, 64): (1, 1, 64, 64, 3, 4),
+        (4096, 4096, 131072, 128, 128): (2, 1, 128, 64, 1, 4),
+        (8192, 8192, 256, 16, 16): (1, 2, 16, 64, 1, 2),
+        (8192, 8192, 256, 32, 32): (1, 1, 32, 64, 1, 2),
+        (8192, 8192, 256, 64, 64): (1, 2, 64, 64, 1, 4),
+        (8192, 8192, 256, 128, 128): (3, 16, 128, 16, 1, 2),
+        (8192, 8192, 512, 16, 16): (1, 2, 16, 128, 1, 2),
+        (8192, 8192, 512, 32, 32): (1, 4, 32, 64, 3, 2),
+        (8192, 8192, 512, 64, 64): (2, 8, 64, 64, 4, 4),
+        (8192, 8192, 512, 128, 128): (1, 8, 128, 64, 1, 4),
+        (8192, 8192, 1024, 16, 16): (4, 2, 16, 128, 1, 2),
+        (8192, 8192, 1024, 32, 32): (1, 8, 32, 128, 1, 2),
+        (8192, 8192, 1024, 64, 64): (1, 16, 64, 64, 3, 2),
+        (8192, 8192, 1024, 128, 128): (2, 16, 128, 64, 2, 4),
+        (8192, 8192, 2048, 16, 16): (2, 1, 16, 64, 4, 1),
+        (8192, 8192, 2048, 32, 32): (1, 16, 32, 64, 5, 2),
+        (8192, 8192, 2048, 64, 64): (1, 16, 64, 64, 3, 2),
+        (8192, 8192, 2048, 128, 128): (2, 16, 128, 64, 2, 4),
+        (8192, 8192, 4096, 16, 16): (1, 1, 16, 64, 4, 1),
+        (8192, 8192, 4096, 32, 32): (1, 16, 32, 64, 5, 2),
+        (8192, 8192, 4096, 64, 64): (1, 16, 64, 64, 3, 2),
+        (8192, 8192, 4096, 128, 128): (2, 64, 128, 64, 2, 4),
+        (8192, 8192, 8192, 16, 16): (1, 1, 16, 64, 4, 1),
+        (8192, 8192, 8192, 32, 32): (1, 8, 32, 128, 5, 4),
+        (8192, 8192, 8192, 64, 64): (1, 8, 64, 64, 3, 2),
+        (8192, 8192, 8192, 128, 128): (2, 8, 128, 64, 1, 4),
+        (8192, 8192, 16384, 16, 16): (1, 1, 16, 64, 4, 1),
+        (8192, 8192, 16384, 32, 32): (1, 8, 32, 64, 5, 2),
+        (8192, 8192, 16384, 64, 64): (1, 8, 64, 64, 3, 2),
+        (8192, 8192, 16384, 128, 128): (1, 8, 128, 64, 1, 4),
+        (8192, 8192, 32768, 16, 16): (1, 1, 16, 64, 4, 1),
+        (8192, 8192, 32768, 32, 32): (1, 8, 32, 64, 5, 2),
+        (8192, 8192, 32768, 64, 64): (3, 8, 64, 64, 3, 2),
+        (8192, 8192, 32768, 128, 128): (2, 8, 128, 64, 1, 4),
+        (8192, 8192, 65536, 16, 16): (1, 1, 16, 64, 4, 1),
+        (8192, 8192, 65536, 32, 32): (5, 4, 32, 64, 3, 2),
+        (8192, 8192, 65536, 64, 64): (1, 8, 64, 64, 3, 2),
+        (8192, 8192, 65536, 128, 128): (2, 8, 128, 64, 1, 4),
+        (8192, 8192, 131072, 16, 16): (2, 1, 16, 64, 4, 1),
+        (8192, 8192, 131072, 32, 32): (1, 4, 32, 64, 5, 2),
+        (8192, 8192, 131072, 64, 64): (1, 4, 64, 128, 3, 4),
+        (8192, 8192, 131072, 128, 128): (2, 8, 128, 64, 1, 4),
+        (16384, 16384, 256, 16, 16): (1, 2, 16, 128, 1, 2),
+        (16384, 16384, 256, 32, 32): (1, 4, 32, 64, 3, 2),
+        (16384, 16384, 256, 64, 64): (2, 4, 64, 64, 4, 4),
+        (16384, 16384, 256, 128, 128): (1, 4, 128, 64, 1, 16),
+        (16384, 16384, 512, 16, 16): (1, 2, 16, 128, 3, 2),
+        (16384, 16384, 512, 32, 32): (1, 4, 32, 128, 5, 4),
+        (16384, 16384, 512, 64, 64): (1, 8, 64, 64, 3, 2),
+        (16384, 16384, 512, 128, 128): (2, 8, 128, 64, 1, 4),
+        (16384, 16384, 1024, 16, 16): (1, 2, 16, 128, 1, 2),
+        (16384, 16384, 1024, 32, 32): (1, 8, 32, 64, 5, 2),
+        (16384, 16384, 1024, 64, 64): (1, 16, 64, 64, 3, 2),
+        (16384, 16384, 1024, 128, 128): (5, 16, 128, 64, 2, 4),
+        (16384, 16384, 2048, 16, 16): (1, 2, 16, 128, 1, 2),
+        (16384, 16384, 2048, 32, 32): (1, 8, 32, 64, 5, 2),
+        (16384, 16384, 2048, 64, 64): (1, 16, 64, 64, 3, 2),
+        (16384, 16384, 2048, 128, 128): (4, 32, 128, 64, 2, 4),
+        (16384, 16384, 4096, 16, 16): (3, 2, 16, 128, 1, 2),
+        (16384, 16384, 4096, 32, 32): (1, 4, 32, 64, 5, 2),
+        (16384, 16384, 4096, 64, 64): (2, 16, 64, 64, 3, 2),
+        (16384, 16384, 4096, 128, 128): (3, 32, 128, 64, 2, 4),
+        (16384, 16384, 8192, 16, 16): (1, 2, 16, 128, 1, 2),
+        (16384, 16384, 8192, 32, 32): (1, 4, 32, 64, 5, 2),
+        (16384, 16384, 8192, 64, 64): (4, 8, 64, 64, 3, 2),
+        (16384, 16384, 8192, 128, 128): (5, 8, 128, 64, 1, 4),
+        (16384, 16384, 16384, 16, 16): (1, 2, 16, 128, 1, 2),
+        (16384, 16384, 16384, 32, 32): (1, 4, 32, 64, 5, 2),
+        (16384, 16384, 16384, 64, 64): (2, 4, 64, 128, 3, 4),
+        (16384, 16384, 16384, 128, 128): (4, 8, 128, 64, 1, 4),
+        (16384, 16384, 32768, 16, 16): (4, 2, 16, 128, 1, 2),
+        (16384, 16384, 32768, 32, 32): (1, 4, 32, 64, 5, 2),
+        (16384, 16384, 32768, 64, 64): (1, 8, 64, 64, 3, 2),
+        (16384, 16384, 32768, 128, 128): (2, 512, 128, 64, 2, 4),
+        (16384, 16384, 65536, 16, 16): (3, 2, 16, 128, 1, 2),
+        (16384, 16384, 65536, 32, 32): (1, 4, 32, 64, 5, 2),
+        (16384, 16384, 65536, 64, 64): (1, 4, 64, 128, 3, 4),
+        (16384, 16384, 65536, 128, 128): (2, 1024, 128, 64, 2, 4),
+        (16384, 16384, 131072, 16, 16): (1, 2, 16, 128, 1, 2),
+        (16384, 16384, 131072, 32, 32): (1, 4, 32, 64, 5, 2),
+        (16384, 16384, 131072, 64, 64): (3, 4, 64, 128, 3, 4),
+        (16384, 16384, 131072, 128, 128): (4, 2048, 128, 64, 2, 4),
+    },
     ("scatter_mm", "NVIDIA A100-SXM4-80GB", (0, torch.float16, 0.5)): {
         (256, 256, 256, 16, 16): (5, 4, 16, 16, 1, 4),
         (256, 256, 256, 32, 32): (5, 2, 32, 16, 1, 4),
         (256, 256, 256, 64, 64): (4, 1, 32, 32, 1, 8),
         (256, 256, 256, 128, 128): (2, 1, 32, 32, 1, 4),
-        (256, 256, 512, 16, 16): (4, 8, 16, 32, 1, 4),
+        (256, 256, 512, 16, 16): (2, 2, 16, 32, 1, 4),
         (256, 256, 512, 32, 32): (4, 8, 32, 32, 1, 8),
         (256, 256, 512, 64, 64): (4, 8, 32, 64, 1, 4),
         (256, 256, 512, 128, 128): (4, 8, 32, 64, 1, 4),
@@ -1345,12 +2294,12 @@
         (256, 256, 1024, 32, 32): (4, 16, 32, 64, 1, 2),
         (256, 256, 1024, 64, 64): (4, 16, 32, 64, 1, 4),
         (256, 256, 1024, 128, 128): (4, 16, 64, 64, 1, 8),
-        (256, 256, 2048, 16, 16): (4, 16, 16, 64, 1, 1),
+        (256, 256, 2048, 16, 16): (2, 16, 16, 64, 1, 8),
         (256, 256, 2048, 32, 32): (4, 16, 32, 64, 1, 2),
         (256, 256, 2048, 64, 64): (4, 16, 32, 64, 1, 4),
         (256, 256, 2048, 128, 128): (4, 16, 64, 64, 1, 4),
         (256, 256, 4096, 16, 16): (4, 32, 16, 64, 1, 1),
-        (256, 256, 4096, 32, 32): (4, 32, 32, 64, 1, 2),
+        (256, 256, 4096, 32, 32): (2, 64, 32, 64, 1, 2),
         (256, 256, 4096, 64, 64): (4, 64, 64, 64, 1, 4),
         (256, 256, 4096, 128, 128): (4, 32, 64, 64, 1, 4),
         (256, 256, 8192, 16, 16): (4, 64, 16, 64, 1, 1),
@@ -1358,33 +2307,33 @@
         (256, 256, 8192, 64, 64): (4, 64, 64, 64, 1, 4),
         (256, 256, 8192, 128, 128): (4, 64, 64, 64, 1, 4),
         (256, 256, 16384, 16, 16): (4, 128, 16, 64, 1, 1),
-        (256, 256, 16384, 32, 32): (4, 16, 32, 64, 1, 2),
+        (256, 256, 16384, 32, 32): (2, 128, 32, 64, 1, 2),
         (256, 256, 16384, 64, 64): (4, 32, 32, 128, 1, 4),
         (256, 256, 16384, 128, 128): (4, 16, 64, 64, 1, 4),
         (256, 256, 32768, 16, 16): (4, 64, 16, 64, 1, 1),
-        (256, 256, 32768, 32, 32): (4, 32, 32, 64, 1, 2),
+        (256, 256, 32768, 32, 32): (2, 256, 32, 64, 1, 2),
         (256, 256, 32768, 64, 64): (4, 32, 32, 128, 1, 4),
         (256, 256, 32768, 128, 128): (4, 32, 64, 64, 1, 4),
         (256, 256, 65536, 16, 16): (4, 128, 16, 64, 1, 1),
-        (256, 256, 65536, 32, 32): (4, 16, 32, 64, 1, 2),
-        (256, 256, 65536, 64, 64): (4, 16, 64, 64, 1, 2),
+        (256, 256, 65536, 32, 32): (4, 1, 32, 64, 1, 2),
+        (256, 256, 65536, 64, 64): (2, 1, 64, 64, 1, 2),
         (256, 256, 65536, 128, 128): (4, 32, 64, 64, 1, 4),
         (256, 256, 131072, 16, 16): (4, 64, 16, 64, 1, 1),
-        (256, 256, 131072, 32, 32): (4, 2, 32, 64, 1, 2),
+        (256, 256, 131072, 32, 32): (2, 1, 32, 64, 1, 2),
         (256, 256, 131072, 64, 64): (4, 32, 32, 128, 1, 4),
         (256, 256, 131072, 128, 128): (4, 32, 64, 64, 1, 4),
         (512, 512, 256, 16, 16): (4, 16, 16, 16, 1, 4),
-        (512, 512, 256, 32, 32): (4, 16, 32, 16, 1, 4),
-        (512, 512, 256, 64, 64): (4, 16, 64, 16, 1, 8),
+        (512, 512, 256, 32, 32): (2, 4, 32, 16, 1, 4),
+        (512, 512, 256, 64, 64): (2, 16, 64, 16, 3, 8),
         (512, 512, 256, 128, 128): (4, 16, 64, 16, 1, 4),
-        (512, 512, 512, 16, 16): (2, 1, 16, 64, 1, 2),
+        (512, 512, 512, 16, 16): (1, 1, 16, 64, 1, 8),
         (512, 512, 512, 32, 32): (2, 4, 16, 32, 1, 1),
         (512, 512, 512, 64, 64): (2, 1, 32, 32, 1, 2),
         (512, 512, 512, 128, 128): (4, 8, 32, 64, 1, 4),
-        (512, 512, 1024, 16, 16): (4, 8, 16, 64, 1, 1),
+        (512, 512, 1024, 16, 16): (2, 8, 16, 64, 1, 8),
         (512, 512, 1024, 32, 32): (4, 16, 32, 64, 1, 2),
         (512, 512, 1024, 64, 64): (4, 16, 64, 64, 1, 4),
-        (512, 512, 1024, 128, 128): (4, 16, 64, 64, 1, 4),
+        (512, 512, 1024, 128, 128): (2, 8, 64, 64, 1, 4),
         (512, 512, 2048, 16, 16): (4, 16, 16, 64, 1, 4),
         (512, 512, 2048, 32, 32): (4, 16, 32, 64, 1, 2),
         (512, 512, 2048, 64, 64): (4, 16, 64, 64, 1, 8),
@@ -1393,7 +2342,7 @@
         (512, 512, 4096, 32, 32): (4, 32, 32, 64, 1, 2),
         (512, 512, 4096, 64, 64): (4, 32, 64, 64, 1, 4),
         (512, 512, 4096, 128, 128): (4, 32, 64, 64, 1, 4),
-        (512, 512, 8192, 16, 16): (3, 16, 16, 128, 1, 2),
+        (512, 512, 8192, 16, 16): (2, 32, 16, 128, 1, 2),
         (512, 512, 8192, 32, 32): (4, 64, 32, 64, 1, 2),
         (512, 512, 8192, 64, 64): (4, 128, 64, 64, 1, 2),
         (512, 512, 8192, 128, 128): (4, 64, 64, 64, 1, 4),
@@ -1401,31 +2350,31 @@
         (512, 512, 16384, 32, 32): (4, 64, 32, 64, 1, 2),
         (512, 512, 16384, 64, 64): (4, 16, 64, 64, 1, 4),
         (512, 512, 16384, 128, 128): (4, 32, 64, 64, 1, 4),
-        (512, 512, 32768, 16, 16): (6, 16, 16, 128, 1, 2),
+        (512, 512, 32768, 16, 16): (7, 16, 16, 128, 1, 2),
         (512, 512, 32768, 32, 32): (4, 64, 32, 64, 1, 2),
-        (512, 512, 32768, 64, 64): (4, 32, 64, 64, 1, 2),
-        (512, 512, 32768, 128, 128): (4, 16, 64, 64, 1, 4),
-        (512, 512, 65536, 16, 16): (4, 32, 16, 64, 1, 1),
+        (512, 512, 32768, 64, 64): (2, 32, 64, 64, 3, 2),
+        (512, 512, 32768, 128, 128): (2, 32, 64, 64, 1, 4),
+        (512, 512, 65536, 16, 16): (2, 32, 16, 64, 1, 1),
         (512, 512, 65536, 32, 32): (4, 64, 32, 64, 1, 2),
-        (512, 512, 65536, 64, 64): (5, 32, 64, 64, 1, 2),
+        (512, 512, 65536, 64, 64): (3, 32, 64, 64, 3, 2),
         (512, 512, 65536, 128, 128): (4, 16, 64, 64, 1, 4),
         (512, 512, 131072, 16, 16): (3, 32, 16, 128, 1, 2),
         (512, 512, 131072, 32, 32): (4, 64, 32, 64, 1, 2),
-        (512, 512, 131072, 64, 64): (4, 32, 64, 64, 1, 2),
-        (512, 512, 131072, 128, 128): (4, 16, 64, 64, 1, 4),
+        (512, 512, 131072, 64, 64): (2, 32, 64, 64, 3, 2),
+        (512, 512, 131072, 128, 128): (3, 1, 64, 64, 1, 4),
         (1024, 1024, 256, 16, 16): (4, 16, 16, 16, 1, 4),
         (1024, 1024, 256, 32, 32): (4, 16, 32, 16, 1, 4),
         (1024, 1024, 256, 64, 64): (4, 4, 64, 32, 1, 16),
         (1024, 1024, 256, 128, 128): (4, 16, 64, 16, 1, 8),
-        (1024, 1024, 512, 16, 16): (4, 8, 16, 64, 1, 1),
-        (1024, 1024, 512, 32, 32): (5, 8, 32, 64, 1, 2),
+        (1024, 1024, 512, 16, 16): (2, 8, 16, 64, 1, 8),
+        (1024, 1024, 512, 32, 32): (3, 2, 32, 64, 1, 2),
         (1024, 1024, 512, 64, 64): (4, 8, 32, 64, 1, 8),
         (1024, 1024, 512, 128, 128): (4, 8, 64, 64, 1, 8),
         (1024, 1024, 1024, 16, 16): (2, 2, 16, 64, 1, 2),
         (1024, 1024, 1024, 32, 32): (2, 8, 32, 64, 1, 2),
         (1024, 1024, 1024, 64, 64): (2, 8, 32, 128, 1, 4),
         (1024, 1024, 1024, 128, 128): (2, 8, 64, 64, 1, 4),
-        (1024, 1024, 2048, 16, 16): (4, 16, 16, 128, 1, 2),
+        (1024, 1024, 2048, 16, 16): (2, 16, 16, 128, 3, 2),
         (1024, 1024, 2048, 32, 32): (4, 32, 32, 64, 1, 2),
         (1024, 1024, 2048, 64, 64): (4, 16, 64, 64, 1, 4),
         (1024, 1024, 2048, 128, 128): (4, 32, 64, 64, 1, 4),
@@ -1434,183 +2383,183 @@
         (1024, 1024, 4096, 64, 64): (4, 32, 64, 64, 1, 4),
         (1024, 1024, 4096, 128, 128): (4, 32, 64, 64, 1, 4),
         (1024, 1024, 8192, 16, 16): (5, 16, 16, 128, 1, 2),
-        (1024, 1024, 8192, 32, 32): (4, 32, 32, 64, 1, 2),
-        (1024, 1024, 8192, 64, 64): (3, 64, 64, 64, 3, 2),
+        (1024, 1024, 8192, 32, 32): (2, 32, 32, 64, 3, 2),
+        (1024, 1024, 8192, 64, 64): (1, 16, 64, 64, 3, 2),
         (1024, 1024, 8192, 128, 128): (4, 32, 64, 64, 1, 4),
         (1024, 1024, 16384, 16, 16): (4, 16, 16, 128, 1, 2),
-        (1024, 1024, 16384, 32, 32): (3, 32, 32, 64, 1, 2),
+        (1024, 1024, 16384, 32, 32): (1, 32, 32, 64, 3, 2),
         (1024, 1024, 16384, 64, 64): (4, 16, 64, 64, 3, 2),
         (1024, 1024, 16384, 128, 128): (4, 32, 128, 64, 1, 4),
-        (1024, 1024, 32768, 16, 16): (4, 16, 16, 128, 1, 2),
-        (1024, 1024, 32768, 32, 32): (3, 32, 32, 64, 1, 2),
+        (1024, 1024, 32768, 16, 16): (3, 16, 16, 128, 1, 2),
+        (1024, 1024, 32768, 32, 32): (1, 8, 32, 64, 3, 2),
         (1024, 1024, 32768, 64, 64): (4, 16, 64, 64, 3, 2),
         (1024, 1024, 32768, 128, 128): (4, 8, 128, 64, 2, 4),
-        (1024, 1024, 65536, 16, 16): (4, 8, 16, 128, 1, 2),
-        (1024, 1024, 65536, 32, 32): (4, 16, 32, 64, 1, 2),
-        (1024, 1024, 65536, 64, 64): (4, 16, 64, 64, 3, 2),
+        (1024, 1024, 65536, 16, 16): (1, 2, 16, 128, 1, 2),
+        (1024, 1024, 65536, 32, 32): (2, 4, 32, 64, 3, 2),
+        (1024, 1024, 65536, 64, 64): (5, 16, 64, 64, 3, 2),
         (1024, 1024, 65536, 128, 128): (5, 8, 128, 64, 2, 4),
-        (1024, 1024, 131072, 16, 16): (4, 8, 16, 128, 1, 2),
-        (1024, 1024, 131072, 32, 32): (4, 16, 32, 64, 1, 2),
+        (1024, 1024, 131072, 16, 16): (5, 2, 16, 128, 1, 2),
+        (1024, 1024, 131072, 32, 32): (1, 2, 32, 64, 3, 2),
         (1024, 1024, 131072, 64, 64): (5, 16, 64, 64, 3, 2),
-        (1024, 1024, 131072, 128, 128): (4, 8, 128, 64, 2, 4),
+        (1024, 1024, 131072, 128, 128): (2, 1, 128, 64, 2, 4),
         (2048, 2048, 256, 16, 16): (4, 4, 16, 64, 1, 8),
         (2048, 2048, 256, 32, 32): (4, 8, 32, 32, 1, 8),
         (2048, 2048, 256, 64, 64): (4, 16, 64, 16, 1, 8),
         (2048, 2048, 256, 128, 128): (4, 4, 128, 32, 3, 8),
-        (2048, 2048, 512, 16, 16): (4, 8, 16, 64, 1, 2),
-        (2048, 2048, 512, 32, 32): (4, 4, 32, 64, 1, 2),
+        (2048, 2048, 512, 16, 16): (2, 2, 16, 64, 1, 2),
+        (2048, 2048, 512, 32, 32): (2, 4, 32, 64, 3, 2),
         (2048, 2048, 512, 64, 64): (4, 4, 64, 64, 1, 8),
         (2048, 2048, 512, 128, 128): (4, 8, 64, 64, 1, 4),
-        (2048, 2048, 1024, 16, 16): (3, 8, 16, 64, 1, 2),
-        (2048, 2048, 1024, 32, 32): (4, 16, 32, 64, 1, 2),
+        (2048, 2048, 1024, 16, 16): (1, 8, 16, 64, 1, 2),
+        (2048, 2048, 1024, 32, 32): (2, 16, 32, 64, 3, 2),
         (2048, 2048, 1024, 64, 64): (4, 8, 64, 64, 1, 4),
         (2048, 2048, 1024, 128, 128): (4, 8, 128, 64, 1, 4),
-        (2048, 2048, 2048, 16, 16): (4, 4, 16, 128, 1, 2),
-        (2048, 2048, 2048, 32, 32): (2, 16, 32, 64, 1, 2),
+        (2048, 2048, 2048, 16, 16): (5, 4, 16, 128, 1, 2),
+        (2048, 2048, 2048, 32, 32): (1, 16, 32, 64, 3, 2),
         (2048, 2048, 2048, 64, 64): (2, 8, 64, 64, 1, 4),
         (2048, 2048, 2048, 128, 128): (2, 8, 128, 64, 1, 4),
         (2048, 2048, 4096, 16, 16): (4, 2, 16, 128, 1, 2),
-        (2048, 2048, 4096, 32, 32): (4, 16, 32, 64, 1, 2),
-        (2048, 2048, 4096, 64, 64): (4, 32, 64, 64, 3, 2),
+        (2048, 2048, 4096, 32, 32): (2, 16, 32, 64, 3, 2),
+        (2048, 2048, 4096, 64, 64): (2, 8, 64, 64, 3, 2),
         (2048, 2048, 4096, 128, 128): (4, 8, 128, 64, 1, 4),
         (2048, 2048, 8192, 16, 16): (5, 4, 16, 128, 1, 2),
-        (2048, 2048, 8192, 32, 32): (4, 64, 32, 64, 1, 2),
+        (2048, 2048, 8192, 32, 32): (2, 8, 32, 64, 3, 2),
         (2048, 2048, 8192, 64, 64): (4, 8, 64, 64, 3, 2),
         (2048, 2048, 8192, 128, 128): (4, 8, 128, 64, 1, 4),
         (2048, 2048, 16384, 16, 16): (3, 2, 16, 128, 1, 2),
-        (2048, 2048, 16384, 32, 32): (4, 8, 32, 64, 1, 2),
+        (2048, 2048, 16384, 32, 32): (2, 4, 32, 128, 3, 2),
         (2048, 2048, 16384, 64, 64): (4, 8, 64, 64, 3, 2),
         (2048, 2048, 16384, 128, 128): (4, 4, 128, 64, 1, 4),
-        (2048, 2048, 32768, 16, 16): (6, 2, 16, 128, 1, 2),
-        (2048, 2048, 32768, 32, 32): (5, 8, 32, 64, 1, 2),
+        (2048, 2048, 32768, 16, 16): (3, 2, 16, 128, 1, 2),
+        (2048, 2048, 32768, 32, 32): (3, 4, 32, 128, 3, 2),
         (2048, 2048, 32768, 64, 64): (6, 4, 64, 64, 3, 2),
         (2048, 2048, 32768, 128, 128): (3, 4, 128, 64, 1, 4),
-        (2048, 2048, 65536, 16, 16): (7, 2, 16, 128, 1, 2),
-        (2048, 2048, 65536, 32, 32): (3, 1, 32, 128, 1, 2),
+        (2048, 2048, 65536, 16, 16): (6, 2, 16, 128, 1, 2),
+        (2048, 2048, 65536, 32, 32): (1, 2, 32, 128, 1, 2),
         (2048, 2048, 65536, 64, 64): (5, 4, 64, 64, 3, 2),
         (2048, 2048, 65536, 128, 128): (5, 1, 128, 64, 2, 4),
         (2048, 2048, 131072, 16, 16): (3, 2, 16, 128, 1, 2),
-        (2048, 2048, 131072, 32, 32): (4, 2, 32, 128, 1, 4),
+        (2048, 2048, 131072, 32, 32): (2, 1, 32, 128, 3, 2),
         (2048, 2048, 131072, 64, 64): (4, 1, 64, 64, 3, 2),
         (2048, 2048, 131072, 128, 128): (3, 1, 128, 64, 2, 4),
         (4096, 4096, 256, 16, 16): (5, 8, 16, 32, 1, 4),
         (4096, 4096, 256, 32, 32): (4, 16, 32, 16, 2, 4),
-        (4096, 4096, 256, 64, 64): (4, 8, 64, 32, 1, 4),
+        (4096, 4096, 256, 64, 64): (2, 1, 64, 64, 3, 4),
         (4096, 4096, 256, 128, 128): (4, 4, 128, 32, 1, 4),
         (4096, 4096, 512, 16, 16): (4, 2, 16, 128, 1, 2),
         (4096, 4096, 512, 32, 32): (4, 8, 32, 64, 1, 2),
         (4096, 4096, 512, 64, 64): (4, 4, 64, 64, 1, 4),
         (4096, 4096, 512, 128, 128): (4, 8, 128, 64, 2, 4),
-        (4096, 4096, 1024, 16, 16): (4, 8, 16, 128, 1, 2),
-        (4096, 4096, 1024, 32, 32): (4, 8, 32, 64, 1, 2),
-        (4096, 4096, 1024, 64, 64): (4, 16, 64, 64, 1, 4),
-        (4096, 4096, 1024, 128, 128): (4, 16, 128, 64, 2, 4),
-        (4096, 4096, 2048, 16, 16): (5, 8, 16, 128, 1, 2),
-        (4096, 4096, 2048, 32, 32): (3, 4, 32, 64, 1, 2),
+        (4096, 4096, 1024, 16, 16): (1, 2, 16, 128, 1, 2),
+        (4096, 4096, 1024, 32, 32): (6, 8, 32, 64, 3, 2),
+        (4096, 4096, 1024, 64, 64): (2, 16, 64, 64, 4, 4),
+        (4096, 4096, 1024, 128, 128): (2, 4, 128, 64, 2, 4),
+        (4096, 4096, 2048, 16, 16): (3, 1, 16, 128, 1, 2),
+        (4096, 4096, 2048, 32, 32): (1, 4, 32, 64, 5, 2),
         (4096, 4096, 2048, 64, 64): (3, 16, 64, 64, 3, 2),
         (4096, 4096, 2048, 128, 128): (4, 32, 128, 64, 2, 4),
         (4096, 4096, 4096, 16, 16): (1, 2, 16, 128, 1, 2),
-        (4096, 4096, 4096, 32, 32): (3, 4, 32, 64, 3, 2),
+        (4096, 4096, 4096, 32, 32): (1, 4, 32, 64, 3, 2),
         (4096, 4096, 4096, 64, 64): (1, 1, 64, 64, 4, 4),
-        (4096, 4096, 4096, 128, 128): (1, 1, 128, 128, 1, 8),
-        (4096, 4096, 8192, 16, 16): (5, 8, 16, 128, 1, 2),
-        (4096, 4096, 8192, 32, 32): (4, 4, 32, 64, 1, 2),
+        (4096, 4096, 4096, 128, 128): (2, 1, 128, 128, 1, 8),
+        (4096, 4096, 8192, 16, 16): (3, 1, 16, 128, 1, 2),
+        (4096, 4096, 8192, 32, 32): (2, 2, 32, 64, 5, 2),
         (4096, 4096, 8192, 64, 64): (4, 16, 64, 64, 3, 2),
         (4096, 4096, 8192, 128, 128): (4, 16, 128, 64, 2, 4),
-        (4096, 4096, 16384, 16, 16): (4, 8, 16, 128, 1, 2),
-        (4096, 4096, 16384, 32, 32): (6, 2, 32, 64, 1, 2),
+        (4096, 4096, 16384, 16, 16): (1, 2, 16, 128, 1, 2),
+        (4096, 4096, 16384, 32, 32): (4, 2, 32, 64, 5, 2),
         (4096, 4096, 16384, 64, 64): (4, 16, 64, 64, 3, 2),
         (4096, 4096, 16384, 128, 128): (4, 16, 128, 64, 2, 4),
-        (4096, 4096, 32768, 16, 16): (2, 8, 16, 128, 1, 2),
+        (4096, 4096, 32768, 16, 16): (3, 1, 16, 128, 1, 2),
         (4096, 4096, 32768, 32, 32): (3, 1, 32, 128, 1, 4),
-        (4096, 4096, 32768, 64, 64): (5, 8, 64, 64, 3, 2),
+        (4096, 4096, 32768, 64, 64): (3, 1, 64, 64, 3, 4),
         (4096, 4096, 32768, 128, 128): (5, 16, 128, 64, 2, 4),
-        (4096, 4096, 65536, 16, 16): (6, 8, 16, 128, 1, 2),
+        (4096, 4096, 65536, 16, 16): (5, 1, 16, 128, 1, 2),
         (4096, 4096, 65536, 32, 32): (5, 1, 32, 128, 1, 4),
-        (4096, 4096, 65536, 64, 64): (3, 8, 64, 64, 3, 2),
+        (4096, 4096, 65536, 64, 64): (1, 1, 64, 64, 3, 4),
         (4096, 4096, 65536, 128, 128): (3, 16, 128, 64, 2, 4),
-        (4096, 4096, 131072, 16, 16): (5, 8, 16, 128, 1, 2),
-        (4096, 4096, 131072, 32, 32): (5, 4, 32, 64, 1, 2),
-        (4096, 4096, 131072, 64, 64): (5, 8, 64, 64, 3, 2),
-        (4096, 4096, 131072, 128, 128): (4, 16, 128, 64, 2, 4),
+        (4096, 4096, 131072, 16, 16): (5, 1, 16, 128, 1, 2),
+        (4096, 4096, 131072, 32, 32): (3, 1, 32, 128, 3, 2),
+        (4096, 4096, 131072, 64, 64): (2, 1, 64, 64, 3, 4),
+        (4096, 4096, 131072, 128, 128): (1, 1, 128, 64, 1, 4),
         (8192, 8192, 256, 16, 16): (4, 16, 16, 16, 1, 4),
-        (8192, 8192, 256, 32, 32): (4, 16, 32, 16, 4, 4),
+        (8192, 8192, 256, 32, 32): (1, 16, 32, 16, 4, 4),
         (8192, 8192, 256, 64, 64): (4, 16, 64, 16, 3, 8),
         (8192, 8192, 256, 128, 128): (4, 16, 128, 16, 1, 2),
-        (8192, 8192, 512, 16, 16): (5, 8, 16, 64, 1, 4),
-        (8192, 8192, 512, 32, 32): (4, 4, 32, 64, 1, 2),
-        (8192, 8192, 512, 64, 64): (4, 4, 64, 64, 1, 4),
+        (8192, 8192, 512, 16, 16): (2, 8, 16, 64, 1, 4),
+        (8192, 8192, 512, 32, 32): (4, 8, 32, 64, 3, 2),
+        (8192, 8192, 512, 64, 64): (2, 8, 64, 64, 4, 4),
         (8192, 8192, 512, 128, 128): (4, 8, 128, 64, 2, 4),
         (8192, 8192, 1024, 16, 16): (4, 16, 16, 64, 1, 8),
-        (8192, 8192, 1024, 32, 32): (4, 4, 32, 64, 1, 2),
-        (8192, 8192, 1024, 64, 64): (4, 16, 64, 64, 3, 2),
-        (8192, 8192, 1024, 128, 128): (4, 16, 128, 64, 2, 4),
-        (8192, 8192, 2048, 16, 16): (5, 2, 16, 128, 1, 2),
-        (8192, 8192, 2048, 32, 32): (4, 16, 32, 64, 1, 2),
+        (8192, 8192, 1024, 32, 32): (2, 8, 32, 64, 5, 2),
+        (8192, 8192, 1024, 64, 64): (1, 16, 64, 64, 3, 2),
+        (8192, 8192, 1024, 128, 128): (5, 16, 128, 64, 2, 4),
+        (8192, 8192, 2048, 16, 16): (7, 2, 16, 128, 1, 2),
+        (8192, 8192, 2048, 32, 32): (1, 16, 32, 64, 5, 2),
         (8192, 8192, 2048, 64, 64): (4, 16, 64, 64, 3, 2),
         (8192, 8192, 2048, 128, 128): (6, 16, 128, 64, 2, 4),
         (8192, 8192, 4096, 16, 16): (4, 2, 16, 128, 1, 2),
-        (8192, 8192, 4096, 32, 32): (4, 4, 32, 64, 1, 2),
+        (8192, 8192, 4096, 32, 32): (2, 8, 32, 64, 5, 2),
         (8192, 8192, 4096, 64, 64): (3, 16, 64, 64, 3, 2),
         (8192, 8192, 4096, 128, 128): (3, 64, 128, 64, 2, 4),
-        (8192, 8192, 8192, 16, 16): (3, 2, 16, 128, 1, 2),
-        (8192, 8192, 8192, 32, 32): (2, 4, 32, 128, 1, 4),
+        (8192, 8192, 8192, 16, 16): (4, 2, 16, 128, 1, 2),
+        (8192, 8192, 8192, 32, 32): (1, 4, 32, 128, 5, 4),
         (8192, 8192, 8192, 64, 64): (4, 4, 64, 64, 1, 4),
         (8192, 8192, 8192, 128, 128): (2, 2, 128, 128, 3, 8),
-        (8192, 8192, 16384, 16, 16): (4, 8, 16, 128, 1, 2),
-        (8192, 8192, 16384, 32, 32): (3, 4, 32, 64, 1, 2),
+        (8192, 8192, 16384, 16, 16): (1, 2, 16, 128, 1, 2),
+        (8192, 8192, 16384, 32, 32): (4, 8, 32, 64, 5, 2),
         (8192, 8192, 16384, 64, 64): (5, 8, 64, 64, 3, 2),
         (8192, 8192, 16384, 128, 128): (3, 16, 128, 64, 2, 4),
-        (8192, 8192, 32768, 16, 16): (3, 2, 16, 128, 1, 2),
-        (8192, 8192, 32768, 32, 32): (4, 4, 32, 64, 1, 2),
+        (8192, 8192, 32768, 16, 16): (7, 2, 16, 128, 1, 2),
+        (8192, 8192, 32768, 32, 32): (3, 4, 32, 64, 3, 2),
         (8192, 8192, 32768, 64, 64): (2, 8, 64, 64, 3, 2),
         (8192, 8192, 32768, 128, 128): (6, 16, 128, 64, 2, 4),
         (8192, 8192, 65536, 16, 16): (9, 2, 16, 128, 1, 2),
-        (8192, 8192, 65536, 32, 32): (6, 4, 32, 64, 1, 2),
+        (8192, 8192, 65536, 32, 32): (7, 4, 32, 64, 5, 2),
         (8192, 8192, 65536, 64, 64): (4, 8, 64, 64, 3, 2),
         (8192, 8192, 65536, 128, 128): (3, 16, 128, 64, 2, 4),
-        (8192, 8192, 131072, 16, 16): (7, 2, 16, 128, 1, 2),
-        (8192, 8192, 131072, 32, 32): (3, 8, 32, 64, 1, 2),
+        (8192, 8192, 131072, 16, 16): (9, 2, 16, 128, 1, 2),
+        (8192, 8192, 131072, 32, 32): (1, 8, 32, 64, 5, 2),
         (8192, 8192, 131072, 64, 64): (1, 8, 64, 64, 3, 2),
         (8192, 8192, 131072, 128, 128): (4, 16, 128, 64, 2, 4),
         (16384, 16384, 256, 16, 16): (5, 16, 16, 16, 1, 4),
         (16384, 16384, 256, 32, 32): (4, 16, 32, 16, 4, 4),
         (16384, 16384, 256, 64, 64): (4, 16, 64, 16, 3, 8),
         (16384, 16384, 256, 128, 128): (4, 16, 128, 16, 1, 2),
-        (16384, 16384, 512, 16, 16): (4, 4, 16, 64, 1, 4),
-        (16384, 16384, 512, 32, 32): (4, 4, 32, 64, 1, 2),
+        (16384, 16384, 512, 16, 16): (2, 8, 16, 64, 1, 4),
+        (16384, 16384, 512, 32, 32): (1, 4, 32, 64, 5, 2),
         (16384, 16384, 512, 64, 64): (4, 8, 64, 64, 1, 4),
         (16384, 16384, 512, 128, 128): (3, 8, 128, 64, 2, 4),
         (16384, 16384, 1024, 16, 16): (4, 2, 16, 128, 1, 2),
-        (16384, 16384, 1024, 32, 32): (4, 8, 32, 64, 1, 2),
+        (16384, 16384, 1024, 32, 32): (4, 8, 32, 64, 5, 2),
         (16384, 16384, 1024, 64, 64): (6, 16, 64, 64, 3, 2),
         (16384, 16384, 1024, 128, 128): (3, 16, 128, 64, 2, 4),
         (16384, 16384, 2048, 16, 16): (3, 2, 16, 128, 1, 2),
-        (16384, 16384, 2048, 32, 32): (5, 8, 32, 64, 1, 2),
+        (16384, 16384, 2048, 32, 32): (1, 8, 32, 64, 5, 2),
         (16384, 16384, 2048, 64, 64): (5, 16, 64, 64, 3, 2),
-        (16384, 16384, 2048, 128, 128): (3, 32, 128, 64, 2, 4),
-        (16384, 16384, 4096, 16, 16): (3, 2, 16, 128, 1, 2),
-        (16384, 16384, 4096, 32, 32): (5, 4, 32, 64, 1, 2),
-        (16384, 16384, 4096, 64, 64): (4, 16, 64, 64, 3, 2),
+        (16384, 16384, 2048, 128, 128): (2, 32, 128, 64, 2, 4),
+        (16384, 16384, 4096, 16, 16): (2, 2, 16, 128, 1, 2),
+        (16384, 16384, 4096, 32, 32): (1, 4, 32, 64, 3, 2),
+        (16384, 16384, 4096, 64, 64): (2, 8, 64, 64, 3, 2),
         (16384, 16384, 4096, 128, 128): (3, 16, 128, 64, 2, 4),
-        (16384, 16384, 8192, 16, 16): (4, 2, 16, 128, 1, 2),
-        (16384, 16384, 8192, 32, 32): (4, 4, 32, 64, 1, 2),
+        (16384, 16384, 8192, 16, 16): (3, 2, 16, 128, 1, 2),
+        (16384, 16384, 8192, 32, 32): (2, 4, 32, 64, 5, 2),
         (16384, 16384, 8192, 64, 64): (4, 8, 64, 64, 3, 2),
-        (16384, 16384, 8192, 128, 128): (6, 32, 128, 64, 2, 4),
+        (16384, 16384, 8192, 128, 128): (8, 32, 128, 64, 2, 4),
         (16384, 16384, 16384, 16, 16): (1, 2, 16, 256, 1, 4),
-        (16384, 16384, 16384, 32, 32): (2, 4, 32, 128, 1, 4),
+        (16384, 16384, 16384, 32, 32): (1, 4, 32, 128, 3, 4),
         (16384, 16384, 16384, 64, 64): (5, 4, 64, 64, 1, 4),
         (16384, 16384, 16384, 128, 128): (4, 8, 128, 64, 2, 4),
         (16384, 16384, 32768, 16, 16): (2, 2, 16, 128, 1, 2),
-        (16384, 16384, 32768, 32, 32): (2, 4, 32, 64, 1, 2),
+        (16384, 16384, 32768, 32, 32): (1, 4, 32, 64, 3, 2),
         (16384, 16384, 32768, 64, 64): (5, 4, 64, 64, 1, 4),
         (16384, 16384, 32768, 128, 128): (5, 8, 128, 64, 2, 4),
-        (16384, 16384, 65536, 16, 16): (5, 2, 16, 128, 1, 2),
-        (16384, 16384, 65536, 32, 32): (4, 2, 32, 64, 1, 2),
+        (16384, 16384, 65536, 16, 16): (8, 2, 16, 128, 1, 2),
+        (16384, 16384, 65536, 32, 32): (6, 4, 32, 64, 5, 2),
         (16384, 16384, 65536, 64, 64): (2, 4, 64, 64, 1, 4),
         (16384, 16384, 65536, 128, 128): (4, 8, 128, 64, 2, 4),
-        (16384, 16384, 131072, 16, 16): (3, 2, 16, 128, 1, 2),
-        (16384, 16384, 131072, 32, 32): (3, 4, 32, 64, 1, 2),
+        (16384, 16384, 131072, 16, 16): (3, 1, 16, 128, 1, 2),
+        (16384, 16384, 131072, 32, 32): (1, 4, 32, 64, 3, 2),
         (16384, 16384, 131072, 64, 64): (4, 4, 64, 64, 1, 4),
         (16384, 16384, 131072, 128, 128): (1, 8, 128, 64, 2, 4),
         (32768, 32768, 256, 16, 16): (4, 16, 16, 16, 1, 4),
@@ -1622,9 +2571,292 @@
         (32768, 32768, 16384, 16, 16): (4, 4, 16, 64, 1, 1),
         (32768, 32768, 32768, 16, 16): (5, 4, 16, 64, 1, 1),
     },
+    ("scatter_mm", "NVIDIA A100-SXM4-80GB", (0, torch.float32, 0.5)): {
+        (256, 256, 256, 16, 16): (1, 1, 16, 16, 1, 8),
+        (256, 256, 256, 32, 32): (1, 1, 16, 16, 1, 4),
+        (256, 256, 256, 64, 64): (1, 1, 16, 16, 1, 4),
+        (256, 256, 256, 128, 128): (1, 1, 16, 16, 1, 1),
+        (256, 256, 512, 16, 16): (1, 1, 16, 16, 1, 4),
+        (256, 256, 512, 32, 32): (1, 16, 16, 16, 1, 1),
+        (256, 256, 512, 64, 64): (1, 1, 16, 16, 1, 1),
+        (256, 256, 512, 128, 128): (1, 1, 32, 32, 1, 4),
+        (256, 256, 1024, 16, 16): (1, 1, 16, 32, 1, 2),
+        (256, 256, 1024, 32, 32): (1, 4, 16, 16, 1, 1),
+        (256, 256, 1024, 64, 64): (1, 1, 32, 32, 1, 4),
+        (256, 256, 1024, 128, 128): (1, 1, 32, 32, 1, 4),
+        (256, 256, 2048, 16, 16): (1, 2, 16, 32, 1, 2),
+        (256, 256, 2048, 32, 32): (1, 1, 16, 32, 1, 2),
+        (256, 256, 2048, 64, 64): (2, 1, 16, 32, 1, 2),
+        (256, 256, 2048, 128, 128): (1, 1, 16, 16, 1, 1),
+        (256, 256, 4096, 16, 16): (1, 1, 16, 32, 1, 2),
+        (256, 256, 4096, 32, 32): (1, 1, 16, 32, 1, 2),
+        (256, 256, 4096, 64, 64): (1, 1, 32, 32, 1, 4),
+        (256, 256, 4096, 128, 128): (3, 1, 32, 64, 1, 4),
+        (256, 256, 8192, 16, 16): (1, 32, 16, 64, 1, 2),
+        (256, 256, 8192, 32, 32): (1, 1, 32, 64, 1, 4),
+        (256, 256, 8192, 64, 64): (1, 1, 32, 64, 1, 4),
+        (256, 256, 8192, 128, 128): (2, 1, 64, 32, 1, 4),
+        (256, 256, 16384, 16, 16): (1, 1, 16, 64, 1, 2),
+        (256, 256, 16384, 32, 32): (1, 1, 32, 64, 1, 4),
+        (256, 256, 16384, 64, 64): (1, 128, 64, 64, 1, 4),
+        (256, 256, 16384, 128, 128): (2, 1, 64, 32, 1, 4),
+        (256, 256, 32768, 16, 16): (2, 128, 16, 64, 1, 1),
+        (256, 256, 32768, 32, 32): (1, 1, 32, 64, 1, 4),
+        (256, 256, 32768, 64, 64): (1, 128, 64, 64, 1, 4),
+        (256, 256, 32768, 128, 128): (2, 1, 64, 64, 1, 4),
+        (256, 256, 65536, 16, 16): (1, 1, 16, 64, 1, 2),
+        (256, 256, 65536, 32, 32): (1, 1, 32, 64, 1, 4),
+        (256, 256, 65536, 64, 64): (2, 1, 64, 64, 1, 4),
+        (256, 256, 65536, 128, 128): (1, 1, 128, 32, 1, 4),
+        (256, 256, 131072, 16, 16): (3, 128, 16, 64, 1, 1),
+        (256, 256, 131072, 32, 32): (1, 1, 32, 64, 1, 4),
+        (256, 256, 131072, 64, 64): (2, 1, 64, 64, 1, 4),
+        (256, 256, 131072, 128, 128): (1, 8192, 64, 16, 1, 4),
+        (512, 512, 256, 16, 16): (1, 2, 16, 16, 1, 1),
+        (512, 512, 256, 32, 32): (1, 4, 16, 16, 1, 1),
+        (512, 512, 256, 64, 64): (1, 16, 16, 16, 1, 1),
+        (512, 512, 256, 128, 128): (1, 1, 16, 32, 1, 4),
+        (512, 512, 512, 16, 16): (1, 8, 16, 32, 1, 2),
+        (512, 512, 512, 32, 32): (1, 8, 16, 32, 1, 2),
+        (512, 512, 512, 64, 64): (1, 2, 16, 32, 1, 2),
+        (512, 512, 512, 128, 128): (1, 1, 32, 32, 1, 4),
+        (512, 512, 1024, 16, 16): (1, 1, 16, 32, 1, 2),
+        (512, 512, 1024, 32, 32): (1, 1, 16, 32, 1, 2),
+        (512, 512, 1024, 64, 64): (1, 1, 16, 32, 1, 2),
+        (512, 512, 1024, 128, 128): (1, 1, 64, 32, 1, 4),
+        (512, 512, 2048, 16, 16): (1, 16, 16, 64, 1, 2),
+        (512, 512, 2048, 32, 32): (1, 1, 32, 32, 1, 4),
+        (512, 512, 2048, 64, 64): (1, 1, 32, 32, 1, 4),
+        (512, 512, 2048, 128, 128): (2, 1, 32, 32, 1, 4),
+        (512, 512, 4096, 16, 16): (2, 64, 16, 64, 1, 1),
+        (512, 512, 4096, 32, 32): (1, 64, 32, 64, 1, 4),
+        (512, 512, 4096, 64, 64): (1, 1, 32, 32, 1, 4),
+        (512, 512, 4096, 128, 128): (1, 1, 64, 32, 1, 4),
+        (512, 512, 8192, 16, 16): (2, 64, 16, 64, 1, 1),
+        (512, 512, 8192, 32, 32): (1, 256, 32, 32, 1, 1),
+        (512, 512, 8192, 64, 64): (1, 64, 64, 64, 1, 4),
+        (512, 512, 8192, 128, 128): (2, 1, 64, 32, 1, 8),
+        (512, 512, 16384, 16, 16): (2, 64, 16, 64, 1, 1),
+        (512, 512, 16384, 32, 32): (1, 128, 32, 32, 1, 1),
+        (512, 512, 16384, 64, 64): (1, 64, 64, 64, 1, 4),
+        (512, 512, 16384, 128, 128): (3, 1, 64, 32, 1, 8),
+        (512, 512, 32768, 16, 16): (2, 64, 16, 64, 1, 1),
+        (512, 512, 32768, 32, 32): (1, 128, 32, 32, 1, 1),
+        (512, 512, 32768, 64, 64): (1, 64, 64, 64, 1, 4),
+        (512, 512, 32768, 128, 128): (2, 1, 64, 32, 1, 8),
+        (512, 512, 65536, 16, 16): (2, 32, 16, 64, 1, 1),
+        (512, 512, 65536, 32, 32): (1, 128, 32, 32, 1, 1),
+        (512, 512, 65536, 64, 64): (1, 64, 64, 64, 1, 4),
+        (512, 512, 65536, 128, 128): (2, 1, 64, 32, 1, 8),
+        (512, 512, 131072, 16, 16): (2, 32, 16, 64, 1, 1),
+        (512, 512, 131072, 32, 32): (1, 128, 32, 32, 1, 1),
+        (512, 512, 131072, 64, 64): (3, 64, 64, 64, 1, 4),
+        (512, 512, 131072, 128, 128): (1, 8192, 64, 16, 1, 4),
+        (1024, 1024, 256, 16, 16): (1, 4, 16, 32, 1, 2),
+        (1024, 1024, 256, 32, 32): (1, 4, 16, 32, 1, 2),
+        (1024, 1024, 256, 64, 64): (1, 1, 16, 32, 1, 2),
+        (1024, 1024, 256, 128, 128): (1, 1, 16, 16, 1, 1),
+        (1024, 1024, 512, 16, 16): (1, 8, 16, 32, 1, 2),
+        (1024, 1024, 512, 32, 32): (1, 8, 16, 32, 1, 1),
+        (1024, 1024, 512, 64, 64): (1, 8, 32, 32, 1, 4),
+        (1024, 1024, 512, 128, 128): (2, 1, 32, 32, 1, 4),
+        (1024, 1024, 1024, 16, 16): (1, 16, 16, 32, 1, 2),
+        (1024, 1024, 1024, 32, 32): (1, 16, 32, 64, 1, 4),
+        (1024, 1024, 1024, 64, 64): (1, 16, 32, 64, 1, 4),
+        (1024, 1024, 1024, 128, 128): (1, 1, 32, 32, 1, 4),
+        (1024, 1024, 2048, 16, 16): (2, 32, 16, 64, 1, 1),
+        (1024, 1024, 2048, 32, 32): (1, 32, 32, 64, 1, 4),
+        (1024, 1024, 2048, 64, 64): (1, 32, 64, 64, 1, 4),
+        (1024, 1024, 2048, 128, 128): (1, 1, 32, 64, 1, 4),
+        (1024, 1024, 4096, 16, 16): (2, 16, 16, 64, 1, 1),
+        (1024, 1024, 4096, 32, 32): (1, 64, 32, 32, 1, 1),
+        (1024, 1024, 4096, 64, 64): (1, 64, 64, 64, 1, 4),
+        (1024, 1024, 4096, 128, 128): (2, 64, 64, 32, 1, 8),
+        (1024, 1024, 8192, 16, 16): (2, 16, 16, 64, 1, 1),
+        (1024, 1024, 8192, 32, 32): (1, 64, 32, 32, 1, 1),
+        (1024, 1024, 8192, 64, 64): (1, 64, 64, 64, 1, 4),
+        (1024, 1024, 8192, 128, 128): (4, 1, 32, 64, 1, 4),
+        (1024, 1024, 16384, 16, 16): (2, 16, 16, 64, 1, 1),
+        (1024, 1024, 16384, 32, 32): (1, 64, 32, 32, 1, 1),
+        (1024, 1024, 16384, 64, 64): (1, 32, 64, 64, 1, 4),
+        (1024, 1024, 16384, 128, 128): (2, 64, 64, 32, 1, 4),
+        (1024, 1024, 32768, 16, 16): (2, 16, 16, 64, 1, 1),
+        (1024, 1024, 32768, 32, 32): (1, 64, 32, 32, 1, 1),
+        (1024, 1024, 32768, 64, 64): (1, 32, 64, 64, 1, 4),
+        (1024, 1024, 32768, 128, 128): (4, 1, 32, 64, 1, 4),
+        (1024, 1024, 65536, 16, 16): (2, 16, 16, 64, 1, 1),
+        (1024, 1024, 65536, 32, 32): (1, 32, 32, 32, 1, 1),
+        (1024, 1024, 65536, 64, 64): (2, 32, 64, 64, 1, 4),
+        (1024, 1024, 65536, 128, 128): (4, 1, 64, 32, 1, 4),
+        (1024, 1024, 131072, 16, 16): (2, 16, 16, 64, 1, 1),
+        (1024, 1024, 131072, 32, 32): (1, 32, 32, 32, 1, 1),
+        (1024, 1024, 131072, 64, 64): (1, 16, 64, 64, 1, 4),
+        (1024, 1024, 131072, 128, 128): (1, 8192, 64, 16, 1, 4),
+        (2048, 2048, 256, 16, 16): (1, 4, 16, 32, 1, 2),
+        (2048, 2048, 256, 32, 32): (1, 8, 16, 32, 1, 1),
+        (2048, 2048, 256, 64, 64): (1, 8, 32, 32, 1, 4),
+        (2048, 2048, 256, 128, 128): (1, 4, 64, 64, 1, 8),
+        (2048, 2048, 512, 16, 16): (2, 8, 16, 32, 1, 2),
+        (2048, 2048, 512, 32, 32): (2, 8, 32, 64, 1, 4),
+        (2048, 2048, 512, 64, 64): (2, 4, 64, 64, 1, 4),
+        (2048, 2048, 512, 128, 128): (1, 8, 32, 64, 1, 4),
+        (2048, 2048, 1024, 16, 16): (2, 16, 16, 64, 3, 1),
+        (2048, 2048, 1024, 32, 32): (1, 32, 32, 32, 1, 1),
+        (2048, 2048, 1024, 64, 64): (1, 16, 64, 64, 1, 4),
+        (2048, 2048, 1024, 128, 128): (2, 4, 64, 64, 1, 8),
+        (2048, 2048, 2048, 16, 16): (2, 16, 16, 64, 1, 1),
+        (2048, 2048, 2048, 32, 32): (1, 32, 32, 32, 1, 1),
+        (2048, 2048, 2048, 64, 64): (1, 16, 64, 64, 1, 4),
+        (2048, 2048, 2048, 128, 128): (2, 32, 32, 64, 1, 4),
+        (2048, 2048, 4096, 16, 16): (3, 2, 16, 64, 1, 1),
+        (2048, 2048, 4096, 32, 32): (3, 4, 32, 32, 1, 1),
+        (2048, 2048, 4096, 64, 64): (1, 16, 64, 64, 1, 4),
+        (2048, 2048, 4096, 128, 128): (2, 32, 64, 32, 1, 4),
+        (2048, 2048, 8192, 16, 16): (3, 4, 16, 64, 1, 1),
+        (2048, 2048, 8192, 32, 32): (2, 4, 32, 32, 1, 1),
+        (2048, 2048, 8192, 64, 64): (2, 32, 64, 32, 1, 2),
+        (2048, 2048, 8192, 128, 128): (4, 1, 32, 64, 1, 4),
+        (2048, 2048, 16384, 16, 16): (3, 4, 16, 64, 1, 1),
+        (2048, 2048, 16384, 32, 32): (1, 4, 32, 32, 1, 1),
+        (2048, 2048, 16384, 64, 64): (2, 8, 64, 32, 1, 2),
+        (2048, 2048, 16384, 128, 128): (2, 8, 64, 32, 1, 4),
+        (2048, 2048, 32768, 16, 16): (2, 4, 16, 64, 1, 1),
+        (2048, 2048, 32768, 32, 32): (2, 8, 32, 32, 1, 1),
+        (2048, 2048, 32768, 64, 64): (1, 16, 64, 32, 1, 2),
+        (2048, 2048, 32768, 128, 128): (4, 1, 32, 64, 1, 4),
+        (2048, 2048, 65536, 16, 16): (3, 4, 16, 64, 1, 1),
+        (2048, 2048, 65536, 32, 32): (1, 8, 32, 32, 1, 1),
+        (2048, 2048, 65536, 64, 64): (1, 8, 64, 32, 1, 2),
+        (2048, 2048, 65536, 128, 128): (4, 1, 64, 32, 1, 4),
+        (2048, 2048, 131072, 16, 16): (2, 4, 16, 64, 1, 1),
+        (2048, 2048, 131072, 32, 32): (1, 8, 32, 32, 1, 1),
+        (2048, 2048, 131072, 64, 64): (3, 1, 64, 32, 1, 2),
+        (2048, 2048, 131072, 128, 128): (1, 8192, 128, 16, 1, 8),
+        (4096, 4096, 256, 16, 16): (2, 4, 16, 32, 1, 2),
+        (4096, 4096, 256, 32, 32): (1, 4, 32, 64, 1, 4),
+        (4096, 4096, 256, 64, 64): (1, 4, 64, 64, 1, 4),
+        (4096, 4096, 256, 128, 128): (1, 4, 32, 64, 1, 4),
+        (4096, 4096, 512, 16, 16): (2, 8, 16, 64, 3, 1),
+        (4096, 4096, 512, 32, 32): (2, 16, 32, 32, 1, 1),
+        (4096, 4096, 512, 64, 64): (1, 8, 64, 64, 1, 4),
+        (4096, 4096, 512, 128, 128): (1, 8, 32, 64, 1, 4),
+        (4096, 4096, 1024, 16, 16): (1, 8, 16, 64, 3, 1),
+        (4096, 4096, 1024, 32, 32): (1, 16, 32, 32, 1, 1),
+        (4096, 4096, 1024, 64, 64): (1, 16, 64, 32, 1, 2),
+        (4096, 4096, 1024, 128, 128): (1, 16, 32, 64, 1, 4),
+        (4096, 4096, 2048, 16, 16): (1, 16, 16, 64, 3, 1),
+        (4096, 4096, 2048, 32, 32): (1, 16, 32, 32, 1, 1),
+        (4096, 4096, 2048, 64, 64): (3, 16, 64, 32, 1, 2),
+        (4096, 4096, 2048, 128, 128): (4, 8, 32, 64, 1, 4),
+        (4096, 4096, 4096, 16, 16): (1, 8, 16, 64, 3, 1),
+        (4096, 4096, 4096, 32, 32): (1, 1, 32, 32, 1, 1),
+        (4096, 4096, 4096, 64, 64): (2, 16, 64, 32, 1, 2),
+        (4096, 4096, 4096, 128, 128): (4, 8, 32, 64, 1, 4),
+        (4096, 4096, 8192, 16, 16): (1, 8, 16, 64, 3, 1),
+        (4096, 4096, 8192, 32, 32): (2, 1, 32, 32, 1, 1),
+        (4096, 4096, 8192, 64, 64): (1, 16, 64, 32, 1, 2),
+        (4096, 4096, 8192, 128, 128): (2, 1, 32, 64, 1, 4),
+        (4096, 4096, 16384, 16, 16): (1, 8, 16, 64, 3, 1),
+        (4096, 4096, 16384, 32, 32): (1, 1, 32, 32, 1, 1),
+        (4096, 4096, 16384, 64, 64): (2, 8, 64, 32, 1, 2),
+        (4096, 4096, 16384, 128, 128): (2, 1, 32, 64, 1, 4),
+        (4096, 4096, 32768, 16, 16): (1, 8, 16, 64, 3, 1),
+        (4096, 4096, 32768, 32, 32): (1, 1, 32, 32, 1, 1),
+        (4096, 4096, 32768, 64, 64): (1, 8, 64, 32, 1, 2),
+        (4096, 4096, 32768, 128, 128): (2, 1, 32, 64, 1, 4),
+        (4096, 4096, 65536, 16, 16): (1, 8, 16, 64, 3, 1),
+        (4096, 4096, 65536, 32, 32): (3, 1, 32, 32, 1, 1),
+        (4096, 4096, 65536, 64, 64): (3, 4, 64, 32, 1, 2),
+        (4096, 4096, 65536, 128, 128): (2, 1, 32, 64, 1, 4),
+        (4096, 4096, 131072, 16, 16): (1, 8, 16, 64, 3, 1),
+        (4096, 4096, 131072, 32, 32): (1, 1, 32, 32, 1, 1),
+        (4096, 4096, 131072, 64, 64): (2, 8, 64, 32, 1, 2),
+        (4096, 4096, 131072, 128, 128): (1, 8192, 128, 16, 1, 8),
+        (8192, 8192, 256, 16, 16): (2, 4, 16, 64, 3, 1),
+        (8192, 8192, 256, 32, 32): (1, 8, 32, 32, 1, 1),
+        (8192, 8192, 256, 64, 64): (1, 4, 64, 64, 1, 4),
+        (8192, 8192, 256, 128, 128): (1, 4, 32, 64, 1, 4),
+        (8192, 8192, 512, 16, 16): (1, 4, 16, 64, 3, 1),
+        (8192, 8192, 512, 32, 32): (1, 16, 32, 32, 1, 1),
+        (8192, 8192, 512, 64, 64): (2, 4, 64, 64, 1, 4),
+        (8192, 8192, 512, 128, 128): (2, 1, 32, 64, 1, 4),
+        (8192, 8192, 1024, 16, 16): (3, 8, 16, 64, 3, 1),
+        (8192, 8192, 1024, 32, 32): (1, 16, 32, 32, 1, 1),
+        (8192, 8192, 1024, 64, 64): (1, 8, 64, 32, 1, 2),
+        (8192, 8192, 1024, 128, 128): (2, 4, 32, 64, 1, 4),
+        (8192, 8192, 2048, 16, 16): (1, 8, 16, 64, 3, 1),
+        (8192, 8192, 2048, 32, 32): (1, 16, 32, 32, 1, 1),
+        (8192, 8192, 2048, 64, 64): (2, 8, 64, 32, 1, 2),
+        (8192, 8192, 2048, 128, 128): (4, 1, 32, 64, 1, 4),
+        (8192, 8192, 4096, 16, 16): (1, 8, 16, 64, 3, 1),
+        (8192, 8192, 4096, 32, 32): (1, 16, 32, 32, 1, 1),
+        (8192, 8192, 4096, 64, 64): (1, 4, 64, 32, 1, 2),
+        (8192, 8192, 4096, 128, 128): (3, 1, 32, 64, 1, 4),
+        (8192, 8192, 8192, 16, 16): (1, 8, 16, 64, 3, 1),
+        (8192, 8192, 8192, 32, 32): (1, 8, 32, 32, 1, 1),
+        (8192, 8192, 8192, 64, 64): (1, 8, 64, 32, 1, 2),
+        (8192, 8192, 8192, 128, 128): (4, 1, 32, 64, 1, 4),
+        (8192, 8192, 16384, 16, 16): (3, 4, 16, 64, 3, 1),
+        (8192, 8192, 16384, 32, 32): (1, 8, 32, 32, 1, 1),
+        (8192, 8192, 16384, 64, 64): (2, 2, 64, 32, 1, 2),
+        (8192, 8192, 16384, 128, 128): (7, 1, 32, 64, 1, 4),
+        (8192, 8192, 32768, 16, 16): (1, 4, 16, 64, 3, 1),
+        (8192, 8192, 32768, 32, 32): (1, 8, 32, 32, 1, 1),
+        (8192, 8192, 32768, 64, 64): (3, 2, 64, 32, 1, 2),
+        (8192, 8192, 32768, 128, 128): (6, 1, 32, 64, 1, 4),
+        (8192, 8192, 65536, 16, 16): (1, 4, 16, 64, 3, 1),
+        (8192, 8192, 65536, 32, 32): (4, 8, 32, 32, 1, 1),
+        (8192, 8192, 65536, 64, 64): (1, 2, 64, 32, 1, 2),
+        (8192, 8192, 65536, 128, 128): (4, 1, 32, 64, 1, 4),
+        (8192, 8192, 131072, 16, 16): (1, 4, 16, 64, 3, 1),
+        (8192, 8192, 131072, 32, 32): (1, 8, 32, 32, 1, 1),
+        (8192, 8192, 131072, 64, 64): (5, 4, 64, 32, 1, 2),
+        (8192, 8192, 131072, 128, 128): (1, 4096, 128, 16, 1, 8),
+        (16384, 16384, 256, 16, 16): (1, 4, 16, 64, 3, 1),
+        (16384, 16384, 256, 32, 32): (1, 8, 32, 32, 1, 1),
+        (16384, 16384, 256, 64, 64): (1, 4, 64, 32, 1, 2),
+        (16384, 16384, 256, 128, 128): (1, 4, 32, 64, 1, 4),
+        (16384, 16384, 512, 16, 16): (1, 8, 16, 64, 3, 1),
+        (16384, 16384, 512, 32, 32): (1, 16, 32, 32, 1, 1),
+        (16384, 16384, 512, 64, 64): (1, 4, 64, 32, 1, 2),
+        (16384, 16384, 512, 128, 128): (3, 1, 32, 64, 1, 4),
+        (16384, 16384, 1024, 16, 16): (1, 8, 16, 64, 3, 1),
+        (16384, 16384, 1024, 32, 32): (1, 16, 32, 32, 1, 1),
+        (16384, 16384, 1024, 64, 64): (2, 4, 64, 32, 1, 2),
+        (16384, 16384, 1024, 128, 128): (1, 2, 32, 64, 1, 4),
+        (16384, 16384, 2048, 16, 16): (1, 4, 16, 64, 3, 1),
+        (16384, 16384, 2048, 32, 32): (1, 16, 32, 32, 1, 1),
+        (16384, 16384, 2048, 64, 64): (3, 4, 64, 32, 1, 2),
+        (16384, 16384, 2048, 128, 128): (2, 1, 32, 64, 1, 4),
+        (16384, 16384, 4096, 16, 16): (4, 8, 16, 64, 3, 1),
+        (16384, 16384, 4096, 32, 32): (5, 16, 32, 32, 1, 1),
+        (16384, 16384, 4096, 64, 64): (3, 2, 64, 32, 1, 2),
+        (16384, 16384, 4096, 128, 128): (2, 1, 32, 64, 1, 4),
+        (16384, 16384, 8192, 16, 16): (1, 4, 16, 64, 3, 1),
+        (16384, 16384, 8192, 32, 32): (1, 4, 32, 32, 1, 1),
+        (16384, 16384, 8192, 64, 64): (1, 2, 64, 32, 1, 2),
+        (16384, 16384, 8192, 128, 128): (2, 1, 32, 64, 1, 4),
+        (16384, 16384, 16384, 16, 16): (1, 8, 16, 64, 3, 1),
+        (16384, 16384, 16384, 32, 32): (1, 4, 32, 32, 1, 1),
+        (16384, 16384, 16384, 64, 64): (1, 2, 64, 32, 1, 2),
+        (16384, 16384, 16384, 128, 128): (3, 1, 32, 64, 1, 4),
+        (16384, 16384, 32768, 16, 16): (1, 4, 16, 64, 3, 1),
+        (16384, 16384, 32768, 32, 32): (1, 2, 32, 32, 1, 1),
+        (16384, 16384, 32768, 64, 64): (3, 2, 64, 32, 1, 2),
+        (16384, 16384, 32768, 128, 128): (3, 1, 32, 64, 1, 4),
+        (16384, 16384, 65536, 16, 16): (1, 8, 16, 64, 3, 1),
+        (16384, 16384, 65536, 32, 32): (1, 4, 32, 32, 1, 1),
+        (16384, 16384, 65536, 64, 64): (4, 4, 64, 32, 1, 2),
+        (16384, 16384, 65536, 128, 128): (5, 1, 32, 64, 1, 4),
+        (16384, 16384, 131072, 16, 16): (1, 2, 16, 64, 3, 1),
+        (16384, 16384, 131072, 32, 32): (1, 4, 32, 32, 1, 1),
+        (16384, 16384, 131072, 64, 64): (1, 2, 64, 32, 1, 2),
+        (16384, 16384, 131072, 128, 128): (1, 4096, 128, 16, 1, 8),
+    },
     # END GENERATED DATA
 }
 
 if __name__ == "__main__":
-    for op in ["scatter_mm", "bsr_dense_mm"]:
-        main(op=op, force=False)
+    for dtype in [torch.float16, torch.bfloat16, torch.float32]:
+        for op in ["scatter_mm", "bsr_dense_mm"]:
+            main(op=op, force=False, dtype=dtype)