| import functools |
| import inspect |
| import itertools |
| import types |
| from typing import Dict, List |
| |
| import torch |
| |
| from .. import variables |
| from ..bytecode_transformation import create_call_function, create_rot_n |
| from ..exc import unimplemented, Unsupported |
| from ..source import ( |
| AttrSource, |
| ConstantSource, |
| DefaultsSource, |
| GetItemSource, |
| GlobalSource, |
| ) |
| from ..utils import make_cell |
| from .base import typestr, VariableTracker |
| |
| |
| def wrap_bound_arg(tx, val, options, source=None): |
| # Source propagation is best effort since not every object we encounter has a source to begin with. |
| assert ( |
| "source" not in options |
| ), "Source needs to be separate from options due to recursive calls for lists/dicts" |
| if isinstance(val, VariableTracker): |
| return val |
| elif not source: |
| from torch._dynamo.variables.builder import SourcelessBuilder |
| |
| return SourcelessBuilder()(tx, val).add_options(options) |
| else: |
| from torch._dynamo.variables.builder import VariableBuilder |
| |
| return VariableBuilder(tx, source=source)(val).add_options(options) |
| |
| |
| def wrap_args_kwargs(tx, result, options): |
| for k, v in list(result.items()): |
| if isinstance(v, (tuple, dict)): |
| # args/kwargs |
| result[k] = wrap_bound_arg(tx, v, options) |
| |
| |
| def init_cellvars(parent, result, code): |
| closure_cells = dict() |
| side_effects = parent.output.side_effects |
| |
| # for name in itertools.chain(code.co_cellvars, code.co_freevars): |
| for name in code.co_cellvars: |
| closure_cells[name] = side_effects.track_cell_new() |
| if name in result: |
| side_effects.store_cell(closure_cells[name], result.pop(name)) |
| |
| return closure_cells |
| |
| |
| def _create_nested_fn( |
| code, f_globals, name, defaults, closure, kwdefaults, annotations |
| ): |
| from types import FunctionType |
| |
| func = FunctionType(code, f_globals, name, defaults, closure) |
| func.__kwdefaults__ = kwdefaults |
| |
| if isinstance(annotations, tuple): |
| from itertools import pairwise |
| |
| annotations = dict(pairwise(annotations)) |
| |
| # TypeError: __annotations__ must be set to a dict object |
| assert annotations is None or isinstance(annotations, dict) |
| func.__annotations__ = annotations |
| |
| return func |
| |
| |
| class BaseUserFunctionVariable(VariableTracker): |
| def get_filename(self): |
| return self.get_code().co_filename |
| |
| def get_name(self): |
| return self.get_code().co_name |
| |
| def call_function( |
| self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]" |
| ) -> "VariableTracker": |
| return tx.inline_user_function_return( |
| self, list(self.self_args()) + list(args), kwargs |
| ) |
| |
| def num_parameters(self): |
| return len(inspect.signature(self.get_function()).parameters) |
| |
| def closure_vars(self, tx): |
| return {} |
| |
| |
| class UserFunctionVariable(BaseUserFunctionVariable): |
| """Some unsupported user-defined global function""" |
| |
| def __init__(self, fn, is_constant=False, **kwargs): |
| super().__init__(**kwargs) |
| if getattr(fn, "_dynamo_marked_constant", False): |
| # This method should be treated as a constant for the purposes of compilation |
| self.is_constant = True |
| else: |
| self.is_constant = False |
| |
| assert isinstance( |
| fn, (types.FunctionType, torch.jit.ScriptFunction) |
| ), f"expected FunctionType found {typestr(fn)} {fn}" |
| # unpack @torch._dynamo.optimize()(fn) wrapped function |
| fn = inspect.getattr_static(fn, "_torchdynamo_inline", fn) |
| # unpack torch.jit.script_if_tracing |
| if inspect.getattr_static(fn, "__script_if_tracing_wrapper", False): |
| fn = inspect.getattr_static(fn, "__original_fn", fn) |
| self.fn: types.FunctionType = fn |
| |
| def self_args(self): |
| return [] |
| |
| def get_function(self): |
| return self.fn |
| |
| def get_code(self): |
| return self.fn.__code__ |
| |
| def python_type(self): |
| return types.FunctionType |
| |
| def has_self(self): |
| return getattr(self.fn, "__self__", None) is not None |
| |
| def get_globals(self): |
| return self.fn.__globals__ |
| |
| def bind_args(self, parent, args, kwargs): |
| assert not self.is_constant |
| options = VariableTracker.propagate([self]) |
| tx = parent.output.root_tx |
| wrap = functools.partial(wrap_bound_arg, tx=tx, options=options) |
| |
| fn: types.FunctionType = self.fn |
| defaults = fn.__defaults__ or [] |
| defaults_sources = [ |
| None if self.source is None else DefaultsSource(self.source, idx) |
| for idx, _ in enumerate(defaults) |
| ] |
| fake_func = types.FunctionType( |
| fn.__code__, |
| fn.__globals__, |
| fn.__name__, |
| tuple( |
| [ |
| wrap(val=arg, source=source) |
| for arg, source in zip(defaults, defaults_sources) |
| ] |
| ), |
| fn.__closure__, |
| ) |
| if fn.__kwdefaults__: |
| kwdefaults_sources = { |
| k: None |
| if self.source is None |
| else DefaultsSource(self.source, k, is_kw=True) |
| for k in fn.__kwdefaults__ |
| } |
| fake_func.__kwdefaults__ = { |
| k: wrap(val=v, source=kwdefaults_sources[k]) |
| for k, v in fn.__kwdefaults__.items() |
| } |
| |
| bound = inspect.signature(fake_func).bind(*args, **kwargs) |
| bound.apply_defaults() |
| result = dict(bound.arguments.items()) |
| |
| wrap_args_kwargs(tx, result, options) |
| closure_cells = init_cellvars(parent, result, fn.__code__) |
| closure = self.fn.__closure__ or () |
| assert len(closure) == len(self.fn.__code__.co_freevars) |
| for idx, name, cell in zip( |
| itertools.count(), self.fn.__code__.co_freevars, closure |
| ): |
| if name == "__class__": |
| source = AttrSource(self.source, "__class__") if self.source else None |
| result[name] = variables.UserDefinedClassVariable( |
| cell.cell_contents, |
| source=source, |
| ) |
| else: |
| var = tx.match_nested_cell(name, cell) |
| if var is not None: |
| # optimization for cleaner codegen |
| result[name] = var |
| elif self.source: |
| from .builder import VariableBuilder |
| |
| side_effects = parent.output.side_effects |
| if cell in side_effects: |
| out = side_effects[cell] |
| else: |
| closure_cell = GetItemSource( |
| AttrSource(self.source, "__closure__"), idx |
| ) |
| closure_cell_contents = AttrSource( |
| closure_cell, "cell_contents" |
| ) |
| contents_var = VariableBuilder(parent, closure_cell_contents)( |
| cell.cell_contents |
| ) |
| |
| if ( |
| closure_cell_contents.name() |
| not in tx.mutated_closure_cell_contents |
| ): |
| # Optimistically don't allocate the cell, to |
| # reduce the number of side effects. This is |
| # important for cond, as without it, any accesses |
| # to closures create side effects and cond doesn't |
| # support side effects. If we're wrong and this |
| # closure cell gets written to, we will restart |
| # the analysis with this cell's name in the |
| # mutated list here |
| result[name] = contents_var |
| continue |
| |
| # cells are written to with "cell_contents", |
| # so the source should just be the closure_cell, not its contents |
| out = side_effects.track_cell_existing(closure_cell, cell) |
| side_effects.store_cell( |
| out, |
| contents_var, |
| ) |
| |
| result[name] = out |
| |
| else: |
| from .builder import SourcelessBuilder |
| |
| result[name] = SourcelessBuilder()( |
| tx, cell.cell_contents |
| ).add_options(options) |
| |
| return result, closure_cells |
| |
| def export_freevars(self, parent, child): |
| pass |
| |
| def call_function( |
| self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]" |
| ) -> "VariableTracker": |
| if self.is_constant: |
| options = VariableTracker.propagate(self, args, kwargs.values()) |
| return invoke_and_store_as_constant( |
| tx, self.fn, self.get_name(), options, args, kwargs |
| ) |
| |
| return super().call_function(tx, args, kwargs) |
| |
| |
| class UserMethodVariable(UserFunctionVariable): |
| """Some unsupported user-defined method""" |
| |
| def __init__(self, fn, obj, **kwargs): |
| super().__init__(fn=fn, **kwargs) |
| self.obj = obj |
| |
| def __str__(self): |
| return f"{self.__class__.__name__}({self.fn}, {self.obj})" |
| |
| def self_args(self): |
| return [self.obj] |
| |
| def python_type(self): |
| return types.MethodType |
| |
| def call_function( |
| self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]" |
| ) -> "VariableTracker": |
| # For nn.Module methods, redirecting to NNModuleVariable.call_method for optimized solution |
| # rather than simple inlining. E.g, putting `call_method` op in FX graph for `forward` method |
| # since we ensure `forward` of allowed modules can be traced by AOT safely. |
| # Note this is not only for allowed modules, as user customized modules can extend from |
| # allowed modules but using parent's `forward` method, which is also covered by this branch. |
| |
| # If we are tracing the higher order op, we want Dynamo to step inside |
| # the module call so that Dynamo can see the underlying parameters and |
| # buffers and raise them as inputs to the graph. The is_root_tracer |
| # check bypasses the if condition for non-root tracers and directly |
| # calls the super().call_function at the end, which is basically |
| # equivalent of inlining the method. |
| if tx.output.is_root_tracer() and isinstance( |
| self.obj, variables.NNModuleVariable |
| ): |
| module_attr = getattr(self.fn, "__module__", "") |
| if ( |
| module_attr is not None |
| and module_attr.startswith("torch.nn.") |
| or self.is_constant |
| ): |
| return self.obj.call_method( |
| tx, self.fn.__name__, args, kwargs, constant=self.is_constant |
| ).add_options(self) |
| return super().call_function(tx, args, kwargs) |
| |
| def num_parameters(self): |
| return super().num_parameters() - 1 |
| |
| |
| class WrappedUserMethodVariable(UserMethodVariable): |
| def __init__(self, wrapped, context, **kwargs): |
| kwargs.pop("fn", None) |
| kwargs.pop("obj", None) |
| super().__init__(wrapped.fn, wrapped.obj, **kwargs) |
| self.wrapped = wrapped |
| self.context = context |
| |
| def call_function( |
| self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]" |
| ) -> "VariableTracker": |
| self.context.enter(tx) |
| result = super().call_function(tx, args, kwargs) |
| self.context.exit(tx) |
| return result |
| |
| |
| class WrappedUserFunctionVariable(UserFunctionVariable): |
| def __init__(self, wrapped, context, **kwargs): |
| kwargs.pop("fn", None) |
| kwargs.pop("obj", None) |
| super().__init__(wrapped.fn, **kwargs) |
| self.wrapped = wrapped |
| self.context = context |
| |
| def call_function( |
| self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]" |
| ) -> "VariableTracker": |
| self.context.enter(tx) |
| result = super().call_function(tx, args, kwargs) |
| self.context.exit(tx) |
| return result |
| |
| |
| def invoke_and_store_as_constant(tx, fn, name, options, args, kwargs): |
| def convert(x): |
| if isinstance(x, variables.TensorVariable): |
| return x.get_real_value() |
| return x.as_python_constant() |
| |
| args = [convert(x) for x in args] |
| kwargs = {k: convert(v) for k, v in kwargs.items()} |
| res = fn(*args, **kwargs) |
| return tx.output.register_attr_or_module( |
| res, |
| name, |
| source=ConstantSource(name), |
| **options, |
| ) |
| |
| |
| class NestedUserFunctionVariable(BaseUserFunctionVariable): |
| _nonvar_fields = { |
| "closure_scope", |
| "f_globals", |
| *BaseUserFunctionVariable._nonvar_fields, |
| } |
| |
| def __init__( |
| self, |
| fn_name, |
| code, |
| f_globals, |
| defaults, |
| kwdefaults, |
| annotations, |
| closure, |
| closure_scope, |
| wraps_source=None, |
| **kwargs, |
| ): |
| super().__init__(**kwargs) |
| assert isinstance(fn_name.as_python_constant(), str) |
| assert isinstance(code.as_python_constant(), types.CodeType) |
| assert isinstance(f_globals, dict) |
| self.fn_name = fn_name |
| self.code = code |
| self.f_globals = f_globals |
| self.defaults = defaults |
| self.kwdefaults = kwdefaults |
| self.annotations = annotations |
| self.closure = closure |
| if closure is None: |
| closure_scope = None |
| self.closure_scope = closure_scope |
| self.wraps_source = wraps_source |
| |
| def self_args(self): |
| return [] |
| |
| def get_code(self): |
| return self.code.as_python_constant() |
| |
| def get_function(self): |
| if self.closure: |
| raise NotImplementedError() |
| func = types.FunctionType( |
| self.code.as_python_constant(), |
| self.f_globals, |
| self.fn_name.as_python_constant(), |
| ) |
| if self.defaults: |
| func.__defaults__ = self.defaults.as_python_constant() |
| if self.kwdefaults: |
| func.__kwdefaults__ = self.kwdefaults.as_python_constant() |
| if self.annotations: |
| annotations = self.annotations.as_python_constant() |
| if isinstance(annotations, tuple): |
| from itertools import pairwise |
| |
| annotations = dict(pairwise(annotations)) |
| |
| # TypeError: __annotations__ must be set to a dict object |
| assert isinstance(annotations, dict) |
| func.__annotations__ = annotations |
| return func |
| |
| def has_closure(self): |
| return self.closure is not None |
| |
| def has_self(self): |
| return False |
| |
| def get_globals(self): |
| return self.f_globals |
| |
| def bind_args(self, parent, args, kwargs): |
| from .misc import InlinedClosureVariable |
| |
| code = self.get_code() |
| func = types.FunctionType( |
| code, |
| self.f_globals, |
| self.fn_name.as_python_constant(), |
| tuple(self.defaults.items) if self.defaults else None, |
| tuple(make_cell(None) for _ in range(len(self.get_code().co_freevars))), |
| ) |
| if self.kwdefaults: |
| func.__kwdefaults__ = self.kwdefaults.items |
| bound = inspect.signature(func).bind(*args, **kwargs) |
| bound.apply_defaults() |
| result = dict(bound.arguments.items()) |
| wrap_args_kwargs(parent.output.root_tx, result, VariableTracker.propagate(self)) |
| closure_cells = init_cellvars(parent, result, code) |
| |
| for idx, name in enumerate(code.co_freevars): |
| cell = self.closure.items[idx] |
| assert getattr(cell, name, name) == name |
| assert name not in result |
| if isinstance(cell, InlinedClosureVariable): |
| # InlinedClosureVariable's are created from LOAD_CLOSURE's from |
| # InliningInstructionTranslators when the variable name is not found in closure_cells. |
| # They should remain outside of closure_cells, so that our callee (the |
| # InliningInstructionTranslator that traces `func`) handles |
| # the cell correctly - that is, the cell's contents are treated as if they |
| # are local variables, like in UserFunctionVariable's bind_args for freevars. |
| cand = parent |
| while cand and name not in cand.symbolic_locals: |
| cand = cand.parent |
| if cand is None: |
| raise RuntimeError( |
| f"Couldn't find {name} in the symbolic_locals of the inline interpreter stack" |
| ) |
| result[name] = cand.symbolic_locals[name] |
| else: |
| closure_cells[name] = self.closure.items[idx] |
| |
| return result, closure_cells |
| |
| def export_freevars(self, parent, child): |
| code = self.get_code() |
| for var in code.co_freevars: |
| if var in child.symbolic_locals: |
| parent.symbolic_locals[var] = child.symbolic_locals[var] |
| |
| def reconstruct(self, codegen): |
| codegen.load_import_from(__name__, "_create_nested_fn") |
| codegen(self.code) |
| codegen.extend_output([codegen._create_load_const(self.f_globals)]) |
| codegen(self.fn_name) |
| |
| if self.defaults: |
| codegen(self.defaults) |
| else: |
| codegen.extend_output([codegen.create_load_const(None)]) |
| |
| if self.closure: |
| codegen(self.closure) |
| else: |
| codegen.extend_output([codegen.create_load_const(None)]) |
| |
| if self.kwdefaults: |
| codegen(self.kwdefaults) |
| else: |
| codegen.extend_output([codegen.create_load_const(None)]) |
| |
| if self.annotations: |
| try: |
| if isinstance(self.annotations, variables.ConstDictVariable): |
| annotations = { |
| k: v.as_python_constant() |
| for k, v in self.annotations.items.items() |
| } |
| else: |
| annotations = tuple( |
| [v.as_python_constant() for v in self.annotations.items] |
| ) |
| codegen.extend_output([codegen._create_load_const(annotations)]) |
| except NotImplementedError: |
| codegen(self.annotations) |
| else: |
| codegen.extend_output([codegen.create_load_const(None)]) |
| |
| codegen.extend_output(create_call_function(7, push_null=True)) |
| |
| if self.wraps_source: |
| codegen.load_import_from("functools", "wraps") |
| codegen(self.wraps_source) |
| codegen.extend_output(create_call_function(1, True)) |
| codegen.extend_output(create_rot_n(2)) |
| codegen.extend_output(create_call_function(1, True)) |
| |
| return [] |
| |
| |
| def _traceable_collective_remaps(): |
| # We can't rely on importing from distributed, since its not always built |
| if torch.distributed.is_available(): |
| from torch.distributed._functional_collectives import ( |
| traceable_collective_remaps, |
| ) |
| |
| return traceable_collective_remaps |
| return {} |
| |
| |
| def _traceable_collectives_source(fn): |
| assert torch.distributed.is_available(), "Illegal invocation." |
| from torch.distributed._functional_collectives import ( |
| all_gather_tensor_inplace, |
| reduce_scatter_tensor_inplace, |
| ) |
| |
| valid_values = {all_gather_tensor_inplace, reduce_scatter_tensor_inplace} |
| assert fn in valid_values |
| inner_name = fn.__name__ |
| path_source = AttrSource( |
| base=AttrSource( |
| base=GlobalSource(global_name="__import_torch"), member="distributed" |
| ), |
| member="_functional_collectives", |
| ) |
| return AttrSource(path_source, inner_name) |
| |
| |
| class CollectiveFunctionRewriteVariable(UserFunctionVariable): |
| """ |
| Some of the torch.distributed.* collective APIs are possible to rewrite to 'traceable' collectives. |
| |
| This class provides both a way to check if a function is remappable, and perform the remapping. |
| |
| In the case that a function is 'remappable' but only for some combinations of call-time arguments, |
| we check the args at `call_function` time and fall back to graph-breaking if needed. This is no worse |
| than status-quo as we currently graph-break on all distributed.* collectives. |
| """ |
| |
| def __init__(self, fn, *, orig_fn, orig_source, **kwargs): |
| # orig_fn lets us implement any fn-specific args/kwargs restrictions inside call_function |
| self.orig_fn = orig_fn |
| self.orig_source = orig_source |
| |
| # remapped_fn gets stuffed in self.fn and used in super().call_function |
| super().__init__(fn, **kwargs) |
| |
| @staticmethod |
| def can_rewrite(variable): |
| return ( |
| inspect.isfunction(variable) and variable in _traceable_collective_remaps() |
| ) |
| |
| @staticmethod |
| def rewrite(fn): |
| new_fn = _traceable_collective_remaps()[fn] |
| return new_fn, _traceable_collectives_source(new_fn) |
| |
| def call_function( |
| self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]" |
| ) -> "VariableTracker": |
| # call_function must check any unsupported arguments and graph-break. |
| # It's safe to assume args/kwargs from orig_fn map 1:1 to args/kwargs of remapped_fn, |
| # since that's the contract for putting a mapping in `traceable_collective_remaps` |
| if kwargs.get("async_op", False): |
| # Put the old source back, this function will always graph break, but this ensures |
| # we produce the correct guards. |
| self.source = self.orig_source |
| unimplemented( |
| f"CollectiveFunctionRewriteVariable can't support async_op=True for {self.orig_fn}" |
| ) |
| return super().call_function(tx, args, kwargs) |
| |
| |
| class FunctoolsPartialVariable(VariableTracker): |
| def __init__(self, func, args, keywords, original=None, **kwargs): |
| super().__init__(**kwargs) |
| self.func = func |
| assert isinstance(args, list) |
| self.args = args |
| assert isinstance(keywords, dict) |
| self.keywords = keywords |
| self.original = original |
| |
| self.guards.update(VariableTracker.propagate(func)["guards"]) |
| for arg in args: |
| self.guards.update(VariableTracker.propagate(arg)["guards"]) |
| for val in keywords.values(): |
| self.guards.update(VariableTracker.propagate(val)["guards"]) |
| |
| def call_function( |
| self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]" |
| ) -> "VariableTracker": |
| options = VariableTracker.propagate([self]) |
| merged_args = self.args + args |
| merged_kwargs = {**self.keywords, **kwargs} |
| |
| return self.func.call_function(tx, merged_args, merged_kwargs).add_options( |
| options |
| ) |
| |
| def as_python_constant(self): |
| if self.original: |
| return self.original |
| else: |
| |
| def get_val(v): |
| if isinstance(v, variables.UserDefinedObjectVariable): |
| return v.value |
| else: |
| return v.as_python_constant() |
| |
| return functools.partial( |
| self.func.fn, |
| *[get_val(arg) for arg in self.args], |
| **{k: get_val(v) for k, v in self.keywords.items()}, |
| ) |
| |
| |
| class TritonKernelVariable(VariableTracker): |
| def __init__(self, kernel, kernel_idx, grid, **kwargs): |
| from triton.runtime.autotuner import Autotuner |
| |
| from torch._higher_order_ops.triton_kernel_wrap import kernel_side_table |
| |
| super().__init__(**kwargs) |
| |
| assert kernel is not None |
| |
| self.kernel = kernel |
| self.kernel_idx = kernel_side_table.add_kernel(kernel) |
| |
| assert kernel_idx is None or self.kernel_idx == kernel_idx |
| |
| self.grid = grid |
| |
| if isinstance(kernel, Autotuner): |
| # We only support configs and keys arguments of triton.autotune |
| # Make sure other arguments are defaulted |
| defaults = inspect.signature(Autotuner).parameters |
| if ( |
| defaults["warmup"].default != kernel.warmup |
| or defaults["rep"].default != kernel.rep |
| or defaults["prune_configs_by"].default != kernel.early_config_prune |
| ): |
| raise Unsupported( |
| "Only configs and keys are supported for triton.autotune" |
| ) |
| |
| def call_function( |
| self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]" |
| ) -> "VariableTracker": |
| from triton.runtime.autotuner import Autotuner |
| |
| from .constant import ConstantVariable |
| from .dicts import ConstDictVariable |
| from .lists import BaseListVariable |
| |
| if self.grid is None: |
| raise Unsupported("Triton kernels should always be called with a grid") |
| |
| # Both for grid's meta as well as for the kernel, we need combined |
| # args and kwargs normalized |
| normalized_args = {**dict(zip(self.kernel.arg_names, args)), **kwargs} |
| |
| configs = ( |
| [config.kwargs for config in self.kernel.configs] |
| if isinstance(self.kernel, Autotuner) |
| else [{}] |
| ) |
| grids = [] |
| for config_args in configs: |
| # If the grid is a function, then lets execute it and convert it to |
| # a list |
| grid = self.grid |
| if isinstance(grid, (NestedUserFunctionVariable, UserFunctionVariable)): |
| # Populate the special "meta" argument to call the grid function |
| config_args = { |
| k: ConstantVariable.create(v) for k, v in config_args.items() |
| } |
| meta = ConstDictVariable({**normalized_args, **config_args}, dict) |
| grid = grid.call_function(tx, [meta], {}) |
| |
| # Now, the grid must be a list either originally or through above |
| # modification |
| if isinstance(grid, BaseListVariable): |
| grids.append(grid.as_proxy()) |
| else: |
| unimplemented(f"grid for the triton kernel is {type(grid)}") |
| |
| for i in range(len(grids)): |
| if not isinstance(grids[i], tuple): |
| raise Unsupported("Only tuple grids are supported") |
| # inductor expects all grids to be 3-tuple so lets make it |
| if len(grids[i]) == 1: |
| grids[i] = (grids[i][0], 1, 1) |
| elif len(grids[i]) == 2: |
| grids[i] = (grids[i][0], grids[i][1], 1) |
| elif len(grids[i]) > 3: |
| raise Unsupported("Grid can have at most rank 3") |
| |
| assert len(grids) != 0 |
| if len(set(grids)) == 1: |
| # If there's only one unique grid, lets simplify |
| grids = [grids[0]] |
| |
| from torch._higher_order_ops.triton_kernel_wrap import ( |
| triton_kernel_wrapper_mutation, |
| ) |
| |
| # Combine args and kwargs and pass as a dict so that if user defined triton |
| # kernel uses variables as 'grid' or 'kernel', it does not conflict with |
| # parameters of the wrapper function |
| meta = ConstDictVariable(normalized_args, dict) |
| tx.output.create_proxy( |
| "call_function", |
| triton_kernel_wrapper_mutation, |
| (), |
| { |
| "kernel_idx": self.kernel_idx, |
| "grid": grids, |
| "kwargs": meta.as_proxy(), |
| }, |
| ) |
| |
| return variables.ConstantVariable( |
| None, |
| **VariableTracker.propagate(self, args, kwargs.values()), |
| ) |
| |
| def call_method( |
| self, |
| tx, |
| name, |
| args: "List[VariableTracker]", |
| kwargs: "Dict[str, VariableTracker]", |
| ) -> "VariableTracker": |
| if name == "__getitem__": |
| # __getitem__ should only be called if we don't already have a grid |
| # Only grid needs to be passed |
| if self.grid is not None or len(args) != 1: |
| raise Unsupported( |
| "Triton kernels should be called with only a single grid" |
| ) |
| |
| return TritonKernelVariable( |
| kernel=self.kernel, |
| kernel_idx=self.kernel_idx, |
| grid=args[0], |
| **VariableTracker.propagate(self), |
| ) |
| elif name == "run": |
| if "grid" not in kwargs: |
| raise Unsupported("Triton kernel requires to be called with a grid") |
| grid = kwargs.pop("grid") |
| return self.clone(grid=grid).call_function(tx, args, kwargs) |
| |
| # Bail out to parent's implementation |
| return super().call_method(tx, name, args, kwargs) |