| # mypy: allow-untyped-defs |
| # Copyright (c) Meta Platforms, Inc. and affiliates |
| import collections |
| import logging |
| import weakref |
| from typing import Any, cast, Deque, Dict, Iterator, List, Optional, Set, Tuple, Union |
| |
| import torch |
| from torch.autograd.graph import GradientEdge, Node |
| from torch.nn import Parameter |
| |
| from ._debug import map_debug_info |
| |
| |
| logger = logging.getLogger(__name__) |
| |
| |
| def _get_grad_fn_or_grad_acc(t: torch.Tensor) -> Union[Node, None]: |
| """ |
| Get the grad function or grad accumulator for a tensor. |
| |
| Accumulate grad nodes are lazily created, so we need to a |
| dummy view in order to trigger its creation. |
| """ |
| if t.requires_grad and t.grad_fn is None: |
| # if no grad function (leaf tensors) we use view |
| viewed_t = t.view_as(t) |
| grad_fn = viewed_t.grad_fn |
| if grad_fn is not None: |
| return grad_fn.next_functions[0][0] |
| else: |
| raise RuntimeError( |
| "Attempted to get grad_fn, but got None." |
| "Is this being created in a no-grad context?" |
| ) |
| else: |
| return t.grad_fn |
| |
| |
| def reverse_closure( |
| roots: List[Node], target_nodes: Set[Node] |
| ) -> Tuple[Set[Node], Set[Node]]: |
| """ |
| This function returns the reverse closure of the given roots, |
| i.e. the set of nodes that can be reached from the roots by following the |
| reverse edges of the graph. The target_nodes are the nodes that we want to |
| include in the closure. |
| """ |
| # Recurse until we reach a target node |
| closure: Set[Node] = set() |
| visited_target_nodes = set() |
| q: Deque[Node] = collections.deque() |
| for node in roots: |
| if node is not None and node not in closure: |
| closure.add(node) |
| q.append(node) |
| while q: |
| node = q.popleft() |
| metadata = cast(Dict[str, List], node.metadata) |
| reverse_edges = metadata.get("reverse_edges", []) |
| for holder_ref, idx in reverse_edges: |
| ref = holder_ref() |
| if ref is None: |
| # this reverse graph is no longer alive |
| # raise RuntimeError("Reverse graph is no longer alive") |
| continue |
| fn = ref.node |
| if fn in closure or fn is None: |
| continue |
| if fn in target_nodes: |
| visited_target_nodes.add(fn) |
| continue |
| closure.add(fn) |
| q.append(fn) |
| return closure, visited_target_nodes |
| |
| |
| # Enable weak pointer |
| class Holder: |
| def __init__(self, node: Node): |
| self.node = node |
| |
| |
| def construct_reverse_graph(roots: List[Node]) -> List[Holder]: |
| q: Deque[Node] = collections.deque() |
| root_seen: Set[Node] = set() |
| reverse_graph_refs: List[Holder] = [] |
| for node in roots: |
| if node is not None and node not in root_seen: |
| q.append(node) |
| root_seen.add(node) |
| while q: |
| node = q.popleft() |
| for fn, idx in node.next_functions: |
| if fn is not None: |
| # Don't necessarily need to store on the graph |
| metadata = cast(Dict[str, List], fn.metadata) |
| reverse_edges = metadata.get("reverse_edges", []) |
| if len(reverse_edges) == 0: |
| q.append(fn) |
| holder = Holder(node) |
| holder_ref = weakref.ref(holder) |
| reverse_graph_refs.append(holder) |
| reverse_edges.append((holder_ref, idx)) |
| metadata["reverse_edges"] = reverse_edges |
| return reverse_graph_refs |
| |
| |
| def get_param_groups(inputs: List[Node], params: List[Node]) -> List[Dict[str, Any]]: |
| """ |
| Given a list of inputs and a list of parameters, return a list of parameter |
| groups, where each group contains the parameters and the intermediates that |
| are connected to the parameters. |
| |
| The returned list of parameter groups is a list of dictionaries, where each |
| dictionary contains the following keys: |
| - "params": a set of parameters |
| - "intermediates": a set of intermediates |
| |
| The returned list of parameter groups is a list of dictionaries, |
| """ |
| # reverse graph that starts with inputs, and goes up to the dOutput or the loss, |
| # but omits weights and any subgraphs connecting weights to this closure |
| inputs_closure, _ = reverse_closure(inputs, set()) |
| param_groups: Dict[Node, Dict[str, Set]] = dict() # keyed on intermediates |
| for i, param in enumerate(params): |
| closure, intersected = reverse_closure([param], inputs_closure) |
| param_group: Dict[str, Set] = { |
| "params": {param}, |
| "intermediates": intersected, |
| } |
| for input_node in intersected: |
| existing = param_groups.get(input_node, None) |
| if existing is not None: |
| existing["params"] = existing["params"].union(param_group["params"]) |
| existing["intermediates"] = existing["intermediates"].union( |
| param_group["intermediates"] |
| ) |
| param_group = existing |
| else: |
| param_groups[input_node] = param_group |
| |
| # Sanity check: union of all param_groups params should be equal to all params |
| union_params: Set[Node] = set() |
| seen_ids: Set[int] = set() |
| unique_param_groups = [] |
| for param_group in param_groups.values(): |
| if id(param_group) not in seen_ids: |
| seen_ids.add(id(param_group)) |
| unique_param_groups.append(param_group) |
| union_params = union_params.union(param_group["params"]) |
| |
| # The assert will only be true if the input tensor requires gradients, |
| # otherwise the autograd graph will miss the first layer of inputs |
| # assert union_params == set(params) |
| return unique_param_groups |
| |
| |
| def stage_backward_input( |
| stage_outputs: List[torch.Tensor], |
| output_grads: Optional[List[torch.Tensor]], |
| input_values: List[torch.Tensor], |
| weights: Iterator[Parameter], |
| ): |
| """ |
| compute the gradients for only the stage inputs with respect to the stage outputs |
| """ |
| stage_output_grad_fns: List[Node] = list( |
| filter(None, map(_get_grad_fn_or_grad_acc, stage_outputs)) |
| ) |
| stage_input_grad_fns: List[Node] = list( |
| filter(None, map(_get_grad_fn_or_grad_acc, input_values)) |
| ) |
| weight_grad_fns: List[Node] = list( |
| filter(None, map(_get_grad_fn_or_grad_acc, weights)) |
| ) |
| |
| reverse_graph_refs = construct_reverse_graph(stage_output_grad_fns) |
| param_groups = get_param_groups(stage_input_grad_fns, weight_grad_fns) |
| del reverse_graph_refs |
| |
| for param_group in param_groups: |
| for i, intermediate in enumerate(param_group["intermediates"]): |
| |
| def get_hook(param_group, i): |
| def hook(grad_inputs): |
| if param_group.get("grads", None) is None: |
| param_group["grads"] = [None] * len( |
| param_group["intermediates"] |
| ) |
| param_group["grads"][i] = grad_inputs |
| |
| return hook |
| |
| # These are always "split" nodes that we need to recompute, so |
| # save their inputs. |
| intermediate.register_prehook(get_hook(param_group, i)) |
| |
| # Stage 0 inputs do not require grads? Should we skip in that case? |
| if all(tensor.requires_grad for tensor in input_values): |
| if output_grads is None: |
| # In case this is the loss and there are no output_grads, then we just use 1s |
| output_grads = [ |
| torch.ones_like(stage_output) for stage_output in stage_outputs |
| ] |
| |
| dinputs = torch.autograd.grad( |
| stage_outputs, |
| inputs=input_values, |
| grad_outputs=output_grads, |
| retain_graph=True, |
| ) |
| |
| # update the gradients for inputs |
| for i, inp in enumerate(input_values): |
| if inp.grad is None: |
| inp.grad = dinputs[i] |
| else: |
| inp.grad += dinputs[i] |
| else: |
| dinputs = None |
| return dinputs, param_groups |
| |
| |
| def stage_backward_weight( |
| weights: Iterator[Parameter], param_groups: List[Dict[str, Any]] |
| ): |
| # map weights to param_group_weights |
| grad_acc_to_weight = {} |
| weight_grads = [] |
| for index, weight in enumerate(weights): |
| grad_acc = _get_grad_fn_or_grad_acc(weight) |
| grad_acc_to_weight[grad_acc] = weight, index |
| weight_grads.append(weight.grad) |
| |
| for param_group in param_groups: |
| # TODO: Handle case where intermediate can have multiple outputs |
| intermediate_edges = tuple( |
| GradientEdge(i, 0) for i in param_group["intermediates"] |
| ) |
| weights_edges = tuple(GradientEdge(w, 0) for w in param_group["params"]) |
| |
| assert all(len(g) == 1 for g in param_group["grads"]) |
| # [NEW!] Able to pass a GradientEdge to autograd.grad as output |
| # We do not need to retain_graph because... guarantee no overlap? |
| # print("trying to execute: ", intermediate_edges, weights_edges) |
| dweights = torch.autograd.grad( |
| intermediate_edges, |
| weights_edges, |
| grad_outputs=sum(param_group["grads"], tuple()), |
| ) |
| for grad_acc, dw in zip(param_group["params"], dweights): |
| weight, index = grad_acc_to_weight[grad_acc] |
| if weight.grad is None: |
| weight.grad = dw |
| else: |
| weight.grad += dw |
| # return grads in the original order weights were provided in |
| return weight_grads |
| |
| |
| def stage_backward( |
| stage_output, |
| output_grads, |
| input_values, |
| outputs_with_grads_idxs: Optional[List[int]] = None, # deprecated, not used |
| ): |
| """ |
| This is a helper function to: |
| 1. compute the gradients for the stage inputs, and |
| 2. accumulate gradients for the stage module's parameters. |
| |
| Given the input value(s) and the corresponding gradient for the output |
| value(s), compute and accumulate gradients for all parameter values (leaves |
| in the autograd trace) as well as return a list of the gradients for the |
| input values |
| """ |
| if outputs_with_grads_idxs is not None: |
| # Deprecated, not used in runtime calls, only exists in compiler |
| stage_output = [stage_output[i] for i in outputs_with_grads_idxs] |
| output_grads = [output_grads[i] for i in outputs_with_grads_idxs] |
| |
| try: |
| # stage_output may be a composite datatype like dict. Extract all individual |
| # tensor values here |
| stage_output_tensors = [] |
| output_grad_tensors = [] |
| |
| def extract_tensors_with_grads(output_val, grad_val): |
| if isinstance(output_val, torch.Tensor): |
| if not output_val.requires_grad and output_val.grad_fn is None: |
| return |
| assert isinstance( |
| grad_val, (torch.Tensor, type(None)) |
| ), f"Expected Tensor or None gradient but got {type(grad_val)}" |
| stage_output_tensors.append(output_val) |
| output_grad_tensors.append(grad_val) |
| elif isinstance(output_val, (tuple, list)): |
| if grad_val is None: |
| return |
| assert isinstance( |
| grad_val, (tuple, list) |
| ), f"grad_value expected to have type {type(output_val)} but got {type(grad_val)}" |
| assert len(output_val) == len(grad_val) |
| for ov, gv in zip(output_val, grad_val): |
| extract_tensors_with_grads(ov, gv) |
| elif isinstance(output_val, dict): |
| if grad_val is None: |
| return |
| assert isinstance(grad_val, dict) |
| assert set(output_val.keys()) == set(grad_val.keys()) |
| for k in output_val.keys(): |
| extract_tensors_with_grads(output_val[k], grad_val[k]) |
| else: |
| # Output is a non-tensor type; just ignore it |
| pass |
| |
| extract_tensors_with_grads(stage_output, output_grads) |
| |
| torch.autograd.backward( |
| stage_output_tensors, grad_tensors=output_grad_tensors # type: ignore[arg-type] |
| ) |
| |
| # Extract gradients wrt the input values |
| grad_inputs = [] |
| for val in input_values: |
| if isinstance(val, torch.Tensor): |
| grad_inputs.append(val.grad) |
| else: |
| grad_inputs.append(None) |
| |
| # Alternative impl: `torch.autograd.grad`. |
| # Note that `torch.autograd.grad` will not accumulate gradients into the |
| # model's parameters. |
| """ |
| inputs_with_grad = [] |
| for val in input_values: |
| if isinstance(val, torch.Tensor) and val.requires_grad: |
| inputs_with_grad.append(val) |
| |
| grad_inputs = torch.autograd.grad( |
| stage_output_tensors, inputs_with_grad, output_grad_tensors, # type: ignore[arg-type] |
| ) |
| """ |
| |
| except Exception as e: |
| exc_msg = f""" |
| Failed to run stage backward: |
| Stage output: {map_debug_info(stage_output)} |
| Output gradient: {map_debug_info(output_grads)} |
| Input: {map_debug_info(input_values)} |
| """ |
| raise RuntimeError(exc_msg) from e |
| |
| return grad_inputs |
| |
| |
| # TODO: handling requires_grad=False dynamically. Can we analyze this during initial |
| # IR emission? |
| def _null_coalesce_accumulate(lhs, rhs): |
| """ |
| Coalesce two values, even if one of them is null, returning the non-null |
| value. |
| """ |
| if lhs is None: |
| return rhs |
| elif rhs is None: |
| return lhs |
| else: |
| return torch.add(lhs, rhs) |