| import itertools |
| import operator |
| |
| import torch |
| |
| from torch._subclasses import FakeTensorMode # noqa: F401 |
| from torch.fx.node import map_aggregate |
| from torch.fx.passes.shape_prop import _extract_tensor_metadata, ShapeProp |
| from torch.multiprocessing.reductions import StorageWeakRef |
| |
| |
| class ShapeAliasingAndMutationProp(ShapeProp): |
| def __init__(self, *args, **kwargs): |
| super(ShapeAliasingAndMutationProp, self).__init__(*args, **kwargs) |
| self.input_alias_groups = set() |
| self.storage_to_alias_group = dict() |
| self.make_alias_group = itertools.count(1) |
| self.name = "ShapeAliasingAndMutation" |
| |
| def tensor_alias_group(self, value: torch.Tensor): |
| """Assign a unique identifier to the storage of a given tensor""" |
| storage = StorageWeakRef(value._typed_storage()) |
| alias_group = self.storage_to_alias_group.get(storage) |
| if alias_group is None: |
| alias_group = next(self.make_alias_group) |
| self.storage_to_alias_group[storage] = alias_group |
| return alias_group |
| |
| def placeholder(self, target, args, kwargs): |
| value = super().placeholder(target, args, kwargs) |
| assert isinstance(value, torch.Tensor) |
| self.input_alias_groups.add(self.tensor_alias_group(value)) |
| return value |
| |
| def run_node(self, n: torch.fx.Node): |
| args, kwargs = self.fetch_args_kwargs_from_env(n) |
| tensor_args = self.extract_tensors((args, kwargs)) |
| |
| input_versions1 = [obj._version for obj in tensor_args] |
| result = getattr(self, n.op)(n.target, args, kwargs) |
| input_versions2 = [obj._version for obj in tensor_args] |
| |
| n.meta["type"] = type(result) |
| n.meta["alias_groups"] = { |
| self.tensor_alias_group(obj) for obj in self.extract_tensors(result) |
| } |
| |
| if ( |
| not n.meta["alias_groups"] |
| and n.op == "call_function" |
| and n.target == operator.setitem |
| ): |
| n.meta["alias_groups"] = {self.tensor_alias_group(tensor_args[0])} |
| |
| n.meta["mutates_alias_groups"] = { |
| self.tensor_alias_group(tensor) |
| for tensor, v1, v2 in zip(tensor_args, input_versions1, input_versions2) |
| if v1 != v2 |
| } |
| # Partial mutation refers to the mutation caused by getitem that can |
| # potentially result in changing only a slice of the original tensor |
| n.meta["partial_mutation"] = False |
| |
| def visit_arg(arg: torch.fx.Node): |
| if ( |
| arg.op == "call_function" and arg.target == operator.getitem |
| ) or arg.meta["partial_mutation"]: |
| if bool(n.meta["mutates_alias_groups"] & arg.meta["alias_groups"]): |
| n.meta["partial_mutation"] = True |
| |
| torch.fx.map_arg((n.args, n.kwargs), visit_arg) |
| n.meta["is_input_alias"] = bool( |
| self.input_alias_groups & n.meta["alias_groups"] |
| ) |
| n.meta["is_input_mutation"] = bool( |
| self.input_alias_groups & n.meta["mutates_alias_groups"] |
| ) |
| n.meta["is_mutation"] = bool(n.meta["mutates_alias_groups"]) |
| n.meta["tensor_metas"] = [ |
| _extract_tensor_metadata(obj) for obj in self.extract_tensors(result) |
| ] |
| tensors = self.extract_tensors(result) |
| if tensors: |
| n.meta["device"] = tensors[0].device |
| n.meta["dtype"] = tensors[0].dtype |
| |
| return result |
| |
| @staticmethod |
| def extract_tensors(result): |
| """Return a flat list of tensors found in some nested data structure""" |
| seen = set() |
| tensors = [] |
| |
| def visit(obj): |
| if isinstance(obj, torch.Tensor) and id(obj) not in seen: |
| seen.add(id(obj)) |
| tensors.append(obj) |
| |
| map_aggregate(result, visit) |
| return tensors |
| |
| def run(self, *args): |
| try: |
| super().run(*args) |
| finally: |
| # cleanup |
| self.env.clear() |