Add optimal triton kernel parameters to bsr_dense_mm and scatter_mm for bfloat16 and float32 dtypes (#113553)
As in the title.
This PR is a follow-up to PR https://github.com/pytorch/pytorch/pull/112737 to address bfloat16 and float32 dtype cases. The performance increase is as follows (`NVIDIA A100-SXM4-80GB`):
- bsr_scatter_mm and bfloat16
- for blocksize 16x16, the average/maximum speed up is about 29/75 %.
- for blocksize 32x32, the average/maximum speed up is about 23/58 %.
- for blocksize 64x64, the average/maximum speed up is about 27/66 %.
- for blocksize 128x128, the average/maximum speed up is about 33/72 %.
- bsr_dense_mm and bfloat16
- for blocksize 16x16, the average/maximum speed up is about 47/61 %.
- for blocksize 32x32, the average/maximum speed up is about 29/43 %.
- for blocksize 64x64, the average/maximum speed up is about 21/41 %.
- for blocksize 128x128, the average/maximum speed up is about 12/29 %.
- bsr_dense_mm and float32
- for blocksize 16x16, the average/maximum speed up is about 35/49 %.
- for blocksize 32x32, the average/maximum speed up is about 2/5 %.
- for blocksize 64x64, the average/maximum speed up is about 2/21 %.
- for blocksize 128x128, the average/maximum speed up is about 79/84 %.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/113553
Approved by: https://github.com/cpuhrsch
diff --git a/torch/sparse/_triton_ops.py b/torch/sparse/_triton_ops.py
index 93b1a27..6ff12bb 100644
--- a/torch/sparse/_triton_ops.py
+++ b/torch/sparse/_triton_ops.py
@@ -1202,7 +1202,8 @@
out: Optional[torch.Tensor] = None,
skip_checks: bool = False,
max_grid: Optional[Tuple[Optional[int], Optional[int], Optional[int]]] = None,
- meta: Optional[dict] = None
+ meta: Optional[dict] = None,
+ enable_bsr_scatter_mm: bool = True
):
f_name = "bsr_dense_mm"
m, kl = bsr.shape[-2:]
@@ -1249,13 +1250,19 @@
blocksize = bsr.values().shape[-2:]
- if max(blocksize) == 16 and bsr.dense_dim() == 0 and bsr.ndim == 2:
+ if enable_bsr_scatter_mm and max(blocksize) == 16 and bsr.dense_dim() == 0 and bsr.ndim == 2:
+ dtype = bsr.dtype
# bsr_scatter_mm is more performant than bsr_dense_mm for
# 16x16 blocksizes and large enough input shapes:
if (
- (m >= 4096 and n >= 8192)
- or (m == 2048 and n >= 32768)
- or (n >= 131072)
+ (dtype in {torch.float16, torch.bfloat16}
+ and ((m >= 4096 and n >= 8192)
+ or (m == 2048 and n >= 32768)
+ or (n >= 131072))) or
+ (dtype == torch.float32
+ and (m >= 1024
+ or (m == 512 and n >= 512)
+ or (m == 256 and n >= 2048)))
):
return bsr_scatter_mm(bsr, dense, out=out)
diff --git a/torch/sparse/_triton_ops_meta.py b/torch/sparse/_triton_ops_meta.py
index 80d642c..a014d1c 100644
--- a/torch/sparse/_triton_ops_meta.py
+++ b/torch/sparse/_triton_ops_meta.py
@@ -89,10 +89,18 @@
" BEGIN/END GENERATED DATA comment blocks appear to be corrupted"
)
return
+
+ def sort_key(key):
+ op, device_name, version = key
+ version = tuple(
+ (str(item) if isinstance(item, torch.dtype) else item) for item in version
+ )
+ return (op, device_name, version)
+
part1 = current_content[: begin_data_index + len(begin_data_str)]
part2 = current_content[end_data_index:]
data_part = []
- for op_key in sorted(_operation_device_version_data):
+ for op_key in sorted(_operation_device_version_data, key=sort_key):
data_part.append(" " + repr(op_key).replace("'", '"') + ": {")
op_data = _operation_device_version_data[op_key]
for key in sorted(op_data):
@@ -105,7 +113,9 @@
f.close()
-def minimize(target_func, initial_parameters, step_func):
+def minimize(
+ target_func, initial_parameters, reference_parameters, step_func, max_step=2
+):
"""Find a dict of parameters that minimizes the target function using
the initial dict of parameters and a step function that progresses
a specified parameter in a dict of parameters.
@@ -116,6 +126,8 @@
``target_func(parameters: dict) -> float``
initial_parameters (dict): a set of parameters used as an initial
value to the minimization process.
+ reference_parameters (dict): a set of parameters used as an
+ reference value with respect to which the speed up is computed.
step_func (callable): a functional with the signature
``step_func(parameter_name:str, parameter_value:int, direction:int, parameters:dict) -> int``
that increments or decrements (when ``direction`` is positive or
@@ -128,6 +140,7 @@
parameters (dict): a set of parameters that minimizes the target
function.
speedup_incr (float): a speedup change given in percentage
+
"""
def to_key(parameters):
@@ -137,13 +150,34 @@
return dict(zip(sorted(parameters), key))
all_values = dict()
+
+ try:
+ reference_target = target_func(reference_parameters)
+ except Exception as msg:
+ print(f"{reference_parameters=} lead to failure: {msg}.")
+ reference_target = None
+ if reference_target is not None:
+ all_values[to_key(reference_parameters)] = reference_target
+
parameters = initial_parameters
try:
initial_target = target_func(parameters)
except Exception as msg:
- print(f"{parameters=} lead to failure: {msg}. Skipping.")
- return parameters, -1
- all_values[to_key(parameters)] = initial_target
+ if reference_target is None:
+ print(f"{initial_parameters=} lead to failure: {msg}. Optimization failed!")
+ return {}, -1, None
+ print(
+ f"{initial_parameters=} lead to failure: {msg}. Using reference parameters instead of initial parameters."
+ )
+ parameters = reference_parameters
+ initial_target = reference_target
+
+ if reference_target is None:
+ print("Using initial parameters instead of reference parameters.")
+ reference_target = initial_target
+
+ initial_key = to_key(parameters)
+ all_values[initial_key] = initial_target
while True:
current_key = to_key(parameters)
@@ -152,7 +186,9 @@
new_minimizer = False
for name in parameters:
value = parameters[name]
- for direction in [1, -1]:
+ for direction in range(-max_step, max_step + 1):
+ if direction == 0:
+ continue
next_value = step_func(name, value, direction, parameters)
if next_value == value:
continue
@@ -175,8 +211,30 @@
if new_minimizer:
parameters = from_key(minimizer_key, parameters)
else:
- speedup_incr = (1 - minimizer_target / initial_target) * 100
- return parameters, speedup_incr
+ # ensure stable minimizer:
+ minimizer_keys = {
+ k
+ for k, v in all_values.items()
+ if isinstance(v, float) and abs(1 - v / minimizer_target) < 0.001
+ }
+ minimizer_key = (
+ initial_key if initial_key in minimizer_keys else min(minimizer_keys)
+ )
+ minimizer_target = all_values[minimizer_key]
+ parameters = from_key(minimizer_key, parameters)
+ speedup_incr = (1 - minimizer_target / reference_target) * 100
+ if speedup_incr < 0:
+ print(
+ f"{speedup_incr=} is negative. Rerunning minimize with reference parameters as initial parameters."
+ )
+ return minimize(
+ target_func,
+ reference_parameters,
+ reference_parameters,
+ step_func,
+ max_step=max_step,
+ )
+ return parameters, speedup_incr, minimizer_target
def create_blocked_tensor(B, M, N, blocksize, sparsity, dtype, device):
@@ -212,20 +270,37 @@
from torch.sparse._triton_ops import bsr_scatter_mm, bsr_scatter_mm_indices_data
key = (m, k, n, bm, bk)
- version = (0, dtype, sparsity)
- initial_meta = get_meta("scatter_mm", key, version=version, exact=True)
+ version = (0, dtype, sparsity)
+ device_name = torch.cuda.get_device_name()
+
+ reference_meta = dict(
+ GROUP_SIZE=1,
+ TILE_M=16,
+ TILE_N=16,
+ SPLIT_N=n // 16,
+ num_stages=1,
+ num_warps=1,
+ )
+
+ initial_meta = get_meta(
+ "scatter_mm", key, device_name=device_name, version=version, exact=True
+ )
if initial_meta is None:
- initial_meta = get_meta("scatter_mm", key, version=(0, dtype, 0.5), exact=True)
+ initial_meta = get_meta(
+ "bsr_dense_mm",
+ key,
+ device_name=device_name,
+ version=(0, dtype, 0.5),
+ exact=True,
+ )
if initial_meta is None:
- initial_meta = dict(
- GROUP_SIZE=1, TILE_M=16, TILE_N=16, SPLIT_N=1, num_warps=1, num_stages=1
- )
+ initial_meta = reference_meta
elif not force:
return
- print(f"{m, k, n, bm, bk=}")
+ print(f"{m, k, n, bm, bk, initial_meta, reference_meta=}")
torch.manual_seed(0)
bsr = create_blocked_tensor(
0, m, k, (bm, bk), sparsity, dtype, device
@@ -250,6 +325,7 @@
# return next value in positive or negative direction, or
# input value if the step will result an invalid
# value. The input value is assumed to be valid.
+
is_log = name in {"SPLIT_N", "TILE_M", "TILE_N", "num_warps"}
min_value = dict(
SPLIT_N=1, TILE_M=16, TILE_N=16, num_warps=1, num_stages=1, GROUP_SIZE=1
@@ -261,22 +337,39 @@
SPLIT_N=2, TILE_M=2, TILE_N=2, num_warps=2, num_stages=1, GROUP_SIZE=1
)[name]
if is_log:
- next_value = value * value_step if direction > 0 else value // value_step
+ next_value = (
+ value * value_step**direction
+ if direction > 0
+ else value // (value_step ** abs(direction))
+ )
else:
- next_value = value + value_step if direction > 0 else value - value_step
+ next_value = value + value_step * direction
if min_value is not None:
next_value = max(next_value, min_value)
if max_value is not None:
next_value = min(next_value, max_value)
if name == "SPLIT_N" and n % next_value != 0:
return value
+ # Hard-skip parameter combinations that break CUDA state for pytorch:
+ if (dtype, name, next_value, m, n, k, bm, bk) in {
+ (torch.float32, "num_warps", 32, 256, 256, 256, 16, 16),
+ (torch.float32, "num_warps", 16, 256, 256, 256, 32, 32),
+ (torch.float32, "num_warps", 16, 256, 256, 256, 64, 64),
+ (torch.float32, "num_warps", 16, 256, 256, 256, 128, 128),
+ (torch.float32, "num_warps", 16, 512, 512, 256, 128, 128),
+ } and re.match(r"NVIDIA A100[^\d]", device_name) is not None:
+ return value
return next_value
- meta, speedup = minimize(bench, initial_meta, step_meta_parameter)
- if speedup < 3 and 0:
- # don't bother updating parameters when the speed up change is less than 3 %
+ meta, speedup, timing = minimize(
+ bench, initial_meta, reference_meta, step_meta_parameter
+ )
+ print(f"{meta=} {speedup=:.1f} % {timing=:.3f} ms")
+ if initial_meta is not reference_meta and initial_meta == meta:
return
- print(f"{meta=} {speedup=:.1f} %")
+ print(f"{meta=} {speedup=:.1f} % {timing=:.3f} ms")
+ if speedup < 0:
+ return
device_name = torch.cuda.get_device_name()
update(
@@ -294,6 +387,8 @@
key = (m, k, n, bm, bk)
version = (0, dtype, sparsity)
+ reference_meta = dict(GROUP_SIZE_ROW=1, num_stages=1, num_warps=4)
+
initial_meta = get_meta("bsr_dense_mm", key, version=version, exact=True)
if initial_meta is None:
@@ -301,11 +396,11 @@
"bsr_dense_mm", key, version=(0, dtype, 0.5), exact=True
)
if initial_meta is None:
- initial_meta = dict(GROUP_SIZE_ROW=1, num_stages=1, num_warps=1)
+ initial_meta = reference_meta
elif not force:
return
- print(f"{m, k, n, bm, bk=}")
+ print(f"{m, k, n, bm, bk, initial_meta=}")
torch.manual_seed(0)
bsr = create_blocked_tensor(
0, m, k, (bm, bk), sparsity, dtype, device
@@ -331,20 +426,27 @@
max_value = dict().get(name)
value_step = dict(num_warps=2, num_stages=1, GROUP_SIZE_ROW=1)[name]
if is_log:
- next_value = value * value_step if direction > 0 else value // value_step
+ next_value = (
+ value * value_step**direction
+ if direction > 0
+ else value // (value_step ** abs(direction))
+ )
else:
- next_value = value + value_step if direction > 0 else value - value_step
+ next_value = value + value_step * direction
if min_value is not None:
next_value = max(next_value, min_value)
if max_value is not None:
next_value = min(next_value, max_value)
return next_value
- meta, speedup = minimize(bench, initial_meta, step_meta_parameter)
- if speedup < 3 and 0:
- # don't bother updating parameters when the speed up change is less than 3 %
+ meta, speedup, timing = minimize(
+ bench, initial_meta, reference_meta, step_meta_parameter, max_step=2
+ )
+ if initial_meta is not reference_meta and initial_meta == meta:
return
- print(f"{meta=} {speedup=:.1f} %")
+ print(f"{meta=} {speedup=:.1f} % {timing=:.3f} ms")
+ if speedup < 0:
+ return
device_name = torch.cuda.get_device_name()
update(
@@ -352,10 +454,9 @@
)
-def main(op="scatter_mm", force=False):
+def main(op="scatter_mm", force=False, dtype=torch.float16):
import itertools
- dtype = torch.float16
sizes_lst = [256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072]
shapes_lst = [(sz, sz) for sz in sizes_lst[:-3]]
blocksize_lst = [(16, 16), (32, 32), (64, 64), (128, 128)]
@@ -367,10 +468,12 @@
shapes_lst, sizes_lst, blocksize_lst
):
if op == "scatter_mm":
- optimize_scatter_mm(M, K, N, BM, BK, force=force, sparsity=sparsity)
+ optimize_scatter_mm(
+ M, K, N, BM, BK, force=force, sparsity=sparsity, dtype=dtype
+ )
elif op == "bsr_dense_mm":
optimize_bsr_dense_mm(
- M, K, N, BM, BK, force=force, sparsity=sparsity
+ M, K, N, BM, BK, force=force, sparsity=sparsity, dtype=dtype
)
else:
raise NotImplementedError(op)
@@ -378,7 +481,7 @@
break
except Exception as msg:
dump()
- print(msg)
+ raise
dump()
if 0:
@@ -486,6 +589,288 @@
# bsr_dense_mm : M, K, N, Ms, Ks -> GROUP_SIZE_ROW, num_stages, num_warps
#
# BEGIN GENERATED DATA
+ ("bsr_dense_mm", "NVIDIA A100-SXM4-80GB", (0, torch.bfloat16, 0.5)): {
+ (256, 256, 256, 16, 16): (3, 3, 1),
+ (256, 256, 256, 32, 32): (1, 5, 1),
+ (256, 256, 256, 64, 64): (2, 3, 2),
+ (256, 256, 256, 128, 128): (2, 2, 8),
+ (256, 256, 512, 16, 16): (2, 5, 1),
+ (256, 256, 512, 32, 32): (1, 4, 1),
+ (256, 256, 512, 64, 64): (1, 3, 2),
+ (256, 256, 512, 128, 128): (2, 2, 8),
+ (256, 256, 1024, 16, 16): (2, 3, 1),
+ (256, 256, 1024, 32, 32): (1, 4, 1),
+ (256, 256, 1024, 64, 64): (3, 3, 2),
+ (256, 256, 1024, 128, 128): (1, 2, 8),
+ (256, 256, 2048, 16, 16): (2, 3, 1),
+ (256, 256, 2048, 32, 32): (1, 3, 1),
+ (256, 256, 2048, 64, 64): (2, 3, 4),
+ (256, 256, 2048, 128, 128): (2, 2, 8),
+ (256, 256, 4096, 16, 16): (3, 3, 1),
+ (256, 256, 4096, 32, 32): (4, 3, 1),
+ (256, 256, 4096, 64, 64): (2, 3, 2),
+ (256, 256, 4096, 128, 128): (2, 2, 8),
+ (256, 256, 8192, 16, 16): (1, 3, 1),
+ (256, 256, 8192, 32, 32): (1, 1, 1),
+ (256, 256, 8192, 64, 64): (1, 1, 4),
+ (256, 256, 8192, 128, 128): (1, 1, 4),
+ (256, 256, 16384, 16, 16): (1, 3, 1),
+ (256, 256, 16384, 32, 32): (3, 3, 1),
+ (256, 256, 16384, 64, 64): (1, 1, 4),
+ (256, 256, 16384, 128, 128): (2, 1, 4),
+ (256, 256, 32768, 16, 16): (1, 3, 1),
+ (256, 256, 32768, 32, 32): (1, 3, 1),
+ (256, 256, 32768, 64, 64): (1, 2, 4),
+ (256, 256, 32768, 128, 128): (2, 1, 4),
+ (256, 256, 65536, 16, 16): (1, 3, 1),
+ (256, 256, 65536, 32, 32): (1, 3, 1),
+ (256, 256, 65536, 64, 64): (1, 2, 4),
+ (256, 256, 65536, 128, 128): (1, 1, 4),
+ (256, 256, 131072, 16, 16): (1, 1, 2),
+ (256, 256, 131072, 32, 32): (1, 1, 1),
+ (256, 256, 131072, 64, 64): (1, 2, 4),
+ (256, 256, 131072, 128, 128): (1, 1, 4),
+ (512, 512, 256, 16, 16): (2, 5, 1),
+ (512, 512, 256, 32, 32): (1, 5, 1),
+ (512, 512, 256, 64, 64): (3, 3, 4),
+ (512, 512, 256, 128, 128): (1, 2, 8),
+ (512, 512, 512, 16, 16): (2, 4, 1),
+ (512, 512, 512, 32, 32): (1, 5, 2),
+ (512, 512, 512, 64, 64): (1, 5, 4),
+ (512, 512, 512, 128, 128): (2, 2, 8),
+ (512, 512, 1024, 16, 16): (1, 3, 1),
+ (512, 512, 1024, 32, 32): (1, 4, 1),
+ (512, 512, 1024, 64, 64): (2, 4, 4),
+ (512, 512, 1024, 128, 128): (1, 2, 8),
+ (512, 512, 2048, 16, 16): (3, 3, 1),
+ (512, 512, 2048, 32, 32): (3, 3, 1),
+ (512, 512, 2048, 64, 64): (3, 3, 4),
+ (512, 512, 2048, 128, 128): (2, 2, 8),
+ (512, 512, 4096, 16, 16): (2, 3, 1),
+ (512, 512, 4096, 32, 32): (1, 3, 1),
+ (512, 512, 4096, 64, 64): (1, 3, 4),
+ (512, 512, 4096, 128, 128): (2, 1, 4),
+ (512, 512, 8192, 16, 16): (2, 3, 1),
+ (512, 512, 8192, 32, 32): (1, 3, 1),
+ (512, 512, 8192, 64, 64): (1, 3, 2),
+ (512, 512, 8192, 128, 128): (6, 2, 8),
+ (512, 512, 16384, 16, 16): (1, 3, 1),
+ (512, 512, 16384, 32, 32): (1, 3, 1),
+ (512, 512, 16384, 64, 64): (3, 3, 2),
+ (512, 512, 16384, 128, 128): (2, 1, 4),
+ (512, 512, 32768, 16, 16): (1, 2, 1),
+ (512, 512, 32768, 32, 32): (1, 3, 1),
+ (512, 512, 32768, 64, 64): (1, 3, 2),
+ (512, 512, 32768, 128, 128): (2, 1, 4),
+ (512, 512, 65536, 16, 16): (4, 3, 1),
+ (512, 512, 65536, 32, 32): (1, 3, 1),
+ (512, 512, 65536, 64, 64): (1, 3, 2),
+ (512, 512, 65536, 128, 128): (1, 1, 4),
+ (512, 512, 131072, 16, 16): (1, 2, 2),
+ (512, 512, 131072, 32, 32): (1, 1, 1),
+ (512, 512, 131072, 64, 64): (2, 3, 2),
+ (512, 512, 131072, 128, 128): (1, 1, 4),
+ (1024, 1024, 256, 16, 16): (1, 4, 1),
+ (1024, 1024, 256, 32, 32): (1, 5, 2),
+ (1024, 1024, 256, 64, 64): (3, 3, 2),
+ (1024, 1024, 256, 128, 128): (1, 2, 8),
+ (1024, 1024, 512, 16, 16): (2, 3, 1),
+ (1024, 1024, 512, 32, 32): (2, 4, 1),
+ (1024, 1024, 512, 64, 64): (2, 3, 4),
+ (1024, 1024, 512, 128, 128): (1, 2, 8),
+ (1024, 1024, 1024, 16, 16): (1, 3, 1),
+ (1024, 1024, 1024, 32, 32): (1, 3, 1),
+ (1024, 1024, 1024, 64, 64): (3, 3, 4),
+ (1024, 1024, 1024, 128, 128): (2, 2, 8),
+ (1024, 1024, 2048, 16, 16): (2, 3, 1),
+ (1024, 1024, 2048, 32, 32): (1, 4, 1),
+ (1024, 1024, 2048, 64, 64): (1, 3, 2),
+ (1024, 1024, 2048, 128, 128): (4, 2, 8),
+ (1024, 1024, 4096, 16, 16): (2, 3, 1),
+ (1024, 1024, 4096, 32, 32): (1, 4, 1),
+ (1024, 1024, 4096, 64, 64): (1, 3, 4),
+ (1024, 1024, 4096, 128, 128): (4, 2, 8),
+ (1024, 1024, 8192, 16, 16): (2, 3, 1),
+ (1024, 1024, 8192, 32, 32): (2, 3, 1),
+ (1024, 1024, 8192, 64, 64): (4, 3, 2),
+ (1024, 1024, 8192, 128, 128): (4, 1, 4),
+ (1024, 1024, 16384, 16, 16): (1, 2, 1),
+ (1024, 1024, 16384, 32, 32): (4, 3, 1),
+ (1024, 1024, 16384, 64, 64): (1, 3, 2),
+ (1024, 1024, 16384, 128, 128): (4, 1, 4),
+ (1024, 1024, 32768, 16, 16): (1, 2, 1),
+ (1024, 1024, 32768, 32, 32): (1, 3, 1),
+ (1024, 1024, 32768, 64, 64): (1, 3, 2),
+ (1024, 1024, 32768, 128, 128): (2, 1, 4),
+ (1024, 1024, 65536, 16, 16): (2, 2, 1),
+ (1024, 1024, 65536, 32, 32): (9, 3, 1),
+ (1024, 1024, 65536, 64, 64): (7, 3, 2),
+ (1024, 1024, 65536, 128, 128): (4, 1, 4),
+ (1024, 1024, 131072, 16, 16): (1, 1, 4),
+ (1024, 1024, 131072, 32, 32): (1, 1, 1),
+ (1024, 1024, 131072, 64, 64): (2, 3, 2),
+ (1024, 1024, 131072, 128, 128): (4, 1, 4),
+ (2048, 2048, 256, 16, 16): (4, 5, 1),
+ (2048, 2048, 256, 32, 32): (1, 4, 1),
+ (2048, 2048, 256, 64, 64): (1, 3, 4),
+ (2048, 2048, 256, 128, 128): (3, 2, 8),
+ (2048, 2048, 512, 16, 16): (2, 4, 1),
+ (2048, 2048, 512, 32, 32): (1, 3, 1),
+ (2048, 2048, 512, 64, 64): (1, 3, 2),
+ (2048, 2048, 512, 128, 128): (1, 2, 8),
+ (2048, 2048, 1024, 16, 16): (1, 3, 1),
+ (2048, 2048, 1024, 32, 32): (1, 4, 1),
+ (2048, 2048, 1024, 64, 64): (1, 3, 4),
+ (2048, 2048, 1024, 128, 128): (6, 2, 8),
+ (2048, 2048, 2048, 16, 16): (2, 3, 1),
+ (2048, 2048, 2048, 32, 32): (1, 4, 1),
+ (2048, 2048, 2048, 64, 64): (1, 3, 2),
+ (2048, 2048, 2048, 128, 128): (6, 2, 8),
+ (2048, 2048, 4096, 16, 16): (4, 3, 1),
+ (2048, 2048, 4096, 32, 32): (2, 4, 2),
+ (2048, 2048, 4096, 64, 64): (1, 3, 2),
+ (2048, 2048, 4096, 128, 128): (2, 1, 4),
+ (2048, 2048, 8192, 16, 16): (8, 2, 1),
+ (2048, 2048, 8192, 32, 32): (4, 3, 2),
+ (2048, 2048, 8192, 64, 64): (4, 3, 2),
+ (2048, 2048, 8192, 128, 128): (2, 1, 4),
+ (2048, 2048, 16384, 16, 16): (4, 2, 1),
+ (2048, 2048, 16384, 32, 32): (4, 4, 1),
+ (2048, 2048, 16384, 64, 64): (4, 3, 2),
+ (2048, 2048, 16384, 128, 128): (2, 1, 4),
+ (2048, 2048, 32768, 16, 16): (2, 2, 1),
+ (2048, 2048, 32768, 32, 32): (11, 4, 1),
+ (2048, 2048, 32768, 64, 64): (1, 1, 1),
+ (2048, 2048, 32768, 128, 128): (2, 1, 4),
+ (2048, 2048, 65536, 16, 16): (8, 2, 1),
+ (2048, 2048, 65536, 32, 32): (9, 3, 1),
+ (2048, 2048, 65536, 64, 64): (1, 1, 1),
+ (2048, 2048, 65536, 128, 128): (2, 1, 4),
+ (2048, 2048, 131072, 16, 16): (1, 1, 4),
+ (2048, 2048, 131072, 32, 32): (1, 1, 1),
+ (2048, 2048, 131072, 64, 64): (2, 3, 2),
+ (2048, 2048, 131072, 128, 128): (2, 1, 4),
+ (4096, 4096, 256, 16, 16): (4, 4, 1),
+ (4096, 4096, 256, 32, 32): (1, 3, 1),
+ (4096, 4096, 256, 64, 64): (1, 3, 4),
+ (4096, 4096, 256, 128, 128): (1, 2, 8),
+ (4096, 4096, 512, 16, 16): (1, 4, 1),
+ (4096, 4096, 512, 32, 32): (1, 5, 2),
+ (4096, 4096, 512, 64, 64): (1, 3, 4),
+ (4096, 4096, 512, 128, 128): (2, 2, 8),
+ (4096, 4096, 1024, 16, 16): (1, 3, 1),
+ (4096, 4096, 1024, 32, 32): (1, 4, 2),
+ (4096, 4096, 1024, 64, 64): (1, 4, 4),
+ (4096, 4096, 1024, 128, 128): (2, 2, 8),
+ (4096, 4096, 2048, 16, 16): (1, 3, 1),
+ (4096, 4096, 2048, 32, 32): (3, 4, 2),
+ (4096, 4096, 2048, 64, 64): (1, 3, 2),
+ (4096, 4096, 2048, 128, 128): (2, 2, 8),
+ (4096, 4096, 4096, 16, 16): (2, 3, 1),
+ (4096, 4096, 4096, 32, 32): (2, 4, 2),
+ (4096, 4096, 4096, 64, 64): (1, 3, 2),
+ (4096, 4096, 4096, 128, 128): (4, 1, 4),
+ (4096, 4096, 8192, 16, 16): (4, 2, 1),
+ (4096, 4096, 8192, 32, 32): (2, 4, 2),
+ (4096, 4096, 8192, 64, 64): (4, 3, 2),
+ (4096, 4096, 8192, 128, 128): (4, 1, 4),
+ (4096, 4096, 16384, 16, 16): (1, 1, 1),
+ (4096, 4096, 16384, 32, 32): (4, 4, 1),
+ (4096, 4096, 16384, 64, 64): (4, 3, 2),
+ (4096, 4096, 16384, 128, 128): (4, 1, 4),
+ (4096, 4096, 32768, 16, 16): (1, 1, 1),
+ (4096, 4096, 32768, 32, 32): (3, 4, 1),
+ (4096, 4096, 32768, 64, 64): (3, 3, 2),
+ (4096, 4096, 32768, 128, 128): (4, 1, 4),
+ (4096, 4096, 65536, 16, 16): (1, 1, 1),
+ (4096, 4096, 65536, 32, 32): (3, 4, 1),
+ (4096, 4096, 65536, 64, 64): (2, 3, 2),
+ (4096, 4096, 65536, 128, 128): (4, 1, 4),
+ (4096, 4096, 131072, 16, 16): (1, 1, 4),
+ (4096, 4096, 131072, 32, 32): (1, 1, 1),
+ (4096, 4096, 131072, 64, 64): (3, 3, 2),
+ (4096, 4096, 131072, 128, 128): (4, 1, 4),
+ (8192, 8192, 256, 16, 16): (4, 4, 1),
+ (8192, 8192, 256, 32, 32): (2, 1, 1),
+ (8192, 8192, 256, 64, 64): (4, 3, 8),
+ (8192, 8192, 256, 128, 128): (1, 1, 4),
+ (8192, 8192, 512, 16, 16): (2, 2, 1),
+ (8192, 8192, 512, 32, 32): (4, 4, 2),
+ (8192, 8192, 512, 64, 64): (4, 4, 4),
+ (8192, 8192, 512, 128, 128): (4, 2, 8),
+ (8192, 8192, 1024, 16, 16): (4, 5, 1),
+ (8192, 8192, 1024, 32, 32): (4, 4, 2),
+ (8192, 8192, 1024, 64, 64): (4, 3, 4),
+ (8192, 8192, 1024, 128, 128): (4, 2, 8),
+ (8192, 8192, 2048, 16, 16): (4, 5, 1),
+ (8192, 8192, 2048, 32, 32): (4, 4, 2),
+ (8192, 8192, 2048, 64, 64): (1, 1, 1),
+ (8192, 8192, 2048, 128, 128): (4, 1, 4),
+ (8192, 8192, 4096, 16, 16): (4, 3, 1),
+ (8192, 8192, 4096, 32, 32): (4, 4, 2),
+ (8192, 8192, 4096, 64, 64): (4, 3, 2),
+ (8192, 8192, 4096, 128, 128): (4, 1, 4),
+ (8192, 8192, 8192, 16, 16): (4, 2, 1),
+ (8192, 8192, 8192, 32, 32): (4, 4, 1),
+ (8192, 8192, 8192, 64, 64): (4, 3, 2),
+ (8192, 8192, 8192, 128, 128): (4, 1, 4),
+ (8192, 8192, 16384, 16, 16): (4, 2, 1),
+ (8192, 8192, 16384, 32, 32): (4, 4, 1),
+ (8192, 8192, 16384, 64, 64): (4, 3, 2),
+ (8192, 8192, 16384, 128, 128): (4, 1, 4),
+ (8192, 8192, 32768, 16, 16): (4, 2, 1),
+ (8192, 8192, 32768, 32, 32): (4, 4, 1),
+ (8192, 8192, 32768, 64, 64): (4, 3, 2),
+ (8192, 8192, 32768, 128, 128): (4, 1, 4),
+ (8192, 8192, 65536, 16, 16): (4, 2, 1),
+ (8192, 8192, 65536, 32, 32): (4, 3, 1),
+ (8192, 8192, 65536, 64, 64): (4, 4, 2),
+ (8192, 8192, 65536, 128, 128): (4, 1, 4),
+ (8192, 8192, 131072, 16, 16): (4, 1, 4),
+ (8192, 8192, 131072, 32, 32): (4, 2, 1),
+ (8192, 8192, 131072, 64, 64): (4, 3, 2),
+ (8192, 8192, 131072, 128, 128): (4, 1, 4),
+ (16384, 16384, 256, 16, 16): (4, 8, 1),
+ (16384, 16384, 256, 32, 32): (4, 4, 2),
+ (16384, 16384, 256, 64, 64): (4, 4, 4),
+ (16384, 16384, 256, 128, 128): (6, 2, 8),
+ (16384, 16384, 512, 16, 16): (4, 7, 1),
+ (16384, 16384, 512, 32, 32): (4, 5, 2),
+ (16384, 16384, 512, 64, 64): (4, 3, 2),
+ (16384, 16384, 512, 128, 128): (4, 2, 8),
+ (16384, 16384, 1024, 16, 16): (4, 9, 1),
+ (16384, 16384, 1024, 32, 32): (4, 4, 1),
+ (16384, 16384, 1024, 64, 64): (4, 3, 2),
+ (16384, 16384, 1024, 128, 128): (4, 1, 4),
+ (16384, 16384, 2048, 16, 16): (4, 9, 1),
+ (16384, 16384, 2048, 32, 32): (4, 4, 1),
+ (16384, 16384, 2048, 64, 64): (4, 3, 2),
+ (16384, 16384, 2048, 128, 128): (4, 1, 4),
+ (16384, 16384, 4096, 16, 16): (4, 2, 1),
+ (16384, 16384, 4096, 32, 32): (4, 5, 1),
+ (16384, 16384, 4096, 64, 64): (4, 3, 2),
+ (16384, 16384, 4096, 128, 128): (4, 1, 4),
+ (16384, 16384, 8192, 16, 16): (4, 2, 1),
+ (16384, 16384, 8192, 32, 32): (4, 5, 1),
+ (16384, 16384, 8192, 64, 64): (4, 3, 2),
+ (16384, 16384, 8192, 128, 128): (4, 1, 4),
+ (16384, 16384, 16384, 16, 16): (4, 2, 1),
+ (16384, 16384, 16384, 32, 32): (4, 4, 1),
+ (16384, 16384, 16384, 64, 64): (4, 3, 2),
+ (16384, 16384, 16384, 128, 128): (4, 1, 4),
+ (16384, 16384, 32768, 16, 16): (4, 2, 1),
+ (16384, 16384, 32768, 32, 32): (4, 5, 1),
+ (16384, 16384, 32768, 64, 64): (4, 3, 2),
+ (16384, 16384, 32768, 128, 128): (4, 1, 4),
+ (16384, 16384, 65536, 16, 16): (4, 2, 1),
+ (16384, 16384, 65536, 32, 32): (4, 6, 1),
+ (16384, 16384, 65536, 64, 64): (4, 3, 2),
+ (16384, 16384, 65536, 128, 128): (4, 1, 4),
+ (16384, 16384, 131072, 16, 16): (4, 2, 2),
+ (16384, 16384, 131072, 32, 32): (4, 2, 1),
+ (16384, 16384, 131072, 64, 64): (4, 3, 2),
+ (16384, 16384, 131072, 128, 128): (4, 1, 4),
+ },
("bsr_dense_mm", "NVIDIA A100-SXM4-80GB", (0, torch.float16, 0.3)): {
(256, 256, 256, 16, 16): (5, 1, 2),
(256, 256, 256, 32, 32): (4, 3, 4),
@@ -769,80 +1154,80 @@
(16384, 16384, 131072, 128, 128): (4, 1, 4),
},
("bsr_dense_mm", "NVIDIA A100-SXM4-80GB", (0, torch.float16, 0.5)): {
- (256, 256, 256, 16, 16): (8, 1, 4),
- (256, 256, 256, 32, 32): (4, 3, 4),
+ (256, 256, 256, 16, 16): (8, 6, 1),
+ (256, 256, 256, 32, 32): (4, 5, 2),
(256, 256, 256, 64, 64): (3, 3, 4),
(256, 256, 256, 128, 128): (4, 2, 8),
- (256, 256, 512, 16, 16): (3, 1, 4),
- (256, 256, 512, 32, 32): (4, 1, 4),
- (256, 256, 512, 64, 64): (4, 3, 4),
- (256, 256, 512, 128, 128): (4, 2, 4),
+ (256, 256, 512, 16, 16): (4, 1, 4),
+ (256, 256, 512, 32, 32): (4, 3, 4),
+ (256, 256, 512, 64, 64): (6, 3, 4),
+ (256, 256, 512, 128, 128): (2, 2, 8),
(256, 256, 1024, 16, 16): (4, 1, 2),
- (256, 256, 1024, 32, 32): (4, 3, 4),
+ (256, 256, 1024, 32, 32): (1, 3, 2),
(256, 256, 1024, 64, 64): (4, 3, 4),
(256, 256, 1024, 128, 128): (4, 2, 8),
- (256, 256, 2048, 16, 16): (4, 1, 1),
- (256, 256, 2048, 32, 32): (5, 1, 4),
- (256, 256, 2048, 64, 64): (4, 3, 4),
+ (256, 256, 2048, 16, 16): (2, 3, 1),
+ (256, 256, 2048, 32, 32): (5, 4, 1),
+ (256, 256, 2048, 64, 64): (3, 3, 4),
(256, 256, 2048, 128, 128): (4, 2, 8),
- (256, 256, 4096, 16, 16): (4, 1, 1),
+ (256, 256, 4096, 16, 16): (4, 3, 1),
(256, 256, 4096, 32, 32): (4, 1, 4),
(256, 256, 4096, 64, 64): (4, 3, 4),
(256, 256, 4096, 128, 128): (4, 2, 8),
(256, 256, 8192, 16, 16): (5, 3, 1),
- (256, 256, 8192, 32, 32): (4, 3, 2),
+ (256, 256, 8192, 32, 32): (4, 3, 1),
(256, 256, 8192, 64, 64): (4, 2, 4),
- (256, 256, 8192, 128, 128): (4, 1, 4),
- (256, 256, 16384, 16, 16): (5, 2, 1),
- (256, 256, 16384, 32, 32): (3, 2, 2),
- (256, 256, 16384, 64, 64): (4, 1, 4),
+ (256, 256, 8192, 128, 128): (2, 1, 4),
+ (256, 256, 16384, 16, 16): (7, 3, 1),
+ (256, 256, 16384, 32, 32): (3, 3, 1),
+ (256, 256, 16384, 64, 64): (4, 2, 4),
(256, 256, 16384, 128, 128): (4, 1, 4),
- (256, 256, 32768, 16, 16): (5, 3, 1),
- (256, 256, 32768, 32, 32): (4, 3, 2),
+ (256, 256, 32768, 16, 16): (7, 3, 1),
+ (256, 256, 32768, 32, 32): (1, 3, 1),
(256, 256, 32768, 64, 64): (4, 2, 4),
(256, 256, 32768, 128, 128): (4, 1, 4),
(256, 256, 65536, 16, 16): (5, 3, 1),
(256, 256, 65536, 32, 32): (4, 3, 1),
- (256, 256, 65536, 64, 64): (5, 2, 4),
- (256, 256, 65536, 128, 128): (4, 1, 4),
+ (256, 256, 65536, 64, 64): (6, 2, 4),
+ (256, 256, 65536, 128, 128): (1, 1, 4),
(256, 256, 131072, 16, 16): (4, 1, 2),
(256, 256, 131072, 32, 32): (4, 2, 2),
(256, 256, 131072, 64, 64): (4, 2, 4),
(256, 256, 131072, 128, 128): (4, 1, 4),
- (512, 512, 256, 16, 16): (4, 1, 4),
- (512, 512, 256, 32, 32): (4, 1, 4),
+ (512, 512, 256, 16, 16): (4, 5, 1),
+ (512, 512, 256, 32, 32): (4, 5, 2),
(512, 512, 256, 64, 64): (4, 3, 4),
(512, 512, 256, 128, 128): (4, 2, 8),
- (512, 512, 512, 16, 16): (4, 1, 1),
- (512, 512, 512, 32, 32): (4, 1, 4),
- (512, 512, 512, 64, 64): (4, 3, 4),
+ (512, 512, 512, 16, 16): (2, 4, 1),
+ (512, 512, 512, 32, 32): (4, 4, 4),
+ (512, 512, 512, 64, 64): (4, 5, 4),
(512, 512, 512, 128, 128): (4, 2, 8),
- (512, 512, 1024, 16, 16): (5, 4, 1),
- (512, 512, 1024, 32, 32): (4, 1, 4),
+ (512, 512, 1024, 16, 16): (1, 4, 1),
+ (512, 512, 1024, 32, 32): (2, 3, 1),
(512, 512, 1024, 64, 64): (4, 4, 4),
(512, 512, 1024, 128, 128): (4, 2, 8),
- (512, 512, 2048, 16, 16): (4, 3, 1),
- (512, 512, 2048, 32, 32): (4, 3, 2),
- (512, 512, 2048, 64, 64): (3, 3, 4),
+ (512, 512, 2048, 16, 16): (5, 3, 1),
+ (512, 512, 2048, 32, 32): (2, 3, 2),
+ (512, 512, 2048, 64, 64): (1, 3, 2),
(512, 512, 2048, 128, 128): (4, 2, 8),
(512, 512, 4096, 16, 16): (4, 3, 1),
(512, 512, 4096, 32, 32): (4, 4, 2),
- (512, 512, 4096, 64, 64): (4, 3, 4),
- (512, 512, 4096, 128, 128): (4, 2, 8),
- (512, 512, 8192, 16, 16): (4, 3, 1),
- (512, 512, 8192, 32, 32): (5, 3, 2),
- (512, 512, 8192, 64, 64): (5, 2, 4),
- (512, 512, 8192, 128, 128): (4, 1, 4),
+ (512, 512, 4096, 64, 64): (5, 3, 4),
+ (512, 512, 4096, 128, 128): (2, 2, 8),
+ (512, 512, 8192, 16, 16): (2, 3, 1),
+ (512, 512, 8192, 32, 32): (1, 3, 1),
+ (512, 512, 8192, 64, 64): (5, 3, 2),
+ (512, 512, 8192, 128, 128): (4, 1, 16),
(512, 512, 16384, 16, 16): (4, 3, 1),
(512, 512, 16384, 32, 32): (4, 3, 1),
(512, 512, 16384, 64, 64): (4, 3, 4),
(512, 512, 16384, 128, 128): (4, 1, 4),
- (512, 512, 32768, 16, 16): (4, 2, 1),
+ (512, 512, 32768, 16, 16): (1, 2, 1),
(512, 512, 32768, 32, 32): (5, 3, 1),
(512, 512, 32768, 64, 64): (4, 3, 2),
(512, 512, 32768, 128, 128): (5, 1, 4),
(512, 512, 65536, 16, 16): (4, 3, 1),
- (512, 512, 65536, 32, 32): (4, 3, 1),
+ (512, 512, 65536, 32, 32): (1, 3, 1),
(512, 512, 65536, 64, 64): (4, 3, 2),
(512, 512, 65536, 128, 128): (5, 1, 4),
(512, 512, 131072, 16, 16): (4, 1, 4),
@@ -850,40 +1235,40 @@
(512, 512, 131072, 64, 64): (4, 3, 2),
(512, 512, 131072, 128, 128): (4, 1, 4),
(1024, 1024, 256, 16, 16): (4, 4, 1),
- (1024, 1024, 256, 32, 32): (4, 1, 4),
+ (1024, 1024, 256, 32, 32): (2, 4, 2),
(1024, 1024, 256, 64, 64): (4, 4, 4),
(1024, 1024, 256, 128, 128): (4, 2, 8),
- (1024, 1024, 512, 16, 16): (4, 4, 1),
+ (1024, 1024, 512, 16, 16): (3, 3, 1),
(1024, 1024, 512, 32, 32): (5, 4, 2),
(1024, 1024, 512, 64, 64): (4, 3, 4),
(1024, 1024, 512, 128, 128): (3, 2, 8),
- (1024, 1024, 1024, 16, 16): (5, 3, 1),
- (1024, 1024, 1024, 32, 32): (1, 3, 1),
+ (1024, 1024, 1024, 16, 16): (7, 3, 1),
+ (1024, 1024, 1024, 32, 32): (2, 3, 1),
(1024, 1024, 1024, 64, 64): (1, 3, 2),
(1024, 1024, 1024, 128, 128): (1, 2, 8),
- (1024, 1024, 2048, 16, 16): (4, 3, 1),
- (1024, 1024, 2048, 32, 32): (4, 4, 1),
+ (1024, 1024, 2048, 16, 16): (2, 3, 1),
+ (1024, 1024, 2048, 32, 32): (1, 4, 1),
(1024, 1024, 2048, 64, 64): (3, 3, 4),
(1024, 1024, 2048, 128, 128): (4, 2, 8),
(1024, 1024, 4096, 16, 16): (4, 3, 1),
- (1024, 1024, 4096, 32, 32): (4, 3, 1),
+ (1024, 1024, 4096, 32, 32): (4, 4, 1),
(1024, 1024, 4096, 64, 64): (4, 3, 4),
(1024, 1024, 4096, 128, 128): (4, 2, 8),
- (1024, 1024, 8192, 16, 16): (4, 3, 1),
+ (1024, 1024, 8192, 16, 16): (2, 3, 1),
(1024, 1024, 8192, 32, 32): (4, 3, 1),
- (1024, 1024, 8192, 64, 64): (4, 3, 4),
+ (1024, 1024, 8192, 64, 64): (4, 3, 2),
(1024, 1024, 8192, 128, 128): (4, 1, 4),
(1024, 1024, 16384, 16, 16): (4, 2, 1),
(1024, 1024, 16384, 32, 32): (4, 3, 1),
(1024, 1024, 16384, 64, 64): (4, 3, 2),
(1024, 1024, 16384, 128, 128): (4, 1, 4),
(1024, 1024, 32768, 16, 16): (4, 2, 1),
- (1024, 1024, 32768, 32, 32): (5, 3, 2),
- (1024, 1024, 32768, 64, 64): (5, 3, 2),
+ (1024, 1024, 32768, 32, 32): (8, 3, 1),
+ (1024, 1024, 32768, 64, 64): (8, 3, 2),
(1024, 1024, 32768, 128, 128): (4, 1, 4),
- (1024, 1024, 65536, 16, 16): (4, 2, 1),
- (1024, 1024, 65536, 32, 32): (5, 3, 1),
- (1024, 1024, 65536, 64, 64): (5, 3, 2),
+ (1024, 1024, 65536, 16, 16): (4, 4, 1),
+ (1024, 1024, 65536, 32, 32): (7, 3, 1),
+ (1024, 1024, 65536, 64, 64): (7, 3, 2),
(1024, 1024, 65536, 128, 128): (4, 1, 4),
(1024, 1024, 131072, 16, 16): (4, 1, 4),
(1024, 1024, 131072, 32, 32): (4, 2, 1),
@@ -891,66 +1276,66 @@
(1024, 1024, 131072, 128, 128): (4, 1, 4),
(2048, 2048, 256, 16, 16): (5, 4, 1),
(2048, 2048, 256, 32, 32): (4, 4, 2),
- (2048, 2048, 256, 64, 64): (5, 3, 4),
+ (2048, 2048, 256, 64, 64): (2, 3, 4),
(2048, 2048, 256, 128, 128): (4, 2, 8),
(2048, 2048, 512, 16, 16): (4, 4, 1),
(2048, 2048, 512, 32, 32): (5, 3, 2),
- (2048, 2048, 512, 64, 64): (6, 3, 4),
+ (2048, 2048, 512, 64, 64): (8, 3, 4),
(2048, 2048, 512, 128, 128): (4, 2, 8),
(2048, 2048, 1024, 16, 16): (4, 3, 1),
- (2048, 2048, 1024, 32, 32): (4, 3, 4),
+ (2048, 2048, 1024, 32, 32): (2, 4, 1),
(2048, 2048, 1024, 64, 64): (5, 3, 4),
- (2048, 2048, 1024, 128, 128): (4, 2, 8),
+ (2048, 2048, 1024, 128, 128): (6, 2, 8),
(2048, 2048, 2048, 16, 16): (3, 3, 1),
- (2048, 2048, 2048, 32, 32): (1, 4, 1),
+ (2048, 2048, 2048, 32, 32): (2, 4, 1),
(2048, 2048, 2048, 64, 64): (3, 3, 2),
(2048, 2048, 2048, 128, 128): (4, 2, 8),
(2048, 2048, 4096, 16, 16): (4, 3, 1),
(2048, 2048, 4096, 32, 32): (4, 4, 2),
- (2048, 2048, 4096, 64, 64): (4, 3, 2),
- (2048, 2048, 4096, 128, 128): (4, 1, 4),
- (2048, 2048, 8192, 16, 16): (4, 2, 1),
- (2048, 2048, 8192, 32, 32): (4, 4, 2),
- (2048, 2048, 8192, 64, 64): (4, 3, 2),
- (2048, 2048, 8192, 128, 128): (4, 1, 4),
+ (2048, 2048, 4096, 64, 64): (6, 3, 2),
+ (2048, 2048, 4096, 128, 128): (2, 1, 4),
+ (2048, 2048, 8192, 16, 16): (6, 2, 1),
+ (2048, 2048, 8192, 32, 32): (6, 4, 2),
+ (2048, 2048, 8192, 64, 64): (6, 3, 2),
+ (2048, 2048, 8192, 128, 128): (2, 1, 4),
(2048, 2048, 16384, 16, 16): (4, 2, 1),
- (2048, 2048, 16384, 32, 32): (4, 4, 2),
+ (2048, 2048, 16384, 32, 32): (4, 4, 1),
(2048, 2048, 16384, 64, 64): (4, 3, 2),
- (2048, 2048, 16384, 128, 128): (4, 1, 4),
- (2048, 2048, 32768, 16, 16): (6, 2, 1),
- (2048, 2048, 32768, 32, 32): (5, 4, 1),
- (2048, 2048, 32768, 64, 64): (5, 3, 2),
+ (2048, 2048, 16384, 128, 128): (2, 1, 4),
+ (2048, 2048, 32768, 16, 16): (8, 2, 1),
+ (2048, 2048, 32768, 32, 32): (7, 4, 1),
+ (2048, 2048, 32768, 64, 64): (9, 3, 2),
(2048, 2048, 32768, 128, 128): (4, 1, 4),
(2048, 2048, 65536, 16, 16): (3, 2, 1),
- (2048, 2048, 65536, 32, 32): (5, 4, 1),
+ (2048, 2048, 65536, 32, 32): (9, 3, 1),
(2048, 2048, 65536, 64, 64): (4, 3, 2),
- (2048, 2048, 65536, 128, 128): (4, 1, 4),
+ (2048, 2048, 65536, 128, 128): (2, 1, 4),
(2048, 2048, 131072, 16, 16): (4, 1, 4),
(2048, 2048, 131072, 32, 32): (4, 1, 1),
- (2048, 2048, 131072, 64, 64): (5, 3, 2),
+ (2048, 2048, 131072, 64, 64): (4, 3, 2),
(2048, 2048, 131072, 128, 128): (4, 1, 4),
(4096, 4096, 256, 16, 16): (4, 4, 1),
- (4096, 4096, 256, 32, 32): (4, 3, 2),
+ (4096, 4096, 256, 32, 32): (1, 3, 2),
(4096, 4096, 256, 64, 64): (3, 3, 4),
(4096, 4096, 256, 128, 128): (3, 2, 8),
(4096, 4096, 512, 16, 16): (1, 3, 1),
- (4096, 4096, 512, 32, 32): (4, 3, 4),
+ (4096, 4096, 512, 32, 32): (1, 3, 4),
(4096, 4096, 512, 64, 64): (6, 3, 4),
- (4096, 4096, 512, 128, 128): (4, 2, 8),
- (4096, 4096, 1024, 16, 16): (3, 3, 1),
+ (4096, 4096, 512, 128, 128): (2, 2, 8),
+ (4096, 4096, 1024, 16, 16): (1, 3, 1),
(4096, 4096, 1024, 32, 32): (4, 4, 2),
(4096, 4096, 1024, 64, 64): (4, 4, 4),
- (4096, 4096, 1024, 128, 128): (4, 2, 8),
+ (4096, 4096, 1024, 128, 128): (2, 2, 8),
(4096, 4096, 2048, 16, 16): (1, 3, 1),
(4096, 4096, 2048, 32, 32): (3, 4, 2),
- (4096, 4096, 2048, 64, 64): (4, 3, 4),
+ (4096, 4096, 2048, 64, 64): (4, 3, 2),
(4096, 4096, 2048, 128, 128): (4, 1, 4),
(4096, 4096, 4096, 16, 16): (2, 3, 1),
- (4096, 4096, 4096, 32, 32): (4, 4, 2),
+ (4096, 4096, 4096, 32, 32): (2, 4, 2),
(4096, 4096, 4096, 64, 64): (1, 3, 2),
(4096, 4096, 4096, 128, 128): (4, 1, 4),
- (4096, 4096, 8192, 16, 16): (4, 2, 1),
- (4096, 4096, 8192, 32, 32): (4, 3, 2),
+ (4096, 4096, 8192, 16, 16): (8, 2, 1),
+ (4096, 4096, 8192, 32, 32): (2, 4, 2),
(4096, 4096, 8192, 64, 64): (4, 3, 2),
(4096, 4096, 8192, 128, 128): (4, 1, 4),
(4096, 4096, 16384, 16, 16): (1, 1, 1),
@@ -959,18 +1344,18 @@
(4096, 4096, 16384, 128, 128): (4, 1, 4),
(4096, 4096, 32768, 16, 16): (4, 2, 1),
(4096, 4096, 32768, 32, 32): (5, 3, 1),
- (4096, 4096, 32768, 64, 64): (5, 3, 2),
+ (4096, 4096, 32768, 64, 64): (3, 3, 2),
(4096, 4096, 32768, 128, 128): (4, 1, 4),
(4096, 4096, 65536, 16, 16): (4, 2, 1),
- (4096, 4096, 65536, 32, 32): (2, 3, 1),
- (4096, 4096, 65536, 64, 64): (2, 3, 2),
+ (4096, 4096, 65536, 32, 32): (2, 4, 1),
+ (4096, 4096, 65536, 64, 64): (3, 3, 2),
(4096, 4096, 65536, 128, 128): (4, 1, 4),
(4096, 4096, 131072, 16, 16): (1, 1, 4),
(4096, 4096, 131072, 32, 32): (4, 2, 1),
- (4096, 4096, 131072, 64, 64): (5, 3, 2),
+ (4096, 4096, 131072, 64, 64): (7, 3, 2),
(4096, 4096, 131072, 128, 128): (4, 1, 4),
(8192, 8192, 256, 16, 16): (4, 4, 1),
- (8192, 8192, 256, 32, 32): (4, 3, 4),
+ (8192, 8192, 256, 32, 32): (4, 5, 2),
(8192, 8192, 256, 64, 64): (4, 3, 4),
(8192, 8192, 256, 128, 128): (5, 1, 4),
(8192, 8192, 512, 16, 16): (4, 5, 1),
@@ -981,9 +1366,9 @@
(8192, 8192, 1024, 32, 32): (4, 4, 2),
(8192, 8192, 1024, 64, 64): (4, 4, 4),
(8192, 8192, 1024, 128, 128): (4, 1, 4),
- (8192, 8192, 2048, 16, 16): (4, 3, 1),
- (8192, 8192, 2048, 32, 32): (4, 4, 1),
- (8192, 8192, 2048, 64, 64): (4, 3, 4),
+ (8192, 8192, 2048, 16, 16): (4, 5, 1),
+ (8192, 8192, 2048, 32, 32): (4, 4, 2),
+ (8192, 8192, 2048, 64, 64): (4, 3, 2),
(8192, 8192, 2048, 128, 128): (4, 1, 4),
(8192, 8192, 4096, 16, 16): (4, 3, 1),
(8192, 8192, 4096, 32, 32): (4, 4, 1),
@@ -999,29 +1384,29 @@
(8192, 8192, 16384, 128, 128): (4, 1, 4),
(8192, 8192, 32768, 16, 16): (4, 2, 1),
(8192, 8192, 32768, 32, 32): (4, 4, 1),
- (8192, 8192, 32768, 64, 64): (4, 4, 2),
+ (8192, 8192, 32768, 64, 64): (4, 3, 2),
(8192, 8192, 32768, 128, 128): (4, 1, 4),
(8192, 8192, 65536, 16, 16): (4, 2, 1),
- (8192, 8192, 65536, 32, 32): (4, 3, 1),
- (8192, 8192, 65536, 64, 64): (4, 3, 2),
+ (8192, 8192, 65536, 32, 32): (4, 4, 1),
+ (8192, 8192, 65536, 64, 64): (4, 4, 2),
(8192, 8192, 65536, 128, 128): (4, 1, 4),
(8192, 8192, 131072, 16, 16): (4, 1, 4),
(8192, 8192, 131072, 32, 32): (4, 2, 1),
(8192, 8192, 131072, 64, 64): (4, 3, 2),
(8192, 8192, 131072, 128, 128): (4, 1, 4),
- (16384, 16384, 256, 16, 16): (4, 4, 1),
+ (16384, 16384, 256, 16, 16): (4, 7, 1),
(16384, 16384, 256, 32, 32): (4, 4, 2),
(16384, 16384, 256, 64, 64): (4, 4, 4),
- (16384, 16384, 256, 128, 128): (4, 2, 8),
- (16384, 16384, 512, 16, 16): (4, 3, 1),
+ (16384, 16384, 256, 128, 128): (6, 2, 8),
+ (16384, 16384, 512, 16, 16): (4, 7, 1),
(16384, 16384, 512, 32, 32): (4, 5, 2),
- (16384, 16384, 512, 64, 64): (4, 4, 4),
+ (16384, 16384, 512, 64, 64): (4, 3, 2),
(16384, 16384, 512, 128, 128): (4, 2, 8),
- (16384, 16384, 1024, 16, 16): (4, 3, 1),
+ (16384, 16384, 1024, 16, 16): (4, 9, 1),
(16384, 16384, 1024, 32, 32): (4, 4, 1),
(16384, 16384, 1024, 64, 64): (4, 4, 4),
(16384, 16384, 1024, 128, 128): (4, 1, 4),
- (16384, 16384, 2048, 16, 16): (4, 3, 1),
+ (16384, 16384, 2048, 16, 16): (4, 9, 1),
(16384, 16384, 2048, 32, 32): (4, 4, 1),
(16384, 16384, 2048, 64, 64): (4, 4, 4),
(16384, 16384, 2048, 128, 128): (4, 1, 4),
@@ -1038,11 +1423,11 @@
(16384, 16384, 16384, 64, 64): (4, 3, 2),
(16384, 16384, 16384, 128, 128): (4, 1, 4),
(16384, 16384, 32768, 16, 16): (4, 2, 1),
- (16384, 16384, 32768, 32, 32): (4, 6, 1),
+ (16384, 16384, 32768, 32, 32): (4, 5, 1),
(16384, 16384, 32768, 64, 64): (4, 3, 2),
(16384, 16384, 32768, 128, 128): (4, 1, 4),
(16384, 16384, 65536, 16, 16): (4, 2, 1),
- (16384, 16384, 65536, 32, 32): (4, 2, 1),
+ (16384, 16384, 65536, 32, 32): (4, 5, 1),
(16384, 16384, 65536, 64, 64): (4, 3, 2),
(16384, 16384, 65536, 128, 128): (4, 1, 4),
(16384, 16384, 131072, 16, 16): (4, 1, 4),
@@ -1332,12 +1717,576 @@
(16384, 16384, 131072, 64, 64): (4, 3, 2),
(16384, 16384, 131072, 128, 128): (4, 1, 4),
},
+ ("bsr_dense_mm", "NVIDIA A100-SXM4-80GB", (0, torch.float32, 0.5)): {
+ (256, 256, 256, 16, 16): (1, 1, 8),
+ (256, 256, 256, 32, 32): (1, 3, 4),
+ (256, 256, 256, 64, 64): (1, 1, 8),
+ (256, 256, 256, 128, 128): (1, 1, 16),
+ (256, 256, 512, 16, 16): (2, 1, 4),
+ (256, 256, 512, 32, 32): (3, 1, 4),
+ (256, 256, 512, 64, 64): (2, 1, 8),
+ (256, 256, 512, 128, 128): (1, 1, 32),
+ (256, 256, 1024, 16, 16): (2, 1, 1),
+ (256, 256, 1024, 32, 32): (1, 1, 4),
+ (256, 256, 1024, 64, 64): (4, 1, 8),
+ (256, 256, 1024, 128, 128): (1, 1, 32),
+ (256, 256, 2048, 16, 16): (1, 1, 1),
+ (256, 256, 2048, 32, 32): (1, 1, 4),
+ (256, 256, 2048, 64, 64): (2, 1, 8),
+ (256, 256, 2048, 128, 128): (1, 1, 32),
+ (256, 256, 4096, 16, 16): (1, 1, 1),
+ (256, 256, 4096, 32, 32): (1, 1, 4),
+ (256, 256, 4096, 64, 64): (4, 1, 8),
+ (256, 256, 4096, 128, 128): (1, 1, 32),
+ (256, 256, 8192, 16, 16): (1, 1, 1),
+ (256, 256, 8192, 32, 32): (1, 1, 4),
+ (256, 256, 8192, 64, 64): (1, 1, 4),
+ (256, 256, 8192, 128, 128): (1, 1, 32),
+ (256, 256, 16384, 16, 16): (2, 1, 1),
+ (256, 256, 16384, 32, 32): (2, 1, 4),
+ (256, 256, 16384, 64, 64): (1, 1, 4),
+ (256, 256, 16384, 128, 128): (1, 1, 32),
+ (256, 256, 32768, 16, 16): (1, 1, 1),
+ (256, 256, 32768, 32, 32): (1, 1, 4),
+ (256, 256, 32768, 64, 64): (1, 1, 4),
+ (256, 256, 32768, 128, 128): (1, 1, 32),
+ (256, 256, 65536, 16, 16): (4, 1, 1),
+ (256, 256, 65536, 32, 32): (1, 1, 4),
+ (256, 256, 65536, 64, 64): (1, 1, 4),
+ (256, 256, 65536, 128, 128): (1, 1, 32),
+ (256, 256, 131072, 16, 16): (2, 1, 1),
+ (256, 256, 131072, 32, 32): (2, 1, 4),
+ (256, 256, 131072, 64, 64): (1, 1, 4),
+ (256, 256, 131072, 128, 128): (1, 1, 32),
+ (512, 512, 256, 16, 16): (1, 1, 4),
+ (512, 512, 256, 32, 32): (1, 1, 4),
+ (512, 512, 256, 64, 64): (1, 1, 8),
+ (512, 512, 256, 128, 128): (1, 1, 32),
+ (512, 512, 512, 16, 16): (1, 1, 1),
+ (512, 512, 512, 32, 32): (1, 1, 4),
+ (512, 512, 512, 64, 64): (1, 1, 8),
+ (512, 512, 512, 128, 128): (1, 1, 32),
+ (512, 512, 1024, 16, 16): (1, 1, 1),
+ (512, 512, 1024, 32, 32): (3, 1, 4),
+ (512, 512, 1024, 64, 64): (4, 1, 8),
+ (512, 512, 1024, 128, 128): (1, 1, 32),
+ (512, 512, 2048, 16, 16): (2, 1, 1),
+ (512, 512, 2048, 32, 32): (3, 1, 4),
+ (512, 512, 2048, 64, 64): (5, 1, 8),
+ (512, 512, 2048, 128, 128): (1, 1, 32),
+ (512, 512, 4096, 16, 16): (1, 1, 1),
+ (512, 512, 4096, 32, 32): (2, 1, 4),
+ (512, 512, 4096, 64, 64): (1, 1, 8),
+ (512, 512, 4096, 128, 128): (1, 1, 32),
+ (512, 512, 8192, 16, 16): (1, 1, 1),
+ (512, 512, 8192, 32, 32): (2, 1, 4),
+ (512, 512, 8192, 64, 64): (1, 1, 4),
+ (512, 512, 8192, 128, 128): (2, 1, 32),
+ (512, 512, 16384, 16, 16): (1, 1, 1),
+ (512, 512, 16384, 32, 32): (2, 1, 2),
+ (512, 512, 16384, 64, 64): (1, 1, 4),
+ (512, 512, 16384, 128, 128): (1, 1, 32),
+ (512, 512, 32768, 16, 16): (4, 1, 1),
+ (512, 512, 32768, 32, 32): (4, 1, 2),
+ (512, 512, 32768, 64, 64): (1, 1, 4),
+ (512, 512, 32768, 128, 128): (1, 1, 32),
+ (512, 512, 65536, 16, 16): (4, 1, 1),
+ (512, 512, 65536, 32, 32): (4, 1, 2),
+ (512, 512, 65536, 64, 64): (1, 1, 4),
+ (512, 512, 65536, 128, 128): (1, 1, 32),
+ (512, 512, 131072, 16, 16): (4, 1, 1),
+ (512, 512, 131072, 32, 32): (4, 1, 2),
+ (512, 512, 131072, 64, 64): (1, 1, 4),
+ (512, 512, 131072, 128, 128): (1, 1, 32),
+ (1024, 1024, 256, 16, 16): (1, 1, 1),
+ (1024, 1024, 256, 32, 32): (1, 1, 4),
+ (1024, 1024, 256, 64, 64): (1, 1, 8),
+ (1024, 1024, 256, 128, 128): (1, 1, 32),
+ (1024, 1024, 512, 16, 16): (1, 1, 1),
+ (1024, 1024, 512, 32, 32): (1, 1, 4),
+ (1024, 1024, 512, 64, 64): (4, 1, 8),
+ (1024, 1024, 512, 128, 128): (1, 1, 32),
+ (1024, 1024, 1024, 16, 16): (1, 1, 1),
+ (1024, 1024, 1024, 32, 32): (1, 1, 4),
+ (1024, 1024, 1024, 64, 64): (3, 1, 8),
+ (1024, 1024, 1024, 128, 128): (1, 1, 32),
+ (1024, 1024, 2048, 16, 16): (1, 1, 1),
+ (1024, 1024, 2048, 32, 32): (3, 1, 4),
+ (1024, 1024, 2048, 64, 64): (1, 1, 8),
+ (1024, 1024, 2048, 128, 128): (4, 1, 32),
+ (1024, 1024, 4096, 16, 16): (1, 1, 1),
+ (1024, 1024, 4096, 32, 32): (4, 1, 2),
+ (1024, 1024, 4096, 64, 64): (1, 1, 8),
+ (1024, 1024, 4096, 128, 128): (4, 1, 32),
+ (1024, 1024, 8192, 16, 16): (1, 1, 1),
+ (1024, 1024, 8192, 32, 32): (4, 1, 2),
+ (1024, 1024, 8192, 64, 64): (1, 1, 4),
+ (1024, 1024, 8192, 128, 128): (1, 1, 32),
+ (1024, 1024, 16384, 16, 16): (1, 1, 1),
+ (1024, 1024, 16384, 32, 32): (2, 1, 2),
+ (1024, 1024, 16384, 64, 64): (1, 1, 4),
+ (1024, 1024, 16384, 128, 128): (1, 1, 32),
+ (1024, 1024, 32768, 16, 16): (4, 1, 1),
+ (1024, 1024, 32768, 32, 32): (4, 1, 1),
+ (1024, 1024, 32768, 64, 64): (1, 1, 4),
+ (1024, 1024, 32768, 128, 128): (1, 1, 32),
+ (1024, 1024, 65536, 16, 16): (4, 1, 1),
+ (1024, 1024, 65536, 32, 32): (4, 1, 1),
+ (1024, 1024, 65536, 64, 64): (1, 1, 4),
+ (1024, 1024, 65536, 128, 128): (1, 1, 32),
+ (1024, 1024, 131072, 16, 16): (4, 1, 1),
+ (1024, 1024, 131072, 32, 32): (4, 1, 1),
+ (1024, 1024, 131072, 64, 64): (1, 1, 4),
+ (1024, 1024, 131072, 128, 128): (1, 1, 32),
+ (2048, 2048, 256, 16, 16): (1, 1, 1),
+ (2048, 2048, 256, 32, 32): (3, 1, 4),
+ (2048, 2048, 256, 64, 64): (2, 1, 8),
+ (2048, 2048, 256, 128, 128): (1, 1, 32),
+ (2048, 2048, 512, 16, 16): (1, 1, 1),
+ (2048, 2048, 512, 32, 32): (1, 1, 4),
+ (2048, 2048, 512, 64, 64): (4, 1, 8),
+ (2048, 2048, 512, 128, 128): (1, 1, 32),
+ (2048, 2048, 1024, 16, 16): (1, 1, 1),
+ (2048, 2048, 1024, 32, 32): (1, 1, 4),
+ (2048, 2048, 1024, 64, 64): (2, 1, 8),
+ (2048, 2048, 1024, 128, 128): (4, 1, 32),
+ (2048, 2048, 2048, 16, 16): (1, 1, 1),
+ (2048, 2048, 2048, 32, 32): (4, 1, 2),
+ (2048, 2048, 2048, 64, 64): (2, 1, 4),
+ (2048, 2048, 2048, 128, 128): (4, 1, 32),
+ (2048, 2048, 4096, 16, 16): (1, 1, 1),
+ (2048, 2048, 4096, 32, 32): (4, 1, 1),
+ (2048, 2048, 4096, 64, 64): (2, 1, 4),
+ (2048, 2048, 4096, 128, 128): (1, 1, 32),
+ (2048, 2048, 8192, 16, 16): (1, 1, 1),
+ (2048, 2048, 8192, 32, 32): (4, 1, 1),
+ (2048, 2048, 8192, 64, 64): (2, 1, 4),
+ (2048, 2048, 8192, 128, 128): (1, 1, 32),
+ (2048, 2048, 16384, 16, 16): (2, 1, 1),
+ (2048, 2048, 16384, 32, 32): (4, 1, 1),
+ (2048, 2048, 16384, 64, 64): (2, 1, 4),
+ (2048, 2048, 16384, 128, 128): (1, 1, 32),
+ (2048, 2048, 32768, 16, 16): (2, 1, 1),
+ (2048, 2048, 32768, 32, 32): (4, 1, 1),
+ (2048, 2048, 32768, 64, 64): (3, 1, 4),
+ (2048, 2048, 32768, 128, 128): (1, 1, 32),
+ (2048, 2048, 65536, 16, 16): (2, 1, 1),
+ (2048, 2048, 65536, 32, 32): (4, 1, 1),
+ (2048, 2048, 65536, 64, 64): (11, 1, 4),
+ (2048, 2048, 65536, 128, 128): (4, 1, 32),
+ (2048, 2048, 131072, 16, 16): (4, 1, 1),
+ (2048, 2048, 131072, 32, 32): (4, 1, 1),
+ (2048, 2048, 131072, 64, 64): (3, 1, 4),
+ (2048, 2048, 131072, 128, 128): (1, 1, 32),
+ (4096, 4096, 256, 16, 16): (1, 1, 1),
+ (4096, 4096, 256, 32, 32): (1, 1, 4),
+ (4096, 4096, 256, 64, 64): (4, 1, 8),
+ (4096, 4096, 256, 128, 128): (1, 1, 32),
+ (4096, 4096, 512, 16, 16): (1, 1, 1),
+ (4096, 4096, 512, 32, 32): (1, 1, 4),
+ (4096, 4096, 512, 64, 64): (4, 1, 4),
+ (4096, 4096, 512, 128, 128): (4, 1, 32),
+ (4096, 4096, 1024, 16, 16): (1, 1, 1),
+ (4096, 4096, 1024, 32, 32): (1, 1, 2),
+ (4096, 4096, 1024, 64, 64): (4, 1, 4),
+ (4096, 4096, 1024, 128, 128): (4, 1, 32),
+ (4096, 4096, 2048, 16, 16): (1, 1, 1),
+ (4096, 4096, 2048, 32, 32): (1, 1, 2),
+ (4096, 4096, 2048, 64, 64): (1, 1, 4),
+ (4096, 4096, 2048, 128, 128): (4, 1, 32),
+ (4096, 4096, 4096, 16, 16): (2, 1, 1),
+ (4096, 4096, 4096, 32, 32): (4, 1, 2),
+ (4096, 4096, 4096, 64, 64): (1, 1, 4),
+ (4096, 4096, 4096, 128, 128): (4, 1, 32),
+ (4096, 4096, 8192, 16, 16): (2, 1, 1),
+ (4096, 4096, 8192, 32, 32): (4, 1, 2),
+ (4096, 4096, 8192, 64, 64): (1, 1, 4),
+ (4096, 4096, 8192, 128, 128): (1, 1, 32),
+ (4096, 4096, 16384, 16, 16): (2, 1, 1),
+ (4096, 4096, 16384, 32, 32): (4, 1, 2),
+ (4096, 4096, 16384, 64, 64): (1, 1, 4),
+ (4096, 4096, 16384, 128, 128): (1, 1, 32),
+ (4096, 4096, 32768, 16, 16): (4, 1, 1),
+ (4096, 4096, 32768, 32, 32): (4, 1, 2),
+ (4096, 4096, 32768, 64, 64): (1, 1, 4),
+ (4096, 4096, 32768, 128, 128): (1, 1, 32),
+ (4096, 4096, 65536, 16, 16): (4, 1, 1),
+ (4096, 4096, 65536, 32, 32): (4, 1, 2),
+ (4096, 4096, 65536, 64, 64): (1, 1, 4),
+ (4096, 4096, 65536, 128, 128): (1, 1, 32),
+ (4096, 4096, 131072, 16, 16): (4, 1, 1),
+ (4096, 4096, 131072, 32, 32): (4, 1, 2),
+ (4096, 4096, 131072, 64, 64): (1, 1, 4),
+ (4096, 4096, 131072, 128, 128): (1, 1, 32),
+ (8192, 8192, 256, 16, 16): (4, 1, 1),
+ (8192, 8192, 256, 32, 32): (4, 1, 4),
+ (8192, 8192, 256, 64, 64): (4, 1, 4),
+ (8192, 8192, 256, 128, 128): (1, 1, 32),
+ (8192, 8192, 512, 16, 16): (4, 1, 1),
+ (8192, 8192, 512, 32, 32): (4, 1, 2),
+ (8192, 8192, 512, 64, 64): (4, 1, 4),
+ (8192, 8192, 512, 128, 128): (4, 1, 32),
+ (8192, 8192, 1024, 16, 16): (4, 1, 1),
+ (8192, 8192, 1024, 32, 32): (4, 1, 2),
+ (8192, 8192, 1024, 64, 64): (1, 1, 4),
+ (8192, 8192, 1024, 128, 128): (4, 1, 32),
+ (8192, 8192, 2048, 16, 16): (4, 1, 1),
+ (8192, 8192, 2048, 32, 32): (4, 1, 2),
+ (8192, 8192, 2048, 64, 64): (1, 1, 4),
+ (8192, 8192, 2048, 128, 128): (4, 1, 32),
+ (8192, 8192, 4096, 16, 16): (4, 1, 1),
+ (8192, 8192, 4096, 32, 32): (4, 1, 2),
+ (8192, 8192, 4096, 64, 64): (1, 1, 4),
+ (8192, 8192, 4096, 128, 128): (4, 1, 32),
+ (8192, 8192, 8192, 16, 16): (4, 1, 1),
+ (8192, 8192, 8192, 32, 32): (4, 1, 2),
+ (8192, 8192, 8192, 64, 64): (1, 1, 4),
+ (8192, 8192, 8192, 128, 128): (4, 1, 32),
+ (8192, 8192, 16384, 16, 16): (4, 1, 1),
+ (8192, 8192, 16384, 32, 32): (4, 1, 2),
+ (8192, 8192, 16384, 64, 64): (1, 1, 4),
+ (8192, 8192, 16384, 128, 128): (1, 1, 32),
+ (8192, 8192, 32768, 16, 16): (4, 1, 1),
+ (8192, 8192, 32768, 32, 32): (4, 1, 2),
+ (8192, 8192, 32768, 64, 64): (1, 1, 4),
+ (8192, 8192, 32768, 128, 128): (1, 1, 32),
+ (8192, 8192, 65536, 16, 16): (4, 1, 1),
+ (8192, 8192, 65536, 32, 32): (4, 1, 2),
+ (8192, 8192, 65536, 64, 64): (1, 1, 4),
+ (8192, 8192, 65536, 128, 128): (1, 1, 32),
+ (8192, 8192, 131072, 16, 16): (4, 1, 1),
+ (8192, 8192, 131072, 32, 32): (4, 1, 2),
+ (8192, 8192, 131072, 64, 64): (1, 1, 4),
+ (8192, 8192, 131072, 128, 128): (1, 1, 32),
+ (16384, 16384, 256, 16, 16): (4, 1, 1),
+ (16384, 16384, 256, 32, 32): (4, 1, 4),
+ (16384, 16384, 256, 64, 64): (2, 1, 4),
+ (16384, 16384, 256, 128, 128): (1, 1, 32),
+ (16384, 16384, 512, 16, 16): (4, 1, 1),
+ (16384, 16384, 512, 32, 32): (4, 1, 2),
+ (16384, 16384, 512, 64, 64): (4, 1, 4),
+ (16384, 16384, 512, 128, 128): (4, 1, 32),
+ (16384, 16384, 1024, 16, 16): (4, 1, 1),
+ (16384, 16384, 1024, 32, 32): (4, 1, 2),
+ (16384, 16384, 1024, 64, 64): (1, 1, 4),
+ (16384, 16384, 1024, 128, 128): (4, 1, 32),
+ (16384, 16384, 2048, 16, 16): (4, 1, 1),
+ (16384, 16384, 2048, 32, 32): (4, 1, 2),
+ (16384, 16384, 2048, 64, 64): (1, 1, 4),
+ (16384, 16384, 2048, 128, 128): (4, 1, 32),
+ (16384, 16384, 4096, 16, 16): (4, 1, 1),
+ (16384, 16384, 4096, 32, 32): (4, 1, 1),
+ (16384, 16384, 4096, 64, 64): (1, 1, 4),
+ (16384, 16384, 4096, 128, 128): (1, 1, 32),
+ (16384, 16384, 8192, 16, 16): (4, 1, 1),
+ (16384, 16384, 8192, 32, 32): (4, 1, 1),
+ (16384, 16384, 8192, 64, 64): (1, 1, 4),
+ (16384, 16384, 8192, 128, 128): (1, 1, 32),
+ (16384, 16384, 16384, 16, 16): (4, 1, 1),
+ (16384, 16384, 16384, 32, 32): (4, 1, 1),
+ (16384, 16384, 16384, 64, 64): (1, 1, 4),
+ (16384, 16384, 16384, 128, 128): (1, 1, 32),
+ (16384, 16384, 32768, 16, 16): (4, 1, 1),
+ (16384, 16384, 32768, 32, 32): (4, 1, 1),
+ (16384, 16384, 32768, 64, 64): (1, 1, 4),
+ (16384, 16384, 32768, 128, 128): (1, 1, 32),
+ (16384, 16384, 65536, 16, 16): (4, 1, 1),
+ (16384, 16384, 65536, 32, 32): (4, 1, 1),
+ (16384, 16384, 65536, 64, 64): (1, 1, 4),
+ (16384, 16384, 65536, 128, 128): (1, 1, 32),
+ (16384, 16384, 131072, 16, 16): (2, 1, 1),
+ (16384, 16384, 131072, 32, 32): (4, 1, 1),
+ (16384, 16384, 131072, 64, 64): (1, 1, 4),
+ (16384, 16384, 131072, 128, 128): (1, 1, 32),
+ },
+ ("scatter_mm", "NVIDIA A100-SXM4-80GB", (0, torch.bfloat16, 0.5)): {
+ (256, 256, 256, 16, 16): (1, 1, 16, 16, 1, 2),
+ (256, 256, 256, 32, 32): (1, 1, 16, 16, 1, 4),
+ (256, 256, 256, 64, 64): (1, 1, 16, 16, 1, 1),
+ (256, 256, 256, 128, 128): (2, 4, 16, 64, 1, 4),
+ (256, 256, 512, 16, 16): (1, 1, 16, 16, 1, 4),
+ (256, 256, 512, 32, 32): (1, 1, 16, 32, 1, 4),
+ (256, 256, 512, 64, 64): (1, 1, 16, 32, 1, 1),
+ (256, 256, 512, 128, 128): (1, 1, 32, 32, 1, 4),
+ (256, 256, 1024, 16, 16): (1, 1, 16, 16, 1, 4),
+ (256, 256, 1024, 32, 32): (1, 2, 16, 32, 1, 1),
+ (256, 256, 1024, 64, 64): (1, 1, 32, 32, 1, 2),
+ (256, 256, 1024, 128, 128): (1, 1, 32, 64, 1, 4),
+ (256, 256, 2048, 16, 16): (1, 1, 16, 64, 1, 8),
+ (256, 256, 2048, 32, 32): (2, 1, 32, 64, 1, 2),
+ (256, 256, 2048, 64, 64): (1, 1, 32, 32, 1, 1),
+ (256, 256, 2048, 128, 128): (1, 1, 64, 64, 1, 4),
+ (256, 256, 4096, 16, 16): (1, 1, 16, 64, 1, 1),
+ (256, 256, 4096, 32, 32): (2, 2, 32, 64, 1, 2),
+ (256, 256, 4096, 64, 64): (1, 1, 32, 128, 1, 4),
+ (256, 256, 4096, 128, 128): (1, 1, 64, 64, 1, 4),
+ (256, 256, 8192, 16, 16): (1, 2, 16, 64, 1, 2),
+ (256, 256, 8192, 32, 32): (1, 1, 32, 64, 1, 2),
+ (256, 256, 8192, 64, 64): (1, 1, 32, 64, 1, 2),
+ (256, 256, 8192, 128, 128): (1, 1, 64, 64, 1, 4),
+ (256, 256, 16384, 16, 16): (1, 1, 16, 64, 1, 2),
+ (256, 256, 16384, 32, 32): (1, 1, 32, 64, 1, 2),
+ (256, 256, 16384, 64, 64): (1, 1, 64, 64, 1, 2),
+ (256, 256, 16384, 128, 128): (2, 16, 64, 64, 1, 4),
+ (256, 256, 32768, 16, 16): (1, 1, 16, 128, 1, 2),
+ (256, 256, 32768, 32, 32): (1, 1, 32, 64, 1, 2),
+ (256, 256, 32768, 64, 64): (1, 1, 64, 64, 1, 2),
+ (256, 256, 32768, 128, 128): (2, 32, 64, 64, 1, 4),
+ (256, 256, 65536, 16, 16): (1, 1, 16, 64, 1, 1),
+ (256, 256, 65536, 32, 32): (1, 1, 32, 64, 1, 2),
+ (256, 256, 65536, 64, 64): (1, 1, 64, 32, 1, 1),
+ (256, 256, 65536, 128, 128): (2, 32, 64, 64, 1, 4),
+ (256, 256, 131072, 16, 16): (1, 1, 16, 64, 1, 1),
+ (256, 256, 131072, 32, 32): (1, 1, 32, 64, 1, 2),
+ (256, 256, 131072, 64, 64): (4, 1, 64, 32, 1, 1),
+ (256, 256, 131072, 128, 128): (2, 64, 64, 64, 1, 4),
+ (512, 512, 256, 16, 16): (1, 1, 16, 16, 1, 2),
+ (512, 512, 256, 32, 32): (1, 1, 16, 32, 1, 1),
+ (512, 512, 256, 64, 64): (1, 2, 16, 32, 1, 1),
+ (512, 512, 256, 128, 128): (2, 16, 64, 16, 2, 4),
+ (512, 512, 512, 16, 16): (1, 1, 16, 16, 1, 4),
+ (512, 512, 512, 32, 32): (1, 1, 16, 32, 1, 1),
+ (512, 512, 512, 64, 64): (1, 1, 32, 32, 1, 2),
+ (512, 512, 512, 128, 128): (2, 8, 32, 64, 1, 4),
+ (512, 512, 1024, 16, 16): (1, 1, 16, 64, 1, 8),
+ (512, 512, 1024, 32, 32): (1, 1, 32, 32, 3, 1),
+ (512, 512, 1024, 64, 64): (1, 4, 32, 64, 1, 2),
+ (512, 512, 1024, 128, 128): (1, 4, 64, 64, 1, 4),
+ (512, 512, 2048, 16, 16): (1, 1, 16, 64, 1, 2),
+ (512, 512, 2048, 32, 32): (1, 1, 32, 64, 1, 2),
+ (512, 512, 2048, 64, 64): (1, 1, 64, 64, 3, 4),
+ (512, 512, 2048, 128, 128): (1, 1, 64, 64, 1, 4),
+ (512, 512, 4096, 16, 16): (1, 1, 16, 64, 1, 2),
+ (512, 512, 4096, 32, 32): (2, 64, 32, 64, 1, 2),
+ (512, 512, 4096, 64, 64): (1, 1, 64, 64, 3, 4),
+ (512, 512, 4096, 128, 128): (1, 1, 64, 64, 1, 4),
+ (512, 512, 8192, 16, 16): (1, 2, 16, 128, 1, 2),
+ (512, 512, 8192, 32, 32): (1, 1, 32, 64, 1, 2),
+ (512, 512, 8192, 64, 64): (1, 1, 64, 64, 1, 2),
+ (512, 512, 8192, 128, 128): (1, 1, 64, 64, 1, 4),
+ (512, 512, 16384, 16, 16): (1, 2, 16, 128, 1, 2),
+ (512, 512, 16384, 32, 32): (1, 1, 32, 64, 1, 2),
+ (512, 512, 16384, 64, 64): (1, 1, 64, 64, 3, 2),
+ (512, 512, 16384, 128, 128): (2, 1, 64, 64, 1, 4),
+ (512, 512, 32768, 16, 16): (1, 2, 16, 128, 1, 2),
+ (512, 512, 32768, 32, 32): (1, 1, 32, 64, 1, 2),
+ (512, 512, 32768, 64, 64): (1, 1, 64, 64, 3, 4),
+ (512, 512, 32768, 128, 128): (2, 1, 64, 64, 1, 4),
+ (512, 512, 65536, 16, 16): (1, 2, 16, 128, 1, 2),
+ (512, 512, 65536, 32, 32): (1, 1, 32, 64, 1, 2),
+ (512, 512, 65536, 64, 64): (1, 1, 64, 64, 3, 4),
+ (512, 512, 65536, 128, 128): (2, 1, 64, 64, 1, 4),
+ (512, 512, 131072, 16, 16): (1, 1, 16, 64, 1, 1),
+ (512, 512, 131072, 32, 32): (1, 1, 32, 64, 1, 2),
+ (512, 512, 131072, 64, 64): (1, 1, 64, 64, 3, 4),
+ (512, 512, 131072, 128, 128): (2, 4, 64, 64, 1, 4),
+ (1024, 1024, 256, 16, 16): (1, 1, 16, 16, 1, 4),
+ (1024, 1024, 256, 32, 32): (2, 16, 32, 16, 3, 4),
+ (1024, 1024, 256, 64, 64): (1, 4, 32, 32, 1, 2),
+ (1024, 1024, 256, 128, 128): (1, 4, 128, 16, 3, 16),
+ (1024, 1024, 512, 16, 16): (1, 1, 16, 64, 1, 2),
+ (1024, 1024, 512, 32, 32): (2, 2, 32, 64, 1, 2),
+ (1024, 1024, 512, 64, 64): (2, 8, 64, 64, 3, 4),
+ (1024, 1024, 512, 128, 128): (1, 4, 64, 64, 1, 8),
+ (1024, 1024, 1024, 16, 16): (1, 1, 16, 64, 1, 2),
+ (1024, 1024, 1024, 32, 32): (1, 1, 32, 64, 1, 2),
+ (1024, 1024, 1024, 64, 64): (1, 8, 64, 64, 3, 4),
+ (1024, 1024, 1024, 128, 128): (1, 8, 64, 64, 1, 4),
+ (1024, 1024, 2048, 16, 16): (1, 2, 16, 64, 1, 2),
+ (1024, 1024, 2048, 32, 32): (1, 1, 32, 64, 1, 2),
+ (1024, 1024, 2048, 64, 64): (2, 16, 64, 64, 2, 2),
+ (1024, 1024, 2048, 128, 128): (2, 32, 64, 64, 1, 4),
+ (1024, 1024, 4096, 16, 16): (2, 16, 16, 128, 1, 2),
+ (1024, 1024, 4096, 32, 32): (1, 16, 32, 64, 3, 2),
+ (1024, 1024, 4096, 64, 64): (1, 1, 64, 64, 3, 4),
+ (1024, 1024, 4096, 128, 128): (2, 64, 128, 64, 1, 4),
+ (1024, 1024, 8192, 16, 16): (2, 16, 16, 128, 1, 2),
+ (1024, 1024, 8192, 32, 32): (1, 16, 32, 64, 3, 2),
+ (1024, 1024, 8192, 64, 64): (1, 1, 64, 64, 3, 4),
+ (1024, 1024, 8192, 128, 128): (2, 1, 64, 64, 1, 4),
+ (1024, 1024, 16384, 16, 16): (1, 2, 16, 128, 1, 2),
+ (1024, 1024, 16384, 32, 32): (1, 16, 32, 64, 3, 2),
+ (1024, 1024, 16384, 64, 64): (1, 1, 64, 64, 3, 4),
+ (1024, 1024, 16384, 128, 128): (2, 16, 128, 64, 1, 4),
+ (1024, 1024, 32768, 16, 16): (1, 1, 16, 128, 1, 2),
+ (1024, 1024, 32768, 32, 32): (1, 1, 32, 128, 1, 2),
+ (1024, 1024, 32768, 64, 64): (1, 32, 64, 32, 2, 1),
+ (1024, 1024, 32768, 128, 128): (2, 8, 128, 64, 1, 4),
+ (1024, 1024, 65536, 16, 16): (3, 2, 16, 128, 1, 2),
+ (1024, 1024, 65536, 32, 32): (1, 1, 32, 128, 1, 2),
+ (1024, 1024, 65536, 64, 64): (2, 4, 64, 32, 2, 1),
+ (1024, 1024, 65536, 128, 128): (2, 8, 128, 64, 1, 4),
+ (1024, 1024, 131072, 16, 16): (2, 1, 16, 128, 1, 2),
+ (1024, 1024, 131072, 32, 32): (1, 1, 32, 128, 1, 2),
+ (1024, 1024, 131072, 64, 64): (1, 4, 64, 32, 2, 1),
+ (1024, 1024, 131072, 128, 128): (4, 1, 128, 64, 1, 4),
+ (2048, 2048, 256, 16, 16): (1, 1, 16, 64, 1, 8),
+ (2048, 2048, 256, 32, 32): (1, 1, 32, 32, 3, 1),
+ (2048, 2048, 256, 64, 64): (1, 1, 32, 32, 2, 1),
+ (2048, 2048, 256, 128, 128): (1, 4, 64, 64, 1, 8),
+ (2048, 2048, 512, 16, 16): (1, 2, 16, 64, 1, 2),
+ (2048, 2048, 512, 32, 32): (1, 2, 32, 64, 1, 4),
+ (2048, 2048, 512, 64, 64): (1, 4, 64, 64, 1, 8),
+ (2048, 2048, 512, 128, 128): (1, 4, 64, 64, 1, 4),
+ (2048, 2048, 1024, 16, 16): (1, 2, 16, 128, 1, 2),
+ (2048, 2048, 1024, 32, 32): (1, 1, 32, 64, 1, 2),
+ (2048, 2048, 1024, 64, 64): (1, 8, 64, 64, 1, 4),
+ (2048, 2048, 1024, 128, 128): (1, 8, 128, 64, 1, 4),
+ (2048, 2048, 2048, 16, 16): (3, 4, 16, 128, 1, 2),
+ (2048, 2048, 2048, 32, 32): (1, 16, 32, 64, 5, 2),
+ (2048, 2048, 2048, 64, 64): (1, 1, 64, 64, 3, 4),
+ (2048, 2048, 2048, 128, 128): (1, 8, 128, 64, 1, 4),
+ (2048, 2048, 4096, 16, 16): (1, 2, 16, 128, 1, 2),
+ (2048, 2048, 4096, 32, 32): (1, 8, 32, 64, 3, 2),
+ (2048, 2048, 4096, 64, 64): (1, 1, 64, 64, 3, 4),
+ (2048, 2048, 4096, 128, 128): (1, 8, 128, 64, 1, 4),
+ (2048, 2048, 8192, 16, 16): (2, 4, 16, 128, 1, 2),
+ (2048, 2048, 8192, 32, 32): (1, 4, 32, 128, 3, 2),
+ (2048, 2048, 8192, 64, 64): (1, 8, 64, 64, 3, 2),
+ (2048, 2048, 8192, 128, 128): (1, 8, 128, 64, 1, 4),
+ (2048, 2048, 16384, 16, 16): (1, 2, 16, 128, 1, 2),
+ (2048, 2048, 16384, 32, 32): (1, 4, 32, 128, 3, 2),
+ (2048, 2048, 16384, 64, 64): (1, 8, 64, 64, 3, 2),
+ (2048, 2048, 16384, 128, 128): (1, 4, 128, 64, 1, 4),
+ (2048, 2048, 32768, 16, 16): (3, 2, 16, 128, 1, 2),
+ (2048, 2048, 32768, 32, 32): (1, 1, 32, 128, 3, 2),
+ (2048, 2048, 32768, 64, 64): (1, 1, 64, 64, 3, 2),
+ (2048, 2048, 32768, 128, 128): (1, 4, 128, 64, 1, 4),
+ (2048, 2048, 65536, 16, 16): (1, 2, 16, 128, 1, 2),
+ (2048, 2048, 65536, 32, 32): (1, 4, 32, 128, 1, 2),
+ (2048, 2048, 65536, 64, 64): (1, 1, 64, 64, 3, 2),
+ (2048, 2048, 65536, 128, 128): (1, 2, 128, 64, 1, 4),
+ (2048, 2048, 131072, 16, 16): (4, 2, 16, 128, 1, 2),
+ (2048, 2048, 131072, 32, 32): (1, 1, 32, 128, 3, 2),
+ (2048, 2048, 131072, 64, 64): (1, 1, 64, 64, 3, 2),
+ (2048, 2048, 131072, 128, 128): (1, 2, 128, 64, 1, 4),
+ (4096, 4096, 256, 16, 16): (1, 1, 16, 64, 1, 2),
+ (4096, 4096, 256, 32, 32): (1, 1, 32, 64, 3, 4),
+ (4096, 4096, 256, 64, 64): (1, 1, 64, 64, 3, 4),
+ (4096, 4096, 256, 128, 128): (3, 4, 128, 32, 1, 4),
+ (4096, 4096, 512, 16, 16): (1, 2, 16, 128, 1, 2),
+ (4096, 4096, 512, 32, 32): (1, 2, 32, 64, 3, 2),
+ (4096, 4096, 512, 64, 64): (1, 4, 64, 64, 1, 4),
+ (4096, 4096, 512, 128, 128): (1, 4, 128, 64, 1, 4),
+ (4096, 4096, 1024, 16, 16): (1, 2, 16, 128, 1, 2),
+ (4096, 4096, 1024, 32, 32): (1, 8, 32, 64, 3, 2),
+ (4096, 4096, 1024, 64, 64): (1, 4, 64, 64, 1, 4),
+ (4096, 4096, 1024, 128, 128): (2, 4, 128, 64, 1, 4),
+ (4096, 4096, 2048, 16, 16): (1, 1, 16, 128, 1, 2),
+ (4096, 4096, 2048, 32, 32): (1, 4, 32, 128, 1, 4),
+ (4096, 4096, 2048, 64, 64): (1, 1, 64, 64, 3, 4),
+ (4096, 4096, 2048, 128, 128): (1, 16, 128, 64, 1, 4),
+ (4096, 4096, 4096, 16, 16): (1, 1, 16, 64, 3, 1),
+ (4096, 4096, 4096, 32, 32): (1, 4, 32, 64, 3, 2),
+ (4096, 4096, 4096, 64, 64): (1, 1, 64, 64, 3, 4),
+ (4096, 4096, 4096, 128, 128): (5, 1, 128, 64, 1, 4),
+ (4096, 4096, 8192, 16, 16): (1, 1, 16, 128, 1, 2),
+ (4096, 4096, 8192, 32, 32): (1, 1, 32, 128, 3, 2),
+ (4096, 4096, 8192, 64, 64): (1, 1, 64, 64, 3, 4),
+ (4096, 4096, 8192, 128, 128): (2, 1, 128, 64, 1, 4),
+ (4096, 4096, 16384, 16, 16): (1, 1, 16, 128, 1, 2),
+ (4096, 4096, 16384, 32, 32): (1, 1, 32, 128, 3, 2),
+ (4096, 4096, 16384, 64, 64): (1, 1, 64, 64, 4, 4),
+ (4096, 4096, 16384, 128, 128): (2, 1, 128, 64, 1, 4),
+ (4096, 4096, 32768, 16, 16): (3, 1, 16, 128, 1, 2),
+ (4096, 4096, 32768, 32, 32): (1, 1, 32, 128, 3, 2),
+ (4096, 4096, 32768, 64, 64): (1, 1, 64, 64, 3, 4),
+ (4096, 4096, 32768, 128, 128): (2, 1, 128, 64, 1, 4),
+ (4096, 4096, 65536, 16, 16): (2, 2, 16, 128, 1, 2),
+ (4096, 4096, 65536, 32, 32): (1, 1, 32, 128, 4, 2),
+ (4096, 4096, 65536, 64, 64): (1, 1, 64, 64, 4, 4),
+ (4096, 4096, 65536, 128, 128): (2, 1, 128, 64, 1, 4),
+ (4096, 4096, 131072, 16, 16): (2, 1, 16, 128, 1, 2),
+ (4096, 4096, 131072, 32, 32): (1, 1, 32, 128, 3, 2),
+ (4096, 4096, 131072, 64, 64): (1, 1, 64, 64, 3, 4),
+ (4096, 4096, 131072, 128, 128): (2, 1, 128, 64, 1, 4),
+ (8192, 8192, 256, 16, 16): (1, 2, 16, 64, 1, 2),
+ (8192, 8192, 256, 32, 32): (1, 1, 32, 64, 1, 2),
+ (8192, 8192, 256, 64, 64): (1, 2, 64, 64, 1, 4),
+ (8192, 8192, 256, 128, 128): (3, 16, 128, 16, 1, 2),
+ (8192, 8192, 512, 16, 16): (1, 2, 16, 128, 1, 2),
+ (8192, 8192, 512, 32, 32): (1, 4, 32, 64, 3, 2),
+ (8192, 8192, 512, 64, 64): (2, 8, 64, 64, 4, 4),
+ (8192, 8192, 512, 128, 128): (1, 8, 128, 64, 1, 4),
+ (8192, 8192, 1024, 16, 16): (4, 2, 16, 128, 1, 2),
+ (8192, 8192, 1024, 32, 32): (1, 8, 32, 128, 1, 2),
+ (8192, 8192, 1024, 64, 64): (1, 16, 64, 64, 3, 2),
+ (8192, 8192, 1024, 128, 128): (2, 16, 128, 64, 2, 4),
+ (8192, 8192, 2048, 16, 16): (2, 1, 16, 64, 4, 1),
+ (8192, 8192, 2048, 32, 32): (1, 16, 32, 64, 5, 2),
+ (8192, 8192, 2048, 64, 64): (1, 16, 64, 64, 3, 2),
+ (8192, 8192, 2048, 128, 128): (2, 16, 128, 64, 2, 4),
+ (8192, 8192, 4096, 16, 16): (1, 1, 16, 64, 4, 1),
+ (8192, 8192, 4096, 32, 32): (1, 16, 32, 64, 5, 2),
+ (8192, 8192, 4096, 64, 64): (1, 16, 64, 64, 3, 2),
+ (8192, 8192, 4096, 128, 128): (2, 64, 128, 64, 2, 4),
+ (8192, 8192, 8192, 16, 16): (1, 1, 16, 64, 4, 1),
+ (8192, 8192, 8192, 32, 32): (1, 8, 32, 128, 5, 4),
+ (8192, 8192, 8192, 64, 64): (1, 8, 64, 64, 3, 2),
+ (8192, 8192, 8192, 128, 128): (2, 8, 128, 64, 1, 4),
+ (8192, 8192, 16384, 16, 16): (1, 1, 16, 64, 4, 1),
+ (8192, 8192, 16384, 32, 32): (1, 8, 32, 64, 5, 2),
+ (8192, 8192, 16384, 64, 64): (1, 8, 64, 64, 3, 2),
+ (8192, 8192, 16384, 128, 128): (1, 8, 128, 64, 1, 4),
+ (8192, 8192, 32768, 16, 16): (1, 1, 16, 64, 4, 1),
+ (8192, 8192, 32768, 32, 32): (1, 8, 32, 64, 5, 2),
+ (8192, 8192, 32768, 64, 64): (3, 8, 64, 64, 3, 2),
+ (8192, 8192, 32768, 128, 128): (2, 8, 128, 64, 1, 4),
+ (8192, 8192, 65536, 16, 16): (1, 1, 16, 64, 4, 1),
+ (8192, 8192, 65536, 32, 32): (5, 4, 32, 64, 3, 2),
+ (8192, 8192, 65536, 64, 64): (1, 8, 64, 64, 3, 2),
+ (8192, 8192, 65536, 128, 128): (2, 8, 128, 64, 1, 4),
+ (8192, 8192, 131072, 16, 16): (2, 1, 16, 64, 4, 1),
+ (8192, 8192, 131072, 32, 32): (1, 4, 32, 64, 5, 2),
+ (8192, 8192, 131072, 64, 64): (1, 4, 64, 128, 3, 4),
+ (8192, 8192, 131072, 128, 128): (2, 8, 128, 64, 1, 4),
+ (16384, 16384, 256, 16, 16): (1, 2, 16, 128, 1, 2),
+ (16384, 16384, 256, 32, 32): (1, 4, 32, 64, 3, 2),
+ (16384, 16384, 256, 64, 64): (2, 4, 64, 64, 4, 4),
+ (16384, 16384, 256, 128, 128): (1, 4, 128, 64, 1, 16),
+ (16384, 16384, 512, 16, 16): (1, 2, 16, 128, 3, 2),
+ (16384, 16384, 512, 32, 32): (1, 4, 32, 128, 5, 4),
+ (16384, 16384, 512, 64, 64): (1, 8, 64, 64, 3, 2),
+ (16384, 16384, 512, 128, 128): (2, 8, 128, 64, 1, 4),
+ (16384, 16384, 1024, 16, 16): (1, 2, 16, 128, 1, 2),
+ (16384, 16384, 1024, 32, 32): (1, 8, 32, 64, 5, 2),
+ (16384, 16384, 1024, 64, 64): (1, 16, 64, 64, 3, 2),
+ (16384, 16384, 1024, 128, 128): (5, 16, 128, 64, 2, 4),
+ (16384, 16384, 2048, 16, 16): (1, 2, 16, 128, 1, 2),
+ (16384, 16384, 2048, 32, 32): (1, 8, 32, 64, 5, 2),
+ (16384, 16384, 2048, 64, 64): (1, 16, 64, 64, 3, 2),
+ (16384, 16384, 2048, 128, 128): (4, 32, 128, 64, 2, 4),
+ (16384, 16384, 4096, 16, 16): (3, 2, 16, 128, 1, 2),
+ (16384, 16384, 4096, 32, 32): (1, 4, 32, 64, 5, 2),
+ (16384, 16384, 4096, 64, 64): (2, 16, 64, 64, 3, 2),
+ (16384, 16384, 4096, 128, 128): (3, 32, 128, 64, 2, 4),
+ (16384, 16384, 8192, 16, 16): (1, 2, 16, 128, 1, 2),
+ (16384, 16384, 8192, 32, 32): (1, 4, 32, 64, 5, 2),
+ (16384, 16384, 8192, 64, 64): (4, 8, 64, 64, 3, 2),
+ (16384, 16384, 8192, 128, 128): (5, 8, 128, 64, 1, 4),
+ (16384, 16384, 16384, 16, 16): (1, 2, 16, 128, 1, 2),
+ (16384, 16384, 16384, 32, 32): (1, 4, 32, 64, 5, 2),
+ (16384, 16384, 16384, 64, 64): (2, 4, 64, 128, 3, 4),
+ (16384, 16384, 16384, 128, 128): (4, 8, 128, 64, 1, 4),
+ (16384, 16384, 32768, 16, 16): (4, 2, 16, 128, 1, 2),
+ (16384, 16384, 32768, 32, 32): (1, 4, 32, 64, 5, 2),
+ (16384, 16384, 32768, 64, 64): (1, 8, 64, 64, 3, 2),
+ (16384, 16384, 32768, 128, 128): (2, 512, 128, 64, 2, 4),
+ (16384, 16384, 65536, 16, 16): (3, 2, 16, 128, 1, 2),
+ (16384, 16384, 65536, 32, 32): (1, 4, 32, 64, 5, 2),
+ (16384, 16384, 65536, 64, 64): (1, 4, 64, 128, 3, 4),
+ (16384, 16384, 65536, 128, 128): (2, 1024, 128, 64, 2, 4),
+ (16384, 16384, 131072, 16, 16): (1, 2, 16, 128, 1, 2),
+ (16384, 16384, 131072, 32, 32): (1, 4, 32, 64, 5, 2),
+ (16384, 16384, 131072, 64, 64): (3, 4, 64, 128, 3, 4),
+ (16384, 16384, 131072, 128, 128): (4, 2048, 128, 64, 2, 4),
+ },
("scatter_mm", "NVIDIA A100-SXM4-80GB", (0, torch.float16, 0.5)): {
(256, 256, 256, 16, 16): (5, 4, 16, 16, 1, 4),
(256, 256, 256, 32, 32): (5, 2, 32, 16, 1, 4),
(256, 256, 256, 64, 64): (4, 1, 32, 32, 1, 8),
(256, 256, 256, 128, 128): (2, 1, 32, 32, 1, 4),
- (256, 256, 512, 16, 16): (4, 8, 16, 32, 1, 4),
+ (256, 256, 512, 16, 16): (2, 2, 16, 32, 1, 4),
(256, 256, 512, 32, 32): (4, 8, 32, 32, 1, 8),
(256, 256, 512, 64, 64): (4, 8, 32, 64, 1, 4),
(256, 256, 512, 128, 128): (4, 8, 32, 64, 1, 4),
@@ -1345,12 +2294,12 @@
(256, 256, 1024, 32, 32): (4, 16, 32, 64, 1, 2),
(256, 256, 1024, 64, 64): (4, 16, 32, 64, 1, 4),
(256, 256, 1024, 128, 128): (4, 16, 64, 64, 1, 8),
- (256, 256, 2048, 16, 16): (4, 16, 16, 64, 1, 1),
+ (256, 256, 2048, 16, 16): (2, 16, 16, 64, 1, 8),
(256, 256, 2048, 32, 32): (4, 16, 32, 64, 1, 2),
(256, 256, 2048, 64, 64): (4, 16, 32, 64, 1, 4),
(256, 256, 2048, 128, 128): (4, 16, 64, 64, 1, 4),
(256, 256, 4096, 16, 16): (4, 32, 16, 64, 1, 1),
- (256, 256, 4096, 32, 32): (4, 32, 32, 64, 1, 2),
+ (256, 256, 4096, 32, 32): (2, 64, 32, 64, 1, 2),
(256, 256, 4096, 64, 64): (4, 64, 64, 64, 1, 4),
(256, 256, 4096, 128, 128): (4, 32, 64, 64, 1, 4),
(256, 256, 8192, 16, 16): (4, 64, 16, 64, 1, 1),
@@ -1358,33 +2307,33 @@
(256, 256, 8192, 64, 64): (4, 64, 64, 64, 1, 4),
(256, 256, 8192, 128, 128): (4, 64, 64, 64, 1, 4),
(256, 256, 16384, 16, 16): (4, 128, 16, 64, 1, 1),
- (256, 256, 16384, 32, 32): (4, 16, 32, 64, 1, 2),
+ (256, 256, 16384, 32, 32): (2, 128, 32, 64, 1, 2),
(256, 256, 16384, 64, 64): (4, 32, 32, 128, 1, 4),
(256, 256, 16384, 128, 128): (4, 16, 64, 64, 1, 4),
(256, 256, 32768, 16, 16): (4, 64, 16, 64, 1, 1),
- (256, 256, 32768, 32, 32): (4, 32, 32, 64, 1, 2),
+ (256, 256, 32768, 32, 32): (2, 256, 32, 64, 1, 2),
(256, 256, 32768, 64, 64): (4, 32, 32, 128, 1, 4),
(256, 256, 32768, 128, 128): (4, 32, 64, 64, 1, 4),
(256, 256, 65536, 16, 16): (4, 128, 16, 64, 1, 1),
- (256, 256, 65536, 32, 32): (4, 16, 32, 64, 1, 2),
- (256, 256, 65536, 64, 64): (4, 16, 64, 64, 1, 2),
+ (256, 256, 65536, 32, 32): (4, 1, 32, 64, 1, 2),
+ (256, 256, 65536, 64, 64): (2, 1, 64, 64, 1, 2),
(256, 256, 65536, 128, 128): (4, 32, 64, 64, 1, 4),
(256, 256, 131072, 16, 16): (4, 64, 16, 64, 1, 1),
- (256, 256, 131072, 32, 32): (4, 2, 32, 64, 1, 2),
+ (256, 256, 131072, 32, 32): (2, 1, 32, 64, 1, 2),
(256, 256, 131072, 64, 64): (4, 32, 32, 128, 1, 4),
(256, 256, 131072, 128, 128): (4, 32, 64, 64, 1, 4),
(512, 512, 256, 16, 16): (4, 16, 16, 16, 1, 4),
- (512, 512, 256, 32, 32): (4, 16, 32, 16, 1, 4),
- (512, 512, 256, 64, 64): (4, 16, 64, 16, 1, 8),
+ (512, 512, 256, 32, 32): (2, 4, 32, 16, 1, 4),
+ (512, 512, 256, 64, 64): (2, 16, 64, 16, 3, 8),
(512, 512, 256, 128, 128): (4, 16, 64, 16, 1, 4),
- (512, 512, 512, 16, 16): (2, 1, 16, 64, 1, 2),
+ (512, 512, 512, 16, 16): (1, 1, 16, 64, 1, 8),
(512, 512, 512, 32, 32): (2, 4, 16, 32, 1, 1),
(512, 512, 512, 64, 64): (2, 1, 32, 32, 1, 2),
(512, 512, 512, 128, 128): (4, 8, 32, 64, 1, 4),
- (512, 512, 1024, 16, 16): (4, 8, 16, 64, 1, 1),
+ (512, 512, 1024, 16, 16): (2, 8, 16, 64, 1, 8),
(512, 512, 1024, 32, 32): (4, 16, 32, 64, 1, 2),
(512, 512, 1024, 64, 64): (4, 16, 64, 64, 1, 4),
- (512, 512, 1024, 128, 128): (4, 16, 64, 64, 1, 4),
+ (512, 512, 1024, 128, 128): (2, 8, 64, 64, 1, 4),
(512, 512, 2048, 16, 16): (4, 16, 16, 64, 1, 4),
(512, 512, 2048, 32, 32): (4, 16, 32, 64, 1, 2),
(512, 512, 2048, 64, 64): (4, 16, 64, 64, 1, 8),
@@ -1393,7 +2342,7 @@
(512, 512, 4096, 32, 32): (4, 32, 32, 64, 1, 2),
(512, 512, 4096, 64, 64): (4, 32, 64, 64, 1, 4),
(512, 512, 4096, 128, 128): (4, 32, 64, 64, 1, 4),
- (512, 512, 8192, 16, 16): (3, 16, 16, 128, 1, 2),
+ (512, 512, 8192, 16, 16): (2, 32, 16, 128, 1, 2),
(512, 512, 8192, 32, 32): (4, 64, 32, 64, 1, 2),
(512, 512, 8192, 64, 64): (4, 128, 64, 64, 1, 2),
(512, 512, 8192, 128, 128): (4, 64, 64, 64, 1, 4),
@@ -1401,31 +2350,31 @@
(512, 512, 16384, 32, 32): (4, 64, 32, 64, 1, 2),
(512, 512, 16384, 64, 64): (4, 16, 64, 64, 1, 4),
(512, 512, 16384, 128, 128): (4, 32, 64, 64, 1, 4),
- (512, 512, 32768, 16, 16): (6, 16, 16, 128, 1, 2),
+ (512, 512, 32768, 16, 16): (7, 16, 16, 128, 1, 2),
(512, 512, 32768, 32, 32): (4, 64, 32, 64, 1, 2),
- (512, 512, 32768, 64, 64): (4, 32, 64, 64, 1, 2),
- (512, 512, 32768, 128, 128): (4, 16, 64, 64, 1, 4),
- (512, 512, 65536, 16, 16): (4, 32, 16, 64, 1, 1),
+ (512, 512, 32768, 64, 64): (2, 32, 64, 64, 3, 2),
+ (512, 512, 32768, 128, 128): (2, 32, 64, 64, 1, 4),
+ (512, 512, 65536, 16, 16): (2, 32, 16, 64, 1, 1),
(512, 512, 65536, 32, 32): (4, 64, 32, 64, 1, 2),
- (512, 512, 65536, 64, 64): (5, 32, 64, 64, 1, 2),
+ (512, 512, 65536, 64, 64): (3, 32, 64, 64, 3, 2),
(512, 512, 65536, 128, 128): (4, 16, 64, 64, 1, 4),
(512, 512, 131072, 16, 16): (3, 32, 16, 128, 1, 2),
(512, 512, 131072, 32, 32): (4, 64, 32, 64, 1, 2),
- (512, 512, 131072, 64, 64): (4, 32, 64, 64, 1, 2),
- (512, 512, 131072, 128, 128): (4, 16, 64, 64, 1, 4),
+ (512, 512, 131072, 64, 64): (2, 32, 64, 64, 3, 2),
+ (512, 512, 131072, 128, 128): (3, 1, 64, 64, 1, 4),
(1024, 1024, 256, 16, 16): (4, 16, 16, 16, 1, 4),
(1024, 1024, 256, 32, 32): (4, 16, 32, 16, 1, 4),
(1024, 1024, 256, 64, 64): (4, 4, 64, 32, 1, 16),
(1024, 1024, 256, 128, 128): (4, 16, 64, 16, 1, 8),
- (1024, 1024, 512, 16, 16): (4, 8, 16, 64, 1, 1),
- (1024, 1024, 512, 32, 32): (5, 8, 32, 64, 1, 2),
+ (1024, 1024, 512, 16, 16): (2, 8, 16, 64, 1, 8),
+ (1024, 1024, 512, 32, 32): (3, 2, 32, 64, 1, 2),
(1024, 1024, 512, 64, 64): (4, 8, 32, 64, 1, 8),
(1024, 1024, 512, 128, 128): (4, 8, 64, 64, 1, 8),
(1024, 1024, 1024, 16, 16): (2, 2, 16, 64, 1, 2),
(1024, 1024, 1024, 32, 32): (2, 8, 32, 64, 1, 2),
(1024, 1024, 1024, 64, 64): (2, 8, 32, 128, 1, 4),
(1024, 1024, 1024, 128, 128): (2, 8, 64, 64, 1, 4),
- (1024, 1024, 2048, 16, 16): (4, 16, 16, 128, 1, 2),
+ (1024, 1024, 2048, 16, 16): (2, 16, 16, 128, 3, 2),
(1024, 1024, 2048, 32, 32): (4, 32, 32, 64, 1, 2),
(1024, 1024, 2048, 64, 64): (4, 16, 64, 64, 1, 4),
(1024, 1024, 2048, 128, 128): (4, 32, 64, 64, 1, 4),
@@ -1434,183 +2383,183 @@
(1024, 1024, 4096, 64, 64): (4, 32, 64, 64, 1, 4),
(1024, 1024, 4096, 128, 128): (4, 32, 64, 64, 1, 4),
(1024, 1024, 8192, 16, 16): (5, 16, 16, 128, 1, 2),
- (1024, 1024, 8192, 32, 32): (4, 32, 32, 64, 1, 2),
- (1024, 1024, 8192, 64, 64): (3, 64, 64, 64, 3, 2),
+ (1024, 1024, 8192, 32, 32): (2, 32, 32, 64, 3, 2),
+ (1024, 1024, 8192, 64, 64): (1, 16, 64, 64, 3, 2),
(1024, 1024, 8192, 128, 128): (4, 32, 64, 64, 1, 4),
(1024, 1024, 16384, 16, 16): (4, 16, 16, 128, 1, 2),
- (1024, 1024, 16384, 32, 32): (3, 32, 32, 64, 1, 2),
+ (1024, 1024, 16384, 32, 32): (1, 32, 32, 64, 3, 2),
(1024, 1024, 16384, 64, 64): (4, 16, 64, 64, 3, 2),
(1024, 1024, 16384, 128, 128): (4, 32, 128, 64, 1, 4),
- (1024, 1024, 32768, 16, 16): (4, 16, 16, 128, 1, 2),
- (1024, 1024, 32768, 32, 32): (3, 32, 32, 64, 1, 2),
+ (1024, 1024, 32768, 16, 16): (3, 16, 16, 128, 1, 2),
+ (1024, 1024, 32768, 32, 32): (1, 8, 32, 64, 3, 2),
(1024, 1024, 32768, 64, 64): (4, 16, 64, 64, 3, 2),
(1024, 1024, 32768, 128, 128): (4, 8, 128, 64, 2, 4),
- (1024, 1024, 65536, 16, 16): (4, 8, 16, 128, 1, 2),
- (1024, 1024, 65536, 32, 32): (4, 16, 32, 64, 1, 2),
- (1024, 1024, 65536, 64, 64): (4, 16, 64, 64, 3, 2),
+ (1024, 1024, 65536, 16, 16): (1, 2, 16, 128, 1, 2),
+ (1024, 1024, 65536, 32, 32): (2, 4, 32, 64, 3, 2),
+ (1024, 1024, 65536, 64, 64): (5, 16, 64, 64, 3, 2),
(1024, 1024, 65536, 128, 128): (5, 8, 128, 64, 2, 4),
- (1024, 1024, 131072, 16, 16): (4, 8, 16, 128, 1, 2),
- (1024, 1024, 131072, 32, 32): (4, 16, 32, 64, 1, 2),
+ (1024, 1024, 131072, 16, 16): (5, 2, 16, 128, 1, 2),
+ (1024, 1024, 131072, 32, 32): (1, 2, 32, 64, 3, 2),
(1024, 1024, 131072, 64, 64): (5, 16, 64, 64, 3, 2),
- (1024, 1024, 131072, 128, 128): (4, 8, 128, 64, 2, 4),
+ (1024, 1024, 131072, 128, 128): (2, 1, 128, 64, 2, 4),
(2048, 2048, 256, 16, 16): (4, 4, 16, 64, 1, 8),
(2048, 2048, 256, 32, 32): (4, 8, 32, 32, 1, 8),
(2048, 2048, 256, 64, 64): (4, 16, 64, 16, 1, 8),
(2048, 2048, 256, 128, 128): (4, 4, 128, 32, 3, 8),
- (2048, 2048, 512, 16, 16): (4, 8, 16, 64, 1, 2),
- (2048, 2048, 512, 32, 32): (4, 4, 32, 64, 1, 2),
+ (2048, 2048, 512, 16, 16): (2, 2, 16, 64, 1, 2),
+ (2048, 2048, 512, 32, 32): (2, 4, 32, 64, 3, 2),
(2048, 2048, 512, 64, 64): (4, 4, 64, 64, 1, 8),
(2048, 2048, 512, 128, 128): (4, 8, 64, 64, 1, 4),
- (2048, 2048, 1024, 16, 16): (3, 8, 16, 64, 1, 2),
- (2048, 2048, 1024, 32, 32): (4, 16, 32, 64, 1, 2),
+ (2048, 2048, 1024, 16, 16): (1, 8, 16, 64, 1, 2),
+ (2048, 2048, 1024, 32, 32): (2, 16, 32, 64, 3, 2),
(2048, 2048, 1024, 64, 64): (4, 8, 64, 64, 1, 4),
(2048, 2048, 1024, 128, 128): (4, 8, 128, 64, 1, 4),
- (2048, 2048, 2048, 16, 16): (4, 4, 16, 128, 1, 2),
- (2048, 2048, 2048, 32, 32): (2, 16, 32, 64, 1, 2),
+ (2048, 2048, 2048, 16, 16): (5, 4, 16, 128, 1, 2),
+ (2048, 2048, 2048, 32, 32): (1, 16, 32, 64, 3, 2),
(2048, 2048, 2048, 64, 64): (2, 8, 64, 64, 1, 4),
(2048, 2048, 2048, 128, 128): (2, 8, 128, 64, 1, 4),
(2048, 2048, 4096, 16, 16): (4, 2, 16, 128, 1, 2),
- (2048, 2048, 4096, 32, 32): (4, 16, 32, 64, 1, 2),
- (2048, 2048, 4096, 64, 64): (4, 32, 64, 64, 3, 2),
+ (2048, 2048, 4096, 32, 32): (2, 16, 32, 64, 3, 2),
+ (2048, 2048, 4096, 64, 64): (2, 8, 64, 64, 3, 2),
(2048, 2048, 4096, 128, 128): (4, 8, 128, 64, 1, 4),
(2048, 2048, 8192, 16, 16): (5, 4, 16, 128, 1, 2),
- (2048, 2048, 8192, 32, 32): (4, 64, 32, 64, 1, 2),
+ (2048, 2048, 8192, 32, 32): (2, 8, 32, 64, 3, 2),
(2048, 2048, 8192, 64, 64): (4, 8, 64, 64, 3, 2),
(2048, 2048, 8192, 128, 128): (4, 8, 128, 64, 1, 4),
(2048, 2048, 16384, 16, 16): (3, 2, 16, 128, 1, 2),
- (2048, 2048, 16384, 32, 32): (4, 8, 32, 64, 1, 2),
+ (2048, 2048, 16384, 32, 32): (2, 4, 32, 128, 3, 2),
(2048, 2048, 16384, 64, 64): (4, 8, 64, 64, 3, 2),
(2048, 2048, 16384, 128, 128): (4, 4, 128, 64, 1, 4),
- (2048, 2048, 32768, 16, 16): (6, 2, 16, 128, 1, 2),
- (2048, 2048, 32768, 32, 32): (5, 8, 32, 64, 1, 2),
+ (2048, 2048, 32768, 16, 16): (3, 2, 16, 128, 1, 2),
+ (2048, 2048, 32768, 32, 32): (3, 4, 32, 128, 3, 2),
(2048, 2048, 32768, 64, 64): (6, 4, 64, 64, 3, 2),
(2048, 2048, 32768, 128, 128): (3, 4, 128, 64, 1, 4),
- (2048, 2048, 65536, 16, 16): (7, 2, 16, 128, 1, 2),
- (2048, 2048, 65536, 32, 32): (3, 1, 32, 128, 1, 2),
+ (2048, 2048, 65536, 16, 16): (6, 2, 16, 128, 1, 2),
+ (2048, 2048, 65536, 32, 32): (1, 2, 32, 128, 1, 2),
(2048, 2048, 65536, 64, 64): (5, 4, 64, 64, 3, 2),
(2048, 2048, 65536, 128, 128): (5, 1, 128, 64, 2, 4),
(2048, 2048, 131072, 16, 16): (3, 2, 16, 128, 1, 2),
- (2048, 2048, 131072, 32, 32): (4, 2, 32, 128, 1, 4),
+ (2048, 2048, 131072, 32, 32): (2, 1, 32, 128, 3, 2),
(2048, 2048, 131072, 64, 64): (4, 1, 64, 64, 3, 2),
(2048, 2048, 131072, 128, 128): (3, 1, 128, 64, 2, 4),
(4096, 4096, 256, 16, 16): (5, 8, 16, 32, 1, 4),
(4096, 4096, 256, 32, 32): (4, 16, 32, 16, 2, 4),
- (4096, 4096, 256, 64, 64): (4, 8, 64, 32, 1, 4),
+ (4096, 4096, 256, 64, 64): (2, 1, 64, 64, 3, 4),
(4096, 4096, 256, 128, 128): (4, 4, 128, 32, 1, 4),
(4096, 4096, 512, 16, 16): (4, 2, 16, 128, 1, 2),
(4096, 4096, 512, 32, 32): (4, 8, 32, 64, 1, 2),
(4096, 4096, 512, 64, 64): (4, 4, 64, 64, 1, 4),
(4096, 4096, 512, 128, 128): (4, 8, 128, 64, 2, 4),
- (4096, 4096, 1024, 16, 16): (4, 8, 16, 128, 1, 2),
- (4096, 4096, 1024, 32, 32): (4, 8, 32, 64, 1, 2),
- (4096, 4096, 1024, 64, 64): (4, 16, 64, 64, 1, 4),
- (4096, 4096, 1024, 128, 128): (4, 16, 128, 64, 2, 4),
- (4096, 4096, 2048, 16, 16): (5, 8, 16, 128, 1, 2),
- (4096, 4096, 2048, 32, 32): (3, 4, 32, 64, 1, 2),
+ (4096, 4096, 1024, 16, 16): (1, 2, 16, 128, 1, 2),
+ (4096, 4096, 1024, 32, 32): (6, 8, 32, 64, 3, 2),
+ (4096, 4096, 1024, 64, 64): (2, 16, 64, 64, 4, 4),
+ (4096, 4096, 1024, 128, 128): (2, 4, 128, 64, 2, 4),
+ (4096, 4096, 2048, 16, 16): (3, 1, 16, 128, 1, 2),
+ (4096, 4096, 2048, 32, 32): (1, 4, 32, 64, 5, 2),
(4096, 4096, 2048, 64, 64): (3, 16, 64, 64, 3, 2),
(4096, 4096, 2048, 128, 128): (4, 32, 128, 64, 2, 4),
(4096, 4096, 4096, 16, 16): (1, 2, 16, 128, 1, 2),
- (4096, 4096, 4096, 32, 32): (3, 4, 32, 64, 3, 2),
+ (4096, 4096, 4096, 32, 32): (1, 4, 32, 64, 3, 2),
(4096, 4096, 4096, 64, 64): (1, 1, 64, 64, 4, 4),
- (4096, 4096, 4096, 128, 128): (1, 1, 128, 128, 1, 8),
- (4096, 4096, 8192, 16, 16): (5, 8, 16, 128, 1, 2),
- (4096, 4096, 8192, 32, 32): (4, 4, 32, 64, 1, 2),
+ (4096, 4096, 4096, 128, 128): (2, 1, 128, 128, 1, 8),
+ (4096, 4096, 8192, 16, 16): (3, 1, 16, 128, 1, 2),
+ (4096, 4096, 8192, 32, 32): (2, 2, 32, 64, 5, 2),
(4096, 4096, 8192, 64, 64): (4, 16, 64, 64, 3, 2),
(4096, 4096, 8192, 128, 128): (4, 16, 128, 64, 2, 4),
- (4096, 4096, 16384, 16, 16): (4, 8, 16, 128, 1, 2),
- (4096, 4096, 16384, 32, 32): (6, 2, 32, 64, 1, 2),
+ (4096, 4096, 16384, 16, 16): (1, 2, 16, 128, 1, 2),
+ (4096, 4096, 16384, 32, 32): (4, 2, 32, 64, 5, 2),
(4096, 4096, 16384, 64, 64): (4, 16, 64, 64, 3, 2),
(4096, 4096, 16384, 128, 128): (4, 16, 128, 64, 2, 4),
- (4096, 4096, 32768, 16, 16): (2, 8, 16, 128, 1, 2),
+ (4096, 4096, 32768, 16, 16): (3, 1, 16, 128, 1, 2),
(4096, 4096, 32768, 32, 32): (3, 1, 32, 128, 1, 4),
- (4096, 4096, 32768, 64, 64): (5, 8, 64, 64, 3, 2),
+ (4096, 4096, 32768, 64, 64): (3, 1, 64, 64, 3, 4),
(4096, 4096, 32768, 128, 128): (5, 16, 128, 64, 2, 4),
- (4096, 4096, 65536, 16, 16): (6, 8, 16, 128, 1, 2),
+ (4096, 4096, 65536, 16, 16): (5, 1, 16, 128, 1, 2),
(4096, 4096, 65536, 32, 32): (5, 1, 32, 128, 1, 4),
- (4096, 4096, 65536, 64, 64): (3, 8, 64, 64, 3, 2),
+ (4096, 4096, 65536, 64, 64): (1, 1, 64, 64, 3, 4),
(4096, 4096, 65536, 128, 128): (3, 16, 128, 64, 2, 4),
- (4096, 4096, 131072, 16, 16): (5, 8, 16, 128, 1, 2),
- (4096, 4096, 131072, 32, 32): (5, 4, 32, 64, 1, 2),
- (4096, 4096, 131072, 64, 64): (5, 8, 64, 64, 3, 2),
- (4096, 4096, 131072, 128, 128): (4, 16, 128, 64, 2, 4),
+ (4096, 4096, 131072, 16, 16): (5, 1, 16, 128, 1, 2),
+ (4096, 4096, 131072, 32, 32): (3, 1, 32, 128, 3, 2),
+ (4096, 4096, 131072, 64, 64): (2, 1, 64, 64, 3, 4),
+ (4096, 4096, 131072, 128, 128): (1, 1, 128, 64, 1, 4),
(8192, 8192, 256, 16, 16): (4, 16, 16, 16, 1, 4),
- (8192, 8192, 256, 32, 32): (4, 16, 32, 16, 4, 4),
+ (8192, 8192, 256, 32, 32): (1, 16, 32, 16, 4, 4),
(8192, 8192, 256, 64, 64): (4, 16, 64, 16, 3, 8),
(8192, 8192, 256, 128, 128): (4, 16, 128, 16, 1, 2),
- (8192, 8192, 512, 16, 16): (5, 8, 16, 64, 1, 4),
- (8192, 8192, 512, 32, 32): (4, 4, 32, 64, 1, 2),
- (8192, 8192, 512, 64, 64): (4, 4, 64, 64, 1, 4),
+ (8192, 8192, 512, 16, 16): (2, 8, 16, 64, 1, 4),
+ (8192, 8192, 512, 32, 32): (4, 8, 32, 64, 3, 2),
+ (8192, 8192, 512, 64, 64): (2, 8, 64, 64, 4, 4),
(8192, 8192, 512, 128, 128): (4, 8, 128, 64, 2, 4),
(8192, 8192, 1024, 16, 16): (4, 16, 16, 64, 1, 8),
- (8192, 8192, 1024, 32, 32): (4, 4, 32, 64, 1, 2),
- (8192, 8192, 1024, 64, 64): (4, 16, 64, 64, 3, 2),
- (8192, 8192, 1024, 128, 128): (4, 16, 128, 64, 2, 4),
- (8192, 8192, 2048, 16, 16): (5, 2, 16, 128, 1, 2),
- (8192, 8192, 2048, 32, 32): (4, 16, 32, 64, 1, 2),
+ (8192, 8192, 1024, 32, 32): (2, 8, 32, 64, 5, 2),
+ (8192, 8192, 1024, 64, 64): (1, 16, 64, 64, 3, 2),
+ (8192, 8192, 1024, 128, 128): (5, 16, 128, 64, 2, 4),
+ (8192, 8192, 2048, 16, 16): (7, 2, 16, 128, 1, 2),
+ (8192, 8192, 2048, 32, 32): (1, 16, 32, 64, 5, 2),
(8192, 8192, 2048, 64, 64): (4, 16, 64, 64, 3, 2),
(8192, 8192, 2048, 128, 128): (6, 16, 128, 64, 2, 4),
(8192, 8192, 4096, 16, 16): (4, 2, 16, 128, 1, 2),
- (8192, 8192, 4096, 32, 32): (4, 4, 32, 64, 1, 2),
+ (8192, 8192, 4096, 32, 32): (2, 8, 32, 64, 5, 2),
(8192, 8192, 4096, 64, 64): (3, 16, 64, 64, 3, 2),
(8192, 8192, 4096, 128, 128): (3, 64, 128, 64, 2, 4),
- (8192, 8192, 8192, 16, 16): (3, 2, 16, 128, 1, 2),
- (8192, 8192, 8192, 32, 32): (2, 4, 32, 128, 1, 4),
+ (8192, 8192, 8192, 16, 16): (4, 2, 16, 128, 1, 2),
+ (8192, 8192, 8192, 32, 32): (1, 4, 32, 128, 5, 4),
(8192, 8192, 8192, 64, 64): (4, 4, 64, 64, 1, 4),
(8192, 8192, 8192, 128, 128): (2, 2, 128, 128, 3, 8),
- (8192, 8192, 16384, 16, 16): (4, 8, 16, 128, 1, 2),
- (8192, 8192, 16384, 32, 32): (3, 4, 32, 64, 1, 2),
+ (8192, 8192, 16384, 16, 16): (1, 2, 16, 128, 1, 2),
+ (8192, 8192, 16384, 32, 32): (4, 8, 32, 64, 5, 2),
(8192, 8192, 16384, 64, 64): (5, 8, 64, 64, 3, 2),
(8192, 8192, 16384, 128, 128): (3, 16, 128, 64, 2, 4),
- (8192, 8192, 32768, 16, 16): (3, 2, 16, 128, 1, 2),
- (8192, 8192, 32768, 32, 32): (4, 4, 32, 64, 1, 2),
+ (8192, 8192, 32768, 16, 16): (7, 2, 16, 128, 1, 2),
+ (8192, 8192, 32768, 32, 32): (3, 4, 32, 64, 3, 2),
(8192, 8192, 32768, 64, 64): (2, 8, 64, 64, 3, 2),
(8192, 8192, 32768, 128, 128): (6, 16, 128, 64, 2, 4),
(8192, 8192, 65536, 16, 16): (9, 2, 16, 128, 1, 2),
- (8192, 8192, 65536, 32, 32): (6, 4, 32, 64, 1, 2),
+ (8192, 8192, 65536, 32, 32): (7, 4, 32, 64, 5, 2),
(8192, 8192, 65536, 64, 64): (4, 8, 64, 64, 3, 2),
(8192, 8192, 65536, 128, 128): (3, 16, 128, 64, 2, 4),
- (8192, 8192, 131072, 16, 16): (7, 2, 16, 128, 1, 2),
- (8192, 8192, 131072, 32, 32): (3, 8, 32, 64, 1, 2),
+ (8192, 8192, 131072, 16, 16): (9, 2, 16, 128, 1, 2),
+ (8192, 8192, 131072, 32, 32): (1, 8, 32, 64, 5, 2),
(8192, 8192, 131072, 64, 64): (1, 8, 64, 64, 3, 2),
(8192, 8192, 131072, 128, 128): (4, 16, 128, 64, 2, 4),
(16384, 16384, 256, 16, 16): (5, 16, 16, 16, 1, 4),
(16384, 16384, 256, 32, 32): (4, 16, 32, 16, 4, 4),
(16384, 16384, 256, 64, 64): (4, 16, 64, 16, 3, 8),
(16384, 16384, 256, 128, 128): (4, 16, 128, 16, 1, 2),
- (16384, 16384, 512, 16, 16): (4, 4, 16, 64, 1, 4),
- (16384, 16384, 512, 32, 32): (4, 4, 32, 64, 1, 2),
+ (16384, 16384, 512, 16, 16): (2, 8, 16, 64, 1, 4),
+ (16384, 16384, 512, 32, 32): (1, 4, 32, 64, 5, 2),
(16384, 16384, 512, 64, 64): (4, 8, 64, 64, 1, 4),
(16384, 16384, 512, 128, 128): (3, 8, 128, 64, 2, 4),
(16384, 16384, 1024, 16, 16): (4, 2, 16, 128, 1, 2),
- (16384, 16384, 1024, 32, 32): (4, 8, 32, 64, 1, 2),
+ (16384, 16384, 1024, 32, 32): (4, 8, 32, 64, 5, 2),
(16384, 16384, 1024, 64, 64): (6, 16, 64, 64, 3, 2),
(16384, 16384, 1024, 128, 128): (3, 16, 128, 64, 2, 4),
(16384, 16384, 2048, 16, 16): (3, 2, 16, 128, 1, 2),
- (16384, 16384, 2048, 32, 32): (5, 8, 32, 64, 1, 2),
+ (16384, 16384, 2048, 32, 32): (1, 8, 32, 64, 5, 2),
(16384, 16384, 2048, 64, 64): (5, 16, 64, 64, 3, 2),
- (16384, 16384, 2048, 128, 128): (3, 32, 128, 64, 2, 4),
- (16384, 16384, 4096, 16, 16): (3, 2, 16, 128, 1, 2),
- (16384, 16384, 4096, 32, 32): (5, 4, 32, 64, 1, 2),
- (16384, 16384, 4096, 64, 64): (4, 16, 64, 64, 3, 2),
+ (16384, 16384, 2048, 128, 128): (2, 32, 128, 64, 2, 4),
+ (16384, 16384, 4096, 16, 16): (2, 2, 16, 128, 1, 2),
+ (16384, 16384, 4096, 32, 32): (1, 4, 32, 64, 3, 2),
+ (16384, 16384, 4096, 64, 64): (2, 8, 64, 64, 3, 2),
(16384, 16384, 4096, 128, 128): (3, 16, 128, 64, 2, 4),
- (16384, 16384, 8192, 16, 16): (4, 2, 16, 128, 1, 2),
- (16384, 16384, 8192, 32, 32): (4, 4, 32, 64, 1, 2),
+ (16384, 16384, 8192, 16, 16): (3, 2, 16, 128, 1, 2),
+ (16384, 16384, 8192, 32, 32): (2, 4, 32, 64, 5, 2),
(16384, 16384, 8192, 64, 64): (4, 8, 64, 64, 3, 2),
- (16384, 16384, 8192, 128, 128): (6, 32, 128, 64, 2, 4),
+ (16384, 16384, 8192, 128, 128): (8, 32, 128, 64, 2, 4),
(16384, 16384, 16384, 16, 16): (1, 2, 16, 256, 1, 4),
- (16384, 16384, 16384, 32, 32): (2, 4, 32, 128, 1, 4),
+ (16384, 16384, 16384, 32, 32): (1, 4, 32, 128, 3, 4),
(16384, 16384, 16384, 64, 64): (5, 4, 64, 64, 1, 4),
(16384, 16384, 16384, 128, 128): (4, 8, 128, 64, 2, 4),
(16384, 16384, 32768, 16, 16): (2, 2, 16, 128, 1, 2),
- (16384, 16384, 32768, 32, 32): (2, 4, 32, 64, 1, 2),
+ (16384, 16384, 32768, 32, 32): (1, 4, 32, 64, 3, 2),
(16384, 16384, 32768, 64, 64): (5, 4, 64, 64, 1, 4),
(16384, 16384, 32768, 128, 128): (5, 8, 128, 64, 2, 4),
- (16384, 16384, 65536, 16, 16): (5, 2, 16, 128, 1, 2),
- (16384, 16384, 65536, 32, 32): (4, 2, 32, 64, 1, 2),
+ (16384, 16384, 65536, 16, 16): (8, 2, 16, 128, 1, 2),
+ (16384, 16384, 65536, 32, 32): (6, 4, 32, 64, 5, 2),
(16384, 16384, 65536, 64, 64): (2, 4, 64, 64, 1, 4),
(16384, 16384, 65536, 128, 128): (4, 8, 128, 64, 2, 4),
- (16384, 16384, 131072, 16, 16): (3, 2, 16, 128, 1, 2),
- (16384, 16384, 131072, 32, 32): (3, 4, 32, 64, 1, 2),
+ (16384, 16384, 131072, 16, 16): (3, 1, 16, 128, 1, 2),
+ (16384, 16384, 131072, 32, 32): (1, 4, 32, 64, 3, 2),
(16384, 16384, 131072, 64, 64): (4, 4, 64, 64, 1, 4),
(16384, 16384, 131072, 128, 128): (1, 8, 128, 64, 2, 4),
(32768, 32768, 256, 16, 16): (4, 16, 16, 16, 1, 4),
@@ -1622,9 +2571,292 @@
(32768, 32768, 16384, 16, 16): (4, 4, 16, 64, 1, 1),
(32768, 32768, 32768, 16, 16): (5, 4, 16, 64, 1, 1),
},
+ ("scatter_mm", "NVIDIA A100-SXM4-80GB", (0, torch.float32, 0.5)): {
+ (256, 256, 256, 16, 16): (1, 1, 16, 16, 1, 8),
+ (256, 256, 256, 32, 32): (1, 1, 16, 16, 1, 4),
+ (256, 256, 256, 64, 64): (1, 1, 16, 16, 1, 4),
+ (256, 256, 256, 128, 128): (1, 1, 16, 16, 1, 1),
+ (256, 256, 512, 16, 16): (1, 1, 16, 16, 1, 4),
+ (256, 256, 512, 32, 32): (1, 16, 16, 16, 1, 1),
+ (256, 256, 512, 64, 64): (1, 1, 16, 16, 1, 1),
+ (256, 256, 512, 128, 128): (1, 1, 32, 32, 1, 4),
+ (256, 256, 1024, 16, 16): (1, 1, 16, 32, 1, 2),
+ (256, 256, 1024, 32, 32): (1, 4, 16, 16, 1, 1),
+ (256, 256, 1024, 64, 64): (1, 1, 32, 32, 1, 4),
+ (256, 256, 1024, 128, 128): (1, 1, 32, 32, 1, 4),
+ (256, 256, 2048, 16, 16): (1, 2, 16, 32, 1, 2),
+ (256, 256, 2048, 32, 32): (1, 1, 16, 32, 1, 2),
+ (256, 256, 2048, 64, 64): (2, 1, 16, 32, 1, 2),
+ (256, 256, 2048, 128, 128): (1, 1, 16, 16, 1, 1),
+ (256, 256, 4096, 16, 16): (1, 1, 16, 32, 1, 2),
+ (256, 256, 4096, 32, 32): (1, 1, 16, 32, 1, 2),
+ (256, 256, 4096, 64, 64): (1, 1, 32, 32, 1, 4),
+ (256, 256, 4096, 128, 128): (3, 1, 32, 64, 1, 4),
+ (256, 256, 8192, 16, 16): (1, 32, 16, 64, 1, 2),
+ (256, 256, 8192, 32, 32): (1, 1, 32, 64, 1, 4),
+ (256, 256, 8192, 64, 64): (1, 1, 32, 64, 1, 4),
+ (256, 256, 8192, 128, 128): (2, 1, 64, 32, 1, 4),
+ (256, 256, 16384, 16, 16): (1, 1, 16, 64, 1, 2),
+ (256, 256, 16384, 32, 32): (1, 1, 32, 64, 1, 4),
+ (256, 256, 16384, 64, 64): (1, 128, 64, 64, 1, 4),
+ (256, 256, 16384, 128, 128): (2, 1, 64, 32, 1, 4),
+ (256, 256, 32768, 16, 16): (2, 128, 16, 64, 1, 1),
+ (256, 256, 32768, 32, 32): (1, 1, 32, 64, 1, 4),
+ (256, 256, 32768, 64, 64): (1, 128, 64, 64, 1, 4),
+ (256, 256, 32768, 128, 128): (2, 1, 64, 64, 1, 4),
+ (256, 256, 65536, 16, 16): (1, 1, 16, 64, 1, 2),
+ (256, 256, 65536, 32, 32): (1, 1, 32, 64, 1, 4),
+ (256, 256, 65536, 64, 64): (2, 1, 64, 64, 1, 4),
+ (256, 256, 65536, 128, 128): (1, 1, 128, 32, 1, 4),
+ (256, 256, 131072, 16, 16): (3, 128, 16, 64, 1, 1),
+ (256, 256, 131072, 32, 32): (1, 1, 32, 64, 1, 4),
+ (256, 256, 131072, 64, 64): (2, 1, 64, 64, 1, 4),
+ (256, 256, 131072, 128, 128): (1, 8192, 64, 16, 1, 4),
+ (512, 512, 256, 16, 16): (1, 2, 16, 16, 1, 1),
+ (512, 512, 256, 32, 32): (1, 4, 16, 16, 1, 1),
+ (512, 512, 256, 64, 64): (1, 16, 16, 16, 1, 1),
+ (512, 512, 256, 128, 128): (1, 1, 16, 32, 1, 4),
+ (512, 512, 512, 16, 16): (1, 8, 16, 32, 1, 2),
+ (512, 512, 512, 32, 32): (1, 8, 16, 32, 1, 2),
+ (512, 512, 512, 64, 64): (1, 2, 16, 32, 1, 2),
+ (512, 512, 512, 128, 128): (1, 1, 32, 32, 1, 4),
+ (512, 512, 1024, 16, 16): (1, 1, 16, 32, 1, 2),
+ (512, 512, 1024, 32, 32): (1, 1, 16, 32, 1, 2),
+ (512, 512, 1024, 64, 64): (1, 1, 16, 32, 1, 2),
+ (512, 512, 1024, 128, 128): (1, 1, 64, 32, 1, 4),
+ (512, 512, 2048, 16, 16): (1, 16, 16, 64, 1, 2),
+ (512, 512, 2048, 32, 32): (1, 1, 32, 32, 1, 4),
+ (512, 512, 2048, 64, 64): (1, 1, 32, 32, 1, 4),
+ (512, 512, 2048, 128, 128): (2, 1, 32, 32, 1, 4),
+ (512, 512, 4096, 16, 16): (2, 64, 16, 64, 1, 1),
+ (512, 512, 4096, 32, 32): (1, 64, 32, 64, 1, 4),
+ (512, 512, 4096, 64, 64): (1, 1, 32, 32, 1, 4),
+ (512, 512, 4096, 128, 128): (1, 1, 64, 32, 1, 4),
+ (512, 512, 8192, 16, 16): (2, 64, 16, 64, 1, 1),
+ (512, 512, 8192, 32, 32): (1, 256, 32, 32, 1, 1),
+ (512, 512, 8192, 64, 64): (1, 64, 64, 64, 1, 4),
+ (512, 512, 8192, 128, 128): (2, 1, 64, 32, 1, 8),
+ (512, 512, 16384, 16, 16): (2, 64, 16, 64, 1, 1),
+ (512, 512, 16384, 32, 32): (1, 128, 32, 32, 1, 1),
+ (512, 512, 16384, 64, 64): (1, 64, 64, 64, 1, 4),
+ (512, 512, 16384, 128, 128): (3, 1, 64, 32, 1, 8),
+ (512, 512, 32768, 16, 16): (2, 64, 16, 64, 1, 1),
+ (512, 512, 32768, 32, 32): (1, 128, 32, 32, 1, 1),
+ (512, 512, 32768, 64, 64): (1, 64, 64, 64, 1, 4),
+ (512, 512, 32768, 128, 128): (2, 1, 64, 32, 1, 8),
+ (512, 512, 65536, 16, 16): (2, 32, 16, 64, 1, 1),
+ (512, 512, 65536, 32, 32): (1, 128, 32, 32, 1, 1),
+ (512, 512, 65536, 64, 64): (1, 64, 64, 64, 1, 4),
+ (512, 512, 65536, 128, 128): (2, 1, 64, 32, 1, 8),
+ (512, 512, 131072, 16, 16): (2, 32, 16, 64, 1, 1),
+ (512, 512, 131072, 32, 32): (1, 128, 32, 32, 1, 1),
+ (512, 512, 131072, 64, 64): (3, 64, 64, 64, 1, 4),
+ (512, 512, 131072, 128, 128): (1, 8192, 64, 16, 1, 4),
+ (1024, 1024, 256, 16, 16): (1, 4, 16, 32, 1, 2),
+ (1024, 1024, 256, 32, 32): (1, 4, 16, 32, 1, 2),
+ (1024, 1024, 256, 64, 64): (1, 1, 16, 32, 1, 2),
+ (1024, 1024, 256, 128, 128): (1, 1, 16, 16, 1, 1),
+ (1024, 1024, 512, 16, 16): (1, 8, 16, 32, 1, 2),
+ (1024, 1024, 512, 32, 32): (1, 8, 16, 32, 1, 1),
+ (1024, 1024, 512, 64, 64): (1, 8, 32, 32, 1, 4),
+ (1024, 1024, 512, 128, 128): (2, 1, 32, 32, 1, 4),
+ (1024, 1024, 1024, 16, 16): (1, 16, 16, 32, 1, 2),
+ (1024, 1024, 1024, 32, 32): (1, 16, 32, 64, 1, 4),
+ (1024, 1024, 1024, 64, 64): (1, 16, 32, 64, 1, 4),
+ (1024, 1024, 1024, 128, 128): (1, 1, 32, 32, 1, 4),
+ (1024, 1024, 2048, 16, 16): (2, 32, 16, 64, 1, 1),
+ (1024, 1024, 2048, 32, 32): (1, 32, 32, 64, 1, 4),
+ (1024, 1024, 2048, 64, 64): (1, 32, 64, 64, 1, 4),
+ (1024, 1024, 2048, 128, 128): (1, 1, 32, 64, 1, 4),
+ (1024, 1024, 4096, 16, 16): (2, 16, 16, 64, 1, 1),
+ (1024, 1024, 4096, 32, 32): (1, 64, 32, 32, 1, 1),
+ (1024, 1024, 4096, 64, 64): (1, 64, 64, 64, 1, 4),
+ (1024, 1024, 4096, 128, 128): (2, 64, 64, 32, 1, 8),
+ (1024, 1024, 8192, 16, 16): (2, 16, 16, 64, 1, 1),
+ (1024, 1024, 8192, 32, 32): (1, 64, 32, 32, 1, 1),
+ (1024, 1024, 8192, 64, 64): (1, 64, 64, 64, 1, 4),
+ (1024, 1024, 8192, 128, 128): (4, 1, 32, 64, 1, 4),
+ (1024, 1024, 16384, 16, 16): (2, 16, 16, 64, 1, 1),
+ (1024, 1024, 16384, 32, 32): (1, 64, 32, 32, 1, 1),
+ (1024, 1024, 16384, 64, 64): (1, 32, 64, 64, 1, 4),
+ (1024, 1024, 16384, 128, 128): (2, 64, 64, 32, 1, 4),
+ (1024, 1024, 32768, 16, 16): (2, 16, 16, 64, 1, 1),
+ (1024, 1024, 32768, 32, 32): (1, 64, 32, 32, 1, 1),
+ (1024, 1024, 32768, 64, 64): (1, 32, 64, 64, 1, 4),
+ (1024, 1024, 32768, 128, 128): (4, 1, 32, 64, 1, 4),
+ (1024, 1024, 65536, 16, 16): (2, 16, 16, 64, 1, 1),
+ (1024, 1024, 65536, 32, 32): (1, 32, 32, 32, 1, 1),
+ (1024, 1024, 65536, 64, 64): (2, 32, 64, 64, 1, 4),
+ (1024, 1024, 65536, 128, 128): (4, 1, 64, 32, 1, 4),
+ (1024, 1024, 131072, 16, 16): (2, 16, 16, 64, 1, 1),
+ (1024, 1024, 131072, 32, 32): (1, 32, 32, 32, 1, 1),
+ (1024, 1024, 131072, 64, 64): (1, 16, 64, 64, 1, 4),
+ (1024, 1024, 131072, 128, 128): (1, 8192, 64, 16, 1, 4),
+ (2048, 2048, 256, 16, 16): (1, 4, 16, 32, 1, 2),
+ (2048, 2048, 256, 32, 32): (1, 8, 16, 32, 1, 1),
+ (2048, 2048, 256, 64, 64): (1, 8, 32, 32, 1, 4),
+ (2048, 2048, 256, 128, 128): (1, 4, 64, 64, 1, 8),
+ (2048, 2048, 512, 16, 16): (2, 8, 16, 32, 1, 2),
+ (2048, 2048, 512, 32, 32): (2, 8, 32, 64, 1, 4),
+ (2048, 2048, 512, 64, 64): (2, 4, 64, 64, 1, 4),
+ (2048, 2048, 512, 128, 128): (1, 8, 32, 64, 1, 4),
+ (2048, 2048, 1024, 16, 16): (2, 16, 16, 64, 3, 1),
+ (2048, 2048, 1024, 32, 32): (1, 32, 32, 32, 1, 1),
+ (2048, 2048, 1024, 64, 64): (1, 16, 64, 64, 1, 4),
+ (2048, 2048, 1024, 128, 128): (2, 4, 64, 64, 1, 8),
+ (2048, 2048, 2048, 16, 16): (2, 16, 16, 64, 1, 1),
+ (2048, 2048, 2048, 32, 32): (1, 32, 32, 32, 1, 1),
+ (2048, 2048, 2048, 64, 64): (1, 16, 64, 64, 1, 4),
+ (2048, 2048, 2048, 128, 128): (2, 32, 32, 64, 1, 4),
+ (2048, 2048, 4096, 16, 16): (3, 2, 16, 64, 1, 1),
+ (2048, 2048, 4096, 32, 32): (3, 4, 32, 32, 1, 1),
+ (2048, 2048, 4096, 64, 64): (1, 16, 64, 64, 1, 4),
+ (2048, 2048, 4096, 128, 128): (2, 32, 64, 32, 1, 4),
+ (2048, 2048, 8192, 16, 16): (3, 4, 16, 64, 1, 1),
+ (2048, 2048, 8192, 32, 32): (2, 4, 32, 32, 1, 1),
+ (2048, 2048, 8192, 64, 64): (2, 32, 64, 32, 1, 2),
+ (2048, 2048, 8192, 128, 128): (4, 1, 32, 64, 1, 4),
+ (2048, 2048, 16384, 16, 16): (3, 4, 16, 64, 1, 1),
+ (2048, 2048, 16384, 32, 32): (1, 4, 32, 32, 1, 1),
+ (2048, 2048, 16384, 64, 64): (2, 8, 64, 32, 1, 2),
+ (2048, 2048, 16384, 128, 128): (2, 8, 64, 32, 1, 4),
+ (2048, 2048, 32768, 16, 16): (2, 4, 16, 64, 1, 1),
+ (2048, 2048, 32768, 32, 32): (2, 8, 32, 32, 1, 1),
+ (2048, 2048, 32768, 64, 64): (1, 16, 64, 32, 1, 2),
+ (2048, 2048, 32768, 128, 128): (4, 1, 32, 64, 1, 4),
+ (2048, 2048, 65536, 16, 16): (3, 4, 16, 64, 1, 1),
+ (2048, 2048, 65536, 32, 32): (1, 8, 32, 32, 1, 1),
+ (2048, 2048, 65536, 64, 64): (1, 8, 64, 32, 1, 2),
+ (2048, 2048, 65536, 128, 128): (4, 1, 64, 32, 1, 4),
+ (2048, 2048, 131072, 16, 16): (2, 4, 16, 64, 1, 1),
+ (2048, 2048, 131072, 32, 32): (1, 8, 32, 32, 1, 1),
+ (2048, 2048, 131072, 64, 64): (3, 1, 64, 32, 1, 2),
+ (2048, 2048, 131072, 128, 128): (1, 8192, 128, 16, 1, 8),
+ (4096, 4096, 256, 16, 16): (2, 4, 16, 32, 1, 2),
+ (4096, 4096, 256, 32, 32): (1, 4, 32, 64, 1, 4),
+ (4096, 4096, 256, 64, 64): (1, 4, 64, 64, 1, 4),
+ (4096, 4096, 256, 128, 128): (1, 4, 32, 64, 1, 4),
+ (4096, 4096, 512, 16, 16): (2, 8, 16, 64, 3, 1),
+ (4096, 4096, 512, 32, 32): (2, 16, 32, 32, 1, 1),
+ (4096, 4096, 512, 64, 64): (1, 8, 64, 64, 1, 4),
+ (4096, 4096, 512, 128, 128): (1, 8, 32, 64, 1, 4),
+ (4096, 4096, 1024, 16, 16): (1, 8, 16, 64, 3, 1),
+ (4096, 4096, 1024, 32, 32): (1, 16, 32, 32, 1, 1),
+ (4096, 4096, 1024, 64, 64): (1, 16, 64, 32, 1, 2),
+ (4096, 4096, 1024, 128, 128): (1, 16, 32, 64, 1, 4),
+ (4096, 4096, 2048, 16, 16): (1, 16, 16, 64, 3, 1),
+ (4096, 4096, 2048, 32, 32): (1, 16, 32, 32, 1, 1),
+ (4096, 4096, 2048, 64, 64): (3, 16, 64, 32, 1, 2),
+ (4096, 4096, 2048, 128, 128): (4, 8, 32, 64, 1, 4),
+ (4096, 4096, 4096, 16, 16): (1, 8, 16, 64, 3, 1),
+ (4096, 4096, 4096, 32, 32): (1, 1, 32, 32, 1, 1),
+ (4096, 4096, 4096, 64, 64): (2, 16, 64, 32, 1, 2),
+ (4096, 4096, 4096, 128, 128): (4, 8, 32, 64, 1, 4),
+ (4096, 4096, 8192, 16, 16): (1, 8, 16, 64, 3, 1),
+ (4096, 4096, 8192, 32, 32): (2, 1, 32, 32, 1, 1),
+ (4096, 4096, 8192, 64, 64): (1, 16, 64, 32, 1, 2),
+ (4096, 4096, 8192, 128, 128): (2, 1, 32, 64, 1, 4),
+ (4096, 4096, 16384, 16, 16): (1, 8, 16, 64, 3, 1),
+ (4096, 4096, 16384, 32, 32): (1, 1, 32, 32, 1, 1),
+ (4096, 4096, 16384, 64, 64): (2, 8, 64, 32, 1, 2),
+ (4096, 4096, 16384, 128, 128): (2, 1, 32, 64, 1, 4),
+ (4096, 4096, 32768, 16, 16): (1, 8, 16, 64, 3, 1),
+ (4096, 4096, 32768, 32, 32): (1, 1, 32, 32, 1, 1),
+ (4096, 4096, 32768, 64, 64): (1, 8, 64, 32, 1, 2),
+ (4096, 4096, 32768, 128, 128): (2, 1, 32, 64, 1, 4),
+ (4096, 4096, 65536, 16, 16): (1, 8, 16, 64, 3, 1),
+ (4096, 4096, 65536, 32, 32): (3, 1, 32, 32, 1, 1),
+ (4096, 4096, 65536, 64, 64): (3, 4, 64, 32, 1, 2),
+ (4096, 4096, 65536, 128, 128): (2, 1, 32, 64, 1, 4),
+ (4096, 4096, 131072, 16, 16): (1, 8, 16, 64, 3, 1),
+ (4096, 4096, 131072, 32, 32): (1, 1, 32, 32, 1, 1),
+ (4096, 4096, 131072, 64, 64): (2, 8, 64, 32, 1, 2),
+ (4096, 4096, 131072, 128, 128): (1, 8192, 128, 16, 1, 8),
+ (8192, 8192, 256, 16, 16): (2, 4, 16, 64, 3, 1),
+ (8192, 8192, 256, 32, 32): (1, 8, 32, 32, 1, 1),
+ (8192, 8192, 256, 64, 64): (1, 4, 64, 64, 1, 4),
+ (8192, 8192, 256, 128, 128): (1, 4, 32, 64, 1, 4),
+ (8192, 8192, 512, 16, 16): (1, 4, 16, 64, 3, 1),
+ (8192, 8192, 512, 32, 32): (1, 16, 32, 32, 1, 1),
+ (8192, 8192, 512, 64, 64): (2, 4, 64, 64, 1, 4),
+ (8192, 8192, 512, 128, 128): (2, 1, 32, 64, 1, 4),
+ (8192, 8192, 1024, 16, 16): (3, 8, 16, 64, 3, 1),
+ (8192, 8192, 1024, 32, 32): (1, 16, 32, 32, 1, 1),
+ (8192, 8192, 1024, 64, 64): (1, 8, 64, 32, 1, 2),
+ (8192, 8192, 1024, 128, 128): (2, 4, 32, 64, 1, 4),
+ (8192, 8192, 2048, 16, 16): (1, 8, 16, 64, 3, 1),
+ (8192, 8192, 2048, 32, 32): (1, 16, 32, 32, 1, 1),
+ (8192, 8192, 2048, 64, 64): (2, 8, 64, 32, 1, 2),
+ (8192, 8192, 2048, 128, 128): (4, 1, 32, 64, 1, 4),
+ (8192, 8192, 4096, 16, 16): (1, 8, 16, 64, 3, 1),
+ (8192, 8192, 4096, 32, 32): (1, 16, 32, 32, 1, 1),
+ (8192, 8192, 4096, 64, 64): (1, 4, 64, 32, 1, 2),
+ (8192, 8192, 4096, 128, 128): (3, 1, 32, 64, 1, 4),
+ (8192, 8192, 8192, 16, 16): (1, 8, 16, 64, 3, 1),
+ (8192, 8192, 8192, 32, 32): (1, 8, 32, 32, 1, 1),
+ (8192, 8192, 8192, 64, 64): (1, 8, 64, 32, 1, 2),
+ (8192, 8192, 8192, 128, 128): (4, 1, 32, 64, 1, 4),
+ (8192, 8192, 16384, 16, 16): (3, 4, 16, 64, 3, 1),
+ (8192, 8192, 16384, 32, 32): (1, 8, 32, 32, 1, 1),
+ (8192, 8192, 16384, 64, 64): (2, 2, 64, 32, 1, 2),
+ (8192, 8192, 16384, 128, 128): (7, 1, 32, 64, 1, 4),
+ (8192, 8192, 32768, 16, 16): (1, 4, 16, 64, 3, 1),
+ (8192, 8192, 32768, 32, 32): (1, 8, 32, 32, 1, 1),
+ (8192, 8192, 32768, 64, 64): (3, 2, 64, 32, 1, 2),
+ (8192, 8192, 32768, 128, 128): (6, 1, 32, 64, 1, 4),
+ (8192, 8192, 65536, 16, 16): (1, 4, 16, 64, 3, 1),
+ (8192, 8192, 65536, 32, 32): (4, 8, 32, 32, 1, 1),
+ (8192, 8192, 65536, 64, 64): (1, 2, 64, 32, 1, 2),
+ (8192, 8192, 65536, 128, 128): (4, 1, 32, 64, 1, 4),
+ (8192, 8192, 131072, 16, 16): (1, 4, 16, 64, 3, 1),
+ (8192, 8192, 131072, 32, 32): (1, 8, 32, 32, 1, 1),
+ (8192, 8192, 131072, 64, 64): (5, 4, 64, 32, 1, 2),
+ (8192, 8192, 131072, 128, 128): (1, 4096, 128, 16, 1, 8),
+ (16384, 16384, 256, 16, 16): (1, 4, 16, 64, 3, 1),
+ (16384, 16384, 256, 32, 32): (1, 8, 32, 32, 1, 1),
+ (16384, 16384, 256, 64, 64): (1, 4, 64, 32, 1, 2),
+ (16384, 16384, 256, 128, 128): (1, 4, 32, 64, 1, 4),
+ (16384, 16384, 512, 16, 16): (1, 8, 16, 64, 3, 1),
+ (16384, 16384, 512, 32, 32): (1, 16, 32, 32, 1, 1),
+ (16384, 16384, 512, 64, 64): (1, 4, 64, 32, 1, 2),
+ (16384, 16384, 512, 128, 128): (3, 1, 32, 64, 1, 4),
+ (16384, 16384, 1024, 16, 16): (1, 8, 16, 64, 3, 1),
+ (16384, 16384, 1024, 32, 32): (1, 16, 32, 32, 1, 1),
+ (16384, 16384, 1024, 64, 64): (2, 4, 64, 32, 1, 2),
+ (16384, 16384, 1024, 128, 128): (1, 2, 32, 64, 1, 4),
+ (16384, 16384, 2048, 16, 16): (1, 4, 16, 64, 3, 1),
+ (16384, 16384, 2048, 32, 32): (1, 16, 32, 32, 1, 1),
+ (16384, 16384, 2048, 64, 64): (3, 4, 64, 32, 1, 2),
+ (16384, 16384, 2048, 128, 128): (2, 1, 32, 64, 1, 4),
+ (16384, 16384, 4096, 16, 16): (4, 8, 16, 64, 3, 1),
+ (16384, 16384, 4096, 32, 32): (5, 16, 32, 32, 1, 1),
+ (16384, 16384, 4096, 64, 64): (3, 2, 64, 32, 1, 2),
+ (16384, 16384, 4096, 128, 128): (2, 1, 32, 64, 1, 4),
+ (16384, 16384, 8192, 16, 16): (1, 4, 16, 64, 3, 1),
+ (16384, 16384, 8192, 32, 32): (1, 4, 32, 32, 1, 1),
+ (16384, 16384, 8192, 64, 64): (1, 2, 64, 32, 1, 2),
+ (16384, 16384, 8192, 128, 128): (2, 1, 32, 64, 1, 4),
+ (16384, 16384, 16384, 16, 16): (1, 8, 16, 64, 3, 1),
+ (16384, 16384, 16384, 32, 32): (1, 4, 32, 32, 1, 1),
+ (16384, 16384, 16384, 64, 64): (1, 2, 64, 32, 1, 2),
+ (16384, 16384, 16384, 128, 128): (3, 1, 32, 64, 1, 4),
+ (16384, 16384, 32768, 16, 16): (1, 4, 16, 64, 3, 1),
+ (16384, 16384, 32768, 32, 32): (1, 2, 32, 32, 1, 1),
+ (16384, 16384, 32768, 64, 64): (3, 2, 64, 32, 1, 2),
+ (16384, 16384, 32768, 128, 128): (3, 1, 32, 64, 1, 4),
+ (16384, 16384, 65536, 16, 16): (1, 8, 16, 64, 3, 1),
+ (16384, 16384, 65536, 32, 32): (1, 4, 32, 32, 1, 1),
+ (16384, 16384, 65536, 64, 64): (4, 4, 64, 32, 1, 2),
+ (16384, 16384, 65536, 128, 128): (5, 1, 32, 64, 1, 4),
+ (16384, 16384, 131072, 16, 16): (1, 2, 16, 64, 3, 1),
+ (16384, 16384, 131072, 32, 32): (1, 4, 32, 32, 1, 1),
+ (16384, 16384, 131072, 64, 64): (1, 2, 64, 32, 1, 2),
+ (16384, 16384, 131072, 128, 128): (1, 4096, 128, 16, 1, 8),
+ },
# END GENERATED DATA
}
if __name__ == "__main__":
- for op in ["scatter_mm", "bsr_dense_mm"]:
- main(op=op, force=False)
+ for dtype in [torch.float16, torch.bfloat16, torch.float32]:
+ for op in ["scatter_mm", "bsr_dense_mm"]:
+ main(op=op, force=False, dtype=dtype)