blob: b338bbf1241ad36a64cf600b979e78fcf4833a9d [file] [log] [blame]
import os
import sys
from functools import lru_cache
# add some debug printouts
debug = False
# Whether to disable a progress bar for autotuning
disable_progress = True
# Whether to enable printing the source code for each future
verbose_progress = False
# use cpp wrapper instead of python wrapper
cpp_wrapper = False
# dead code elimination
dce = False
# assume input tensors are dynamic
dynamic_shapes = (
os.environ.get("TORCHDYNAMO_DYNAMIC_SHAPES") == "1"
) # Use dynamic shapes if torchdynamo dynamic shapes is set
# assume weight tensors are fixed size
static_weight_shapes = True
# put correctness assertions in generated code
size_asserts = True
# enable loop reordering based on input orders
pick_loop_orders = True
# generate inplace computations
inplace_buffers = True
# codegen benchmark harness
benchmark_harness = True
# control store vs recompute heuristic
# For fanouts, rematearialization can lead to exponential blowup. So, have
# smaller threshold
realize_reads_threshold = 4
realize_bytes_threshold = 2000
# Threshold to prevent excessive accumulation of ops in one buffer during lowering
realize_acc_reads_threshold = 8
# fallback to eager for random/dropout, this is slow but useful for debugging
fallback_random = False
# automatically create fallbacks when encountering an unhandled op
implicit_fallbacks = True
# Enables a fusion pass that groups nodes together before the scheduler
prefuse_nodes = True
# do bench to decide best layout, currently only for aten.conv
tune_layout = False
# fuse even in cases without common reads
aggressive_fusion = False
# how many nodes to allow into a single fusion
max_fusion_size = 64
# replace small reductions with pointwise, disable with `= 1`
unroll_reductions_threshold = 8
comment_origin = False
@lru_cache(1)
def is_fbcode():
try:
import torch.fb # noqa: F401
except ImportError:
return False
return True
compile_threads = (
1
if sys.platform == "win32" or is_fbcode()
else min(
32,
len(os.sched_getaffinity(0))
if hasattr(os, "sched_getaffinity")
else os.cpu_count(),
)
)
# If kernel is fused, the name is generated from the origin node op names
# for larger kernels limit this
kernel_name_max_ops = 10
# How to import torchinductor, either torchinductor or torch.inductor
inductor_import = __name__.replace(".config", "")
# How to import torchdynamo, either torchdynamo or torch.dynamo
dynamo_import = inductor_import.replace("inductor", "dynamo")
# Pad input tensors of matmul/bmm/addmm to leverage Tensor Cores in NVIDIA GPUs
shape_padding = os.environ.get("TORCHINDUCTOR_SHAPE_PADDING", "0") == "1"
alignment_size = 4
# Fx-based linear/matmul/bmm + permute/transpose vertical fusion
permute_fusion = os.environ.get("TORCHINDUCTOR_PERMUTE_FUSION", "0") == "1"
# Mark the wrapper call in PyTorch profiler
profiler_mark_wrapper_call = False
# config specific to codegen/cpp.pp
class cpp:
# set to torch.get_num_threads()
threads = -1
# Assume number of threads is dynamic, don't specialize thread number.
# Kernels don't recompile on thread number changes with this flag on.
# For single-threaded workload, turning it on would incur a slight
# performance degradation.
dynamic_threads = False
simdlen = None
min_chunk_size = 4096
cxx = (
None, # download gcc12 from conda-forge if conda is installed
"g++-12",
"g++-11",
"g++-10",
"clang++",
"g++",
"g++.par",
)
# Allow kernel performance profiling via PyTorch profiler
enable_kernel_profile = False
# config specific to codegen/triton.py
class triton:
# Use cudagraphs on output code
cudagraphs = True
# choose conv backend, "aten" or "triton" or "autotune"
convolution = "aten"
# choose mm backend, "aten" or "triton" or "autotune"
mm = "aten"
# Always load full blocks (rather than broadcasting inside the block)
# Set default as True because otherwise will encouter `map::at` error
# in triton if loading from 1-dim tensor using 2-dim pointer offset
# https://triton-lang.slack.com/archives/C01L1FLTX70/p1656023403343639
# could be set as False if triton fixes the bug later
dense_indexing = False
# limit tiling dimensions
max_tiles = 2
# use triton.autotune?
autotune = True
use_bmm = False
# should we stop a fusion to allow better tiling?
tiling_prevents_pointwise_fusion = True
tiling_prevents_reduction_fusion = True
# should we give different names to kernels
ordered_kernel_names = False
# should we put op names in kernel names
descriptive_kernel_names = True
# create a directory containing lots of debug information
class trace:
# master switch for all debugging flags below
enabled = os.environ.get("TORCHINDUCTOR_TRACE", "0") == "1"
# Save python logger call >=logging.DEBUG
debug_log = True
# Save python logger call >=logging.INFO
info_log = False
# Save input FX graph (post decomps)
fx_graph = True
# Save TorchInductor IR before fusion pass
ir_pre_fusion = True
# Save TorchInductor IR after fusion pass
ir_post_fusion = True
# Copy generated code to trace dir
output_code = True
# SVG figure showing post-fusion graph
graph_diagram = False
# Store cProfile (see snakeviz to view)
compile_profile = False
# Upload the .tar.gz file
# Needs to be overriden based on specific environment needs
upload_tar = None
class InductorConfigContext:
static_memory: bool
matmul_tune: str
matmul_padding: bool
triton_autotune: bool
triton_bmm: bool
triton_mm: str
triton_convolution: str
rematerialize_threshold: int
rematerialize_acc_threshold: int
def _save(self):
self.static_memory = triton.cudagraphs
self.matmul_tune = triton.mm
self.matmul_padding = shape_padding
self.triton_autotune = triton.autotune
self.triton_bmm = triton.use_bmm
self.triton_mm = triton.mm
self.triton_convolution = triton.convolution
self.rematerialize_threshold = realize_reads_threshold
self.rematerialize_acc_threshold = realize_acc_reads_threshold
def _apply(self):
triton.cudagraphs = self.static_memory
triton.mm = self.matmul_tune
shape_padding = self.matmul_padding
triton.autotune = self.triton_autotune
triton.use_bmm = self.triton_bmm
triton.mm = self.triton_mm
triton.convolution = self.triton_convolution
realize_reads_threshold = self.rematerialize_threshold
realize_acc_reads_threshold = self.rematerialize_acc_threshold
def __init__(self, arg=None):
self._save()
if arg is None:
return
# Handle mode
if type(arg) is str:
def default():
self.static_memory = False
def reduce_overhead():
self.static_memory = True
def max_autotune():
self.static_memory = False
self.matmul_padding = True
self.triton_convolution = "autotune"
self.triton_mm = "autotune"
self.matmul_padding = True
modes = {
x.__name__.replace("_", "-"): x
for x in [default, reduce_overhead, max_autotune]
}
if arg not in modes:
raise RuntimeError(
f"Unrecognized mode {arg}, should be one of {', '.join(modes.keys())}"
)
modes[arg]()
return
# Handle passes
for (name, val) in arg.items():
attr_name = name.replace("-", "_")
if not hasattr(self, attr_name):
known_passes = ", ".join(
[x.replace("_", "-") for x in dir(self) if not x.startswith("_")]
)
raise RuntimeError(
f"Unexpected optimization pass {name}, known passes are {known_passes}"
)
if type(val) != type(getattr(self, attr_name)):
val_type_str = type(val).__name__
expected_type_str = type(getattr(self, attr_name)).__name__
raise RuntimeError(
f"Unexpected type of attr {name}, got {val_type_str} should be {expected_type_str}"
)
setattr(self, attr_name, val)
def __enter__(self):
self._prev = InductorConfigContext()
self._apply()
def __exit__(self, exc_type, exc_val, exc_tb):
self._prev._apply()