| import os |
| import sys |
| |
| import torch |
| |
| # add some debug printouts |
| debug = False |
| |
| # add inf and NaN checkers |
| debug_check_inf_and_nan = False |
| |
| # Whether to disable a progress bar for autotuning |
| disable_progress = True |
| |
| # Whether to enable printing the source code for each future |
| verbose_progress = False |
| |
| # use fx aot graph codegen cache |
| fx_graph_cache = os.environ.get("TORCHINDUCTOR_FX_GRAPH_CACHE") == "1" |
| |
| # use cpp wrapper instead of python wrapper |
| cpp_wrapper = False |
| |
| # dead code elimination |
| dce = False |
| |
| # assume weight tensors are fixed size |
| static_weight_shapes = True |
| |
| # put correctness assertions in generated code |
| size_asserts = os.environ.get("TORCHINDUCTOR_SIZE_ASSERTS", "1") == "1" |
| |
| # enable loop reordering based on input orders |
| pick_loop_orders = True |
| |
| # reuse a kernel input as the output |
| inplace_buffers = True |
| |
| # reuse a buffer for an unrelated purpose |
| allow_buffer_reuse = True |
| |
| # codegen benchmark harness |
| benchmark_harness = True |
| |
| # fuse pointwise into templates |
| epilogue_fusion = True |
| |
| # do epilogue fusions before other fusions |
| epilogue_fusion_first = False |
| |
| # enable pattern match+replace optimizations |
| pattern_matcher = True |
| |
| # register custom graph optimization pass hook. so far, pre/post passes are |
| # only applied before/after pattern_matcher in post_grad_passes. |
| # |
| # def my_custom_pre_pass(graph: torch.fx.graph.Graph): |
| # # my custom graph optimization pass |
| # ... |
| # |
| # def my_custom_post_pass(graph: torch.fx.graph.Graph): |
| # # my custom graph optimization pass |
| # ... |
| # |
| # torch._inductor.config.post_grad_custom_pre_pass = my_custom_pre_pass |
| # torch._inductor.config.post_grad_custom_post_pass = my_custom_post_pass |
| post_grad_custom_pre_pass = None |
| post_grad_custom_post_pass = None |
| |
| # Optimize away split cat patterns (Experimental) |
| split_cat_fx_passes = True |
| |
| # Optimize conv-batchnorm if batchnorm is in eval mode. Slightly reduces numerical stability. |
| efficient_conv_bn_eval_fx_passes = False |
| |
| # enable pattern match with group fusion (using fbgemm) |
| group_fusion = False |
| |
| # enable pattern match with batch fusion (using torch op) |
| batch_fusion = True |
| |
| # enable reordering pass |
| reordering = True |
| |
| # Scale down RBLOCK for better occupancy |
| dynamic_scale_rblock = os.environ.get("TORCHINDUCTOR_DYNAMIC_SCALE_RBLOCK", "1") == "1" |
| |
| # for pattern torch.mm(a, b.to(dtype)) with cuda tensors, |
| # enable torch._inductor.kernel.mm.tuned_mixed_mm fused kernel. |
| # Autotune will compare perf with normal cast->then->mm option |
| use_mixed_mm = False |
| |
| # for pattern torch.mm(a, b.to(dtype)) with cuda tensors, always use |
| # torch._inductor.kernel.mm.tuned_mixed_mm's fused kernel. |
| # Autotune will not compare with normal cast->then->mm option. |
| # (if force_mixed_mm is true, the use_mixed_mm flag will be ignored) |
| force_mixed_mm = False |
| |
| # TODO: capture whether the graph is from export |
| from_export = False |
| |
| # enable slow autotuning passes to select algorithms |
| max_autotune = os.environ.get("TORCHINDUCTOR_MAX_AUTOTUNE") == "1" |
| |
| # enable slow autotuning passes to select pointwise/reductions algorithms |
| max_autotune_pointwise = os.environ.get("TORCHINDUCTOR_MAX_AUTOTUNE_POINTWISE") == "1" |
| |
| # enable slow autotuning passes to select gemm algorithms |
| max_autotune_gemm = os.environ.get("TORCHINDUCTOR_MAX_AUTOTUNE_GEMM") == "1" |
| |
| # Specify candidate backends for gemm autotune. |
| # Possible choices are combinations of: ATen, Triton, CUTLASS. |
| # ATen: default Pytorch ATen kernels. |
| # Triton: Triton templates defined in torch inductor. |
| # CUTLASS: Cutlass templates and kernels. |
| max_autotune_gemm_backends = os.environ.get( |
| "TORCHINDUCTOR_MAX_AUTOTUNE_GEMM_BACKENDS", "ATEN,TRITON" |
| ).upper() |
| |
| # enable searching global and local cache regardless of `max_autotune` |
| search_autotune_cache = os.environ.get("TORCHINDUCTOR_SEARCH_AUTOTUNE_CACHE") == "1" |
| |
| save_args = os.environ.get("TORCHINDUCTOR_SAVE_ARGS") == "1" |
| |
| # We will disable creating subprocess for autotuning if this is False |
| autotune_in_subproc = os.environ.get("TORCHINDUCTOR_AUTOTUNE_IN_SUBPROC") == "1" |
| |
| # If autotuning in subprocess, whether to use multiple devices |
| autotune_multi_device = os.environ.get("TORCHINDUCTOR_AUTOTUNE_MULTI_DEVICE") == "1" |
| |
| coordinate_descent_tuning = ( |
| os.environ.get("TORCHINDUCTOR_COORDINATE_DESCENT_TUNING") == "1" |
| ) |
| coordinate_descent_check_all_directions = ( |
| os.environ.get("TORCHINDUCTOR_COORDINATE_DESCENT_CHECK_ALL_DIRECTIONS") == "1" |
| ) |
| coordinate_descent_search_radius = int( |
| os.environ.get("TORCHINDUCTOR_COORDINATE_DESCENT_RADIUS", "1") |
| ) |
| |
| layout_optimization = os.environ.get("TORCHINDUCTOR_LAYOUT_OPTIMIZATION", "1") == "1" |
| |
| # Whether to keep the output strides the same as eager after layout optimization. |
| keep_output_stride = os.environ.get("TORCHINDUCTOR_KEEP_OUTPUT_STRIDE", "1") == "1" |
| |
| # Enabling this will let compiler print warning messages if a generated triton |
| # kernel has inputs with mixed layouts. This is helpful for perf debugging |
| # since kernel with mixed layout inputs may run much slower then one whose inputs |
| # have uniform layouts. |
| warn_mix_layout = os.environ.get("TORCHINDUCTOR_WARN_MIX_LAYOUT") == "1" |
| |
| # control store vs recompute heuristic |
| # For fanouts, rematerialization can lead to exponential blowup. So, have |
| # smaller threshold |
| realize_reads_threshold = 4 |
| realize_bytes_threshold = 2000 |
| |
| # Threshold to prevent excessive accumulation of ops in one buffer during lowering |
| realize_acc_reads_threshold = 8 |
| |
| # fallback to eager for random/dropout, this is slow but useful for debugging |
| fallback_random = False |
| |
| # automatically create fallbacks when encountering an unhandled op |
| implicit_fallbacks = True |
| |
| # fuse even in cases without common reads |
| aggressive_fusion = False |
| |
| # For each fused kernel in the wrapper, comment with the nodes that get fused. |
| # Useful for debugging fusion. |
| debug_fusion = os.environ.get("TORCHINDUCTOR_DEBUG_FUSION") == "1" |
| |
| # how many nodes to allow into a single fusion |
| max_fusion_size = 64 |
| |
| # replace small reductions with pointwise, disable with `= 1` |
| unroll_reductions_threshold = 8 |
| |
| # Add extra comments to output code (causes compile cache misses) |
| comment_origin = False |
| |
| # Convert 1x1 convs into matmuls |
| conv_1x1_as_mm = False |
| |
| # Enable split reductions for better utilization when the dimension |
| # being reduced over is large (by splitting it) |
| split_reductions = True |
| |
| benchmark_kernel = os.environ.get("TORCHINDUCTOR_BENCHMARK_KERNEL", "0") == "1" |
| |
| # Enable constant and index_expr folding |
| constant_and_index_propagation = True |
| |
| # we always add constants into graph.constants without |
| # performing any constant-inlining optimization |
| always_keep_tensor_constants = False |
| |
| |
| def is_fbcode(): |
| return not hasattr(torch.version, "git_version") |
| |
| |
| # constant folding on the joint graph |
| joint_graph_constant_folding = True |
| |
| # Enable indirect_indexing asserts for decompositions and lowerings |
| debug_index_asserts = False |
| |
| # warnings intended for PyTorch developers, disable for point releases |
| is_nightly_or_source = "dev" in torch.__version__ or "git" in torch.__version__ |
| developer_warnings = is_fbcode() or is_nightly_or_source |
| |
| # The multiprocessing start method to use for inductor workers in the codecache. |
| # TODO: fork is not safe in a multithreaded environment, we should evaluate changing |
| # the default to spawn. |
| worker_start_method = "fork" |
| |
| |
| def decide_compile_threads(): |
| """ |
| Here are the precedence to decide compile_threads |
| 1. User can override it by TORCHINDUCTOR_COMPILE_THREADS. One may want to disable async compiling by |
| setting this to 1 to make pdb happy. |
| 2. Set to 1 if it's win32 platform or it's a fbcode build |
| 3. decide by the number of CPU cores |
| """ |
| if "TORCHINDUCTOR_COMPILE_THREADS" in os.environ: |
| return int(os.environ["TORCHINDUCTOR_COMPILE_THREADS"]) |
| elif sys.platform == "win32" or is_fbcode(): |
| return 1 |
| else: |
| cpu_count = ( |
| len(os.sched_getaffinity(0)) |
| if hasattr(os, "sched_getaffinity") |
| else os.cpu_count() |
| ) |
| assert cpu_count |
| return min(32, cpu_count) |
| |
| |
| compile_threads = decide_compile_threads() |
| |
| # gemm autotuning global cache dir |
| if is_fbcode(): |
| from libfb.py import parutil # type: ignore[import] |
| |
| try: |
| if __package__: |
| global_cache_dir = parutil.get_dir_path( |
| os.path.join(__package__.replace(".", os.sep), "fb/cache") |
| ) |
| else: |
| global_cache_dir = parutil.get_dir_path("fb/cache") |
| except ValueError: |
| global_cache_dir = None |
| else: |
| global_cache_dir = None |
| |
| # If kernel is fused, the name is generated from the origin node op names |
| # for larger kernels limit this |
| kernel_name_max_ops = 10 |
| |
| # Pad input tensors of matmul/bmm/addmm to leverage Tensor Cores in NVIDIA GPUs |
| shape_padding = os.environ.get("TORCHINDUCTOR_SHAPE_PADDING", "1") == "1" |
| |
| # Fx-based linear/matmul/bmm + permute/transpose vertical fusion |
| permute_fusion = os.environ.get("TORCHINDUCTOR_PERMUTE_FUSION", "0") == "1" |
| |
| # Mark the wrapper call in PyTorch profiler |
| profiler_mark_wrapper_call = False |
| |
| # Generate hook calls to torch._inductor.hooks.run_intermediate_hooks for |
| # every intermediate for which we can correlate it with an intermediate |
| # from the original FX graph |
| generate_intermediate_hooks = False |
| |
| # Populate traceback field on IRNode; good for debugging why origin_node is |
| # not populated, or finding out where an IRNode was constructed |
| debug_ir_traceback = False |
| |
| # used for debugging to make sure config is properly set |
| _raise_error_for_testing = False |
| |
| _profile_var = os.environ.get("TORCHINDUCTOR_PROFILE", "") |
| profile_bandwidth = _profile_var != "" |
| profile_bandwidth_regex = "" if _profile_var == "1" else _profile_var |
| |
| # TODO: remove later |
| disable_cpp_codegen = False |
| |
| |
| # Freezing will attempt to inline weights as constants in optimization |
| # and run constant folding and other optimizations on them. After freezing, weights |
| # can no longer be updated. |
| freezing: bool = os.environ.get("TORCHINDUCTOR_FREEZING", "0") == "1" |
| |
| # Make freezing invalidate the eager Parameters of nn modules, to avoid memory overhead |
| # of potentially keeping multiple copies of weights. |
| freezing_discard_parameters: bool = False |
| |
| |
| # config specific to codegen/cpp.py |
| class cpp: |
| # set to torch.get_num_threads() |
| threads = -1 |
| |
| # Do not generate loops when the condition doesn't hold, like: |
| # for(long i0=4096; i0<4096; i0+=1) |
| no_redundant_loops = True |
| |
| # Assume number of threads is dynamic, don't specialize thread number. |
| # Kernels don't recompile on thread number changes with this flag on. |
| # For single-threaded workload, turning it on would incur a slight |
| # performance degradation. |
| dynamic_threads = False |
| |
| simdlen = None |
| min_chunk_size = 4096 |
| cxx = ( |
| None, # download gcc12 from conda-forge if conda is installed |
| # "g++-12", |
| # "g++-11", |
| # "g++-10", |
| # "clang++", |
| os.environ.get("CXX", "g++"), |
| # "g++.par", |
| ) |
| # Allow kernel performance profiling via PyTorch profiler |
| enable_kernel_profile = False |
| |
| # enable weight prepacking to get a better performance; may lead to large memory footprint |
| weight_prepack = True |
| |
| # Inject a bug into our relu implementation; useful for testing our repro |
| # extraction and minification functionality. |
| # Valid values: "compile_error", "runtime_error", "accuracy" |
| inject_relu_bug_TESTING_ONLY = None |
| inject_log1p_bug_TESTING_ONLY = None |
| |
| # If None, autodetect whether or not AVX512/AVX2 can be used. Otherwise, |
| # force usage as specified, without testing. |
| vec_isa_ok = None |
| |
| # similar to config.triton.descriptive_names |
| descriptive_names = "original_aten" |
| |
| # how many nodes to allow into a single horizontal fusion |
| max_horizontal_fusion_size = 16 |
| |
| # Make scatter_reduce fallback when reduce is sum to avoid performance regression |
| # using atomic_add. |
| fallback_scatter_reduce_sum = True |
| |
| |
| # config specific to codegen/triton.py |
| class triton: |
| # Use cudagraphs on output code |
| cudagraphs = False |
| |
| # Use cudagraph trees for memory pooling if `cudagraphs` is True |
| cudagraph_trees = True |
| |
| # assertions not on the fast path, steady state |
| slow_path_cudagraph_asserts = True |
| |
| # TODO - need to debug why this prevents cleanup |
| cudagraph_trees_history_recording = False |
| |
| # assertions on the fast path |
| fast_path_cudagraph_asserts = False |
| |
| # skip warmup for cudagraph trees |
| skip_cudagraph_warmup = False |
| |
| # Synchronize before and after every compiled graph. |
| debug_sync_graph = False |
| |
| # Synchronize after every kernel launch, to help pinpoint bugs |
| debug_sync_kernel = False |
| |
| # Always load full blocks (rather than broadcasting inside the block) |
| dense_indexing = False |
| |
| # limit tiling dimensions |
| max_tiles = 2 |
| |
| # use triton.autotune for pointwise ops with complex layouts |
| # this should only be disabled for debugging/testing |
| autotune_pointwise = True |
| |
| # max autotune gemm with cublasLt |
| autotune_cublasLt = True |
| |
| # should we stop a fusion to allow better tiling? |
| tiling_prevents_pointwise_fusion = True |
| tiling_prevents_reduction_fusion = True |
| |
| # assert that indirect indexing does not read / write out of bounds |
| assert_indirect_indexing = True |
| |
| # should we give different names to kernels |
| # Note: This is orthogonal to descriptive_names - this is deciding whether |
| # our triton kernel names should all be `triton_` (to maximize caching) or |
| # whether they should be unique. |
| unique_kernel_names = os.environ.get("TORCHINDUCTOR_UNIQUE_KERNEL_NAMES") == "1" |
| |
| # should we put op names in kernel names |
| # False: No special names (just triton__1, triton__2, etc.) |
| # "torch": Maps to the fx op in the Dynamo graph (module name, method name, etc.) |
| # "original_aten": Maps to the highest-level aten op (i.e. pre-decompositions) |
| # "inductor_node": Maps to the node name in the FX graph passed to Inductor |
| descriptive_names = "original_aten" |
| |
| # use alternate codegen for smaller reductions |
| persistent_reductions = ( |
| os.environ.get("TORCHINDUCTOR_PERSISTENT_REDUCTIONS", "1") == "1" |
| ) |
| |
| # hint to Triton when arguments are divisible by 16 |
| divisible_by_16 = True |
| |
| # theses are not enforced, but they are used by asserts in triton_heuristics.py |
| # NOTE: mobilevit_s in timm_models required X to be set to the higher value 2048 |
| max_block = {"X": 2048, "Y": 1024, "Z": 1024} |
| |
| # Store the generated cubin files for cpp wrapper code to load |
| store_cubin = False |
| |
| # the max number of spills we allow for the configs we benchmark. |
| # Setting this to 0 means we skip a config if it spills even a single |
| # register. |
| # Setting it to a larger value allows a config spilling a small amount |
| # of registers being benchmarked. |
| # |
| # NOTE: triton will always report >0 register spills for kernels using sin/cos. |
| # (check this issue https://github.com/openai/triton/issues/1756 ) |
| # So far we see a fixed 8 spilled registers for kernels using sin/cos. |
| # Raise the threshold to 16 to be safe. |
| # We should revisit this once we understand more of the source of register spills. |
| spill_threshold: int = 16 |
| |
| # Inject a bug into our relu implementation; useful for testing our repro |
| # extraction and minification functionality. |
| # Valid values: "compile_error", "runtime_error", "accuracy" |
| inject_relu_bug_TESTING_ONLY = None |
| |
| |
| class aot_inductor: |
| # AOTInductor output path |
| # If an absolute path is specified, the generated lib files will be stored under the directory; |
| # If a relative path is specified, it will be used as a subdirectory under the default caching path; |
| # If not specified, a temp directory will be created under the default caching path |
| output_path = "" |
| |
| # Wether to codegen abi compatible model.so |
| abi_compatible = is_fbcode() |
| |
| |
| class cuda: |
| # CUDA arch to use for CUDA template kernel compilation. |
| # e.g. "70", "75", "80", "90", etc. |
| # When arch is None, Inductor uses torch.cuda.get_device_capability(0). |
| arch = None |
| |
| # CUDA version to use for CUDA template kernel compilation. |
| # e.g. "11.4", "12.1", etc. |
| # When version is None, Inductor uses torch.version.cuda. |
| version = None |
| |
| # Optimization level for the host compiler. |
| compile_opt_level = "-O1" |
| |
| # Whether to enable device LTO (link-time-optimization). |
| enable_cuda_lto = False |
| |
| # Whether to keep intermediate files dring compilation. |
| enable_ptxas_info = False |
| |
| # Whether to enable debug info, e.g. line number, cutlass debug info. |
| enable_debug_info = False |
| |
| # Whether to use fast math. |
| use_fast_math = False |
| |
| # Path to the CUTLASS repo root directory. |
| # The default path only works under PyTorch local development environment. |
| cutlass_dir = os.environ.get( |
| "TORCHINDUCTOR_CUTLASS_DIR", |
| os.path.abspath( |
| os.path.join(os.path.dirname(torch.__file__), "../third_party/cutlass/") |
| ), |
| ) |
| |
| # Configures the maximum number of CUTLASS configs to profile in max_autotune. |
| # By default it's None, so that all CUTLASS configs are tuned. |
| # This is mainly used to reduce test time in CI. |
| cutlass_max_profiling_configs = None |
| |
| # Path to CUDA NVCC. |
| # NVCC search order: |
| # 1) cuda_cxx set in this config |
| # 2)CUDACXX environment variable |
| # 3)CUDA_HOME environment variable |
| # 4) default system search PATH. |
| cuda_cxx = None |
| |
| |
| # create a directory containing lots of debug information |
| class trace: |
| # master switch for all debugging flags below |
| enabled = os.environ.get("TORCH_COMPILE_DEBUG", "0") == "1" |
| |
| # Save python logger call >=logging.DEBUG |
| debug_log = False |
| |
| # Save python logger call >=logging.INFO |
| info_log = False |
| |
| # Save input FX graph (post decomps, pre optimization) |
| fx_graph = True |
| |
| # Save FX graph after transformations |
| fx_graph_transformed = True |
| |
| # Save TorchInductor IR before fusion pass |
| ir_pre_fusion = True |
| |
| # Save TorchInductor IR after fusion pass |
| ir_post_fusion = True |
| |
| # Copy generated code to trace dir |
| output_code = True |
| |
| # SVG figure showing post-fusion graph |
| graph_diagram = os.environ.get("INDUCTOR_POST_FUSION_SVG", "0") == "1" |
| |
| # SVG figure showing fx with fusion |
| draw_orig_fx_graph = os.environ.get("INDUCTOR_ORIG_FX_SVG", "0") == "1" |
| |
| # Store cProfile (see snakeviz to view) |
| compile_profile = False |
| |
| # Upload the .tar.gz file |
| # Needs to be overriden based on specific environment needs |
| upload_tar = None |
| |
| |
| _save_config_ignore = { |
| # workaround: "Can't pickle <function ...>" |
| "trace.upload_tar", |
| } |
| |
| |
| from .._dynamo.config_utils import install_config_module |
| |
| # adds patch, save_config, etc |
| install_config_module(sys.modules[__name__]) |