Experimental MetaTensorTracer

Pull Request resolved: https://github.com/pytorch/pytorch/pull/76003

Approved by: https://github.com/jansel
diff --git a/test/test_fx_experimental.py b/test/test_fx_experimental.py
index 33e6abf..2ab9ae3 100644
--- a/test/test_fx_experimental.py
+++ b/test/test_fx_experimental.py
@@ -26,6 +26,7 @@
 )
 from torch.fx.experimental.rewriter import RewritingTracer
 from torch.fx.experimental.schema_type_annotation import AnnotateTypesWithSchema
+from torch.fx.experimental.meta_tracer import MetaTracer
 from torch.fx.graph_module import GraphModule
 from torch.fx.node import Node
 from torch.fx.operator_schemas import (
@@ -667,6 +668,28 @@
         # Confirm that the output is correct
         self.assertEqual(traced(3, 3), m(3, 3))
 
+    def test_meta_tracer(self):
+        mt = MetaTracer()
+
+        class MetaTracerTestModule(torch.nn.Module):
+            def __init__(self):
+                super().__init__()
+                self.emb = torch.nn.Embedding(num_embeddings=42, embedding_dim=16)
+                self.layernorm = torch.nn.LayerNorm(16)
+
+            def forward(self, x):
+                emb = self.emb(x)
+                lol = self.layernorm(emb)
+                return torch.relu(lol) if lol.shape[0] < 30 else torch.sigmoid(lol)
+
+        mttm = MetaTracerTestModule()
+        for BS in [15, 35]:
+            x = torch.zeros(BS, dtype=torch.long).random_(42)
+            graph = mt.trace(mttm, meta_args={'x' : x.to(device='meta')})
+            gm = torch.fx.GraphModule(mttm, graph)
+            torch.testing.assert_close(gm(x), mttm(x))
+
+
     def test_call_to_assert_with_msg(self):
         class M(torch.nn.Module):
             def forward(self, a, b):
diff --git a/torch/fx/experimental/meta_tracer.py b/torch/fx/experimental/meta_tracer.py
new file mode 100644
index 0000000..0180392
--- /dev/null
+++ b/torch/fx/experimental/meta_tracer.py
@@ -0,0 +1,256 @@
+import torch
+import torch.fx
+import warnings
+import functools
+import builtins
+
+from typing import Callable, Dict
+
+def embedding_override(self, input):
+    return torch.empty(*input.shape, self.weight.shape[-1], device='meta')
+
+
+def nn_layernorm_override(self, input):
+    return input
+
+
+def torch_relu_override(x):
+    return x
+
+
+def torch_nn_relu_override(self, x):
+    return x
+
+
+def functional_relu_override(x, inplace=False):
+    assert not inplace, 'dont support inplace functional.relu for metatensor analysis'
+    return x
+
+
+def torch_where_override(condition, x, y):
+    # torch.where returns the broadcasted tensor of condition, x, and y,
+    # so hack it by using addition
+    return condition.to(device='meta') + x.to(device='meta') + y.to(device='meta')
+
+
+def torch_abs_override(input, *, out=None):
+    assert out is None, 'Dont support in-place abs for MetaTensor analysis'
+    return input
+
+manual_meta_overrides : Dict[Callable, Callable] = {
+    torch.nn.Embedding: embedding_override,
+    torch.nn.LayerNorm: nn_layernorm_override,
+    torch.relu: torch_relu_override,
+    torch.nn.functional.relu: functional_relu_override,
+    torch.nn.ReLU: torch_nn_relu_override,
+    torch.where: torch_where_override,
+    torch.abs: torch_abs_override,
+}
+
+def gen_constructor_wrapper(target):
+    @functools.wraps(target)
+    def wrapper(*args, **kwargs):
+        proxy = None
+
+        def check_has_proxy(v):
+            if isinstance(v, torch.fx.Proxy):
+                nonlocal proxy
+                proxy = v
+        torch.fx.node.map_aggregate(args, check_has_proxy)
+        torch.fx.node.map_aggregate(kwargs, check_has_proxy)
+
+        if proxy is not None:
+            return proxy.tracer.create_proxy('call_function', target, args, kwargs)
+        else:
+            return target(*args, **kwargs)
+    return wrapper, target
+
+class MetaProxy(torch.fx.Proxy):
+    def install_tensor_meta(self, tensor_meta):
+        self._tensor_meta = tensor_meta
+
+    def size(self, dim=None):
+        if hasattr(self, '_tensor_meta') and self._tensor_meta is not None:
+            return self._tensor_meta.size(*[dim] if dim else [])
+        return self.tracer.create_proxy('call_method', 'size', (self, dim) if dim else (self,), {})
+
+    def dim(self):
+        if hasattr(self, '_tensor_meta') and self._tensor_meta is not None:
+            return self._tensor_meta.dim()
+        return self.tracer.create_proxy('call_method', 'dim', (self,), {})
+
+    @property
+    def shape(self):
+        if hasattr(self, '_tensor_meta') and self._tensor_meta is not None:
+            return self._tensor_meta.shape
+        return self.tracer.create_proxy('call_function', builtins.getattr, (self, 'shape'), {})
+
+    @property
+    def dtype(self):
+        if hasattr(self, '_tensor_meta') and self._tensor_meta is not None:
+            return self._tensor_meta.dtype
+        return self.tracer.create_proxy('call_function', builtins.getattr, (self, 'dtype'), {})
+
+    @property
+    def device(self):
+        # Hack so we can track when devices are used. During meta-tensor propagation,
+        # replace these values with a constant 'meta'
+        return MetaDeviceAttribute(self, 'device')
+
+    def __getattr__(self, k):
+        if k == '_tensor_meta':
+            return self.__getattribute__(k)
+        # note: not added to the graph yet, if this is a method call
+        # we peephole optimize to the method invocation
+        return MetaAttribute(self, k)
+
+class MetaAttribute(MetaProxy):
+    def __init__(self, root, attr: str):
+
+        self.root = root
+        self.attr = attr
+        self.tracer = root.tracer
+        self._node = None
+
+    @property
+    def node(self):
+        # the node for attributes is added lazily, since most will just be method calls
+        # which do not rely on the getitem call
+        if self._node is None:
+            self._node = self.tracer.create_proxy('call_function', getattr, (self.root, self.attr), {}).node
+        return self._node
+
+    def __call__(self, *args, **kwargs):
+        return self.tracer.create_proxy('call_method', self.attr, (self.root,) + args, kwargs)
+
+class MetaDeviceAttribute(MetaAttribute):
+    pass
+
+def proxys_to_metas(v):
+    if isinstance(v, MetaDeviceAttribute):
+        return 'meta'
+    if isinstance(v, torch.fx.Proxy):
+        assert isinstance(v, MetaProxy), f'Expected MetaProxy but got {type(v)}'
+        assert hasattr(v, '_tensor_meta'), 'MetaProxy does not have an associated meta'
+        return v._tensor_meta
+    return v
+
+class MetaTracer(torch.fx.Tracer):
+    allow_insert_stateless_mods : bool = True
+
+    _TORCH_METHODS_TO_PATCH = ['arange', 'zeros', 'ones', 'full_like', 'eye']
+
+    def create_proxy(self, kind, target, args, kwargs, name=None, type_expr=None, proxy_factory_fn=None):
+        rv = super().create_proxy(kind, target, args, kwargs, name, type_expr, proxy_factory_fn)
+
+        if kind == 'placeholder' and target in self.meta_args:
+            rv.install_tensor_meta(self.meta_args[target])
+            return rv
+
+        if target in self.orig_fns:
+            # NOTE: tensor constructors in PyTorch define the `device` argument as
+            # *kwargs-only*. That is why this works. If you add methods to
+            # _TORCH_METHODS_TO_PATCH that do not define `device` as kwarg-only,
+            # this will break and you will likely see issues where we cannot infer
+            # the size of the output.
+            if 'device' in kwargs:
+                kwargs['device'] = 'meta'
+
+        try:
+            args_metas = torch.fx.node.map_aggregate(args, proxys_to_metas)
+            kwargs_metas = torch.fx.node.map_aggregate(kwargs, proxys_to_metas)
+
+            if kind == 'call_function':
+                meta_target = manual_meta_overrides.get(target, target)
+                meta_out = meta_target(*args_metas, **kwargs_metas)
+            elif kind == 'call_method':
+                meta_out = getattr(args_metas[0], target)(*args_metas[1:], **kwargs_metas)
+            elif kind == 'call_module':
+                assert hasattr(self, 'orig_forward')
+                self._disable_module_getattr = True
+                try:
+                    mod = self.root.get_submodule(target)
+                    mod_type = type(mod)
+                    if mod_type in manual_meta_overrides:
+                        meta_out = manual_meta_overrides[mod_type](mod, *args_metas, **kwargs_metas)
+                    else:
+                        meta_out = self.orig_forward(*args_metas, **kwargs_metas)
+                finally:
+                    self._disable_module_getattr = False
+            elif kind == 'get_attr':
+                self._disable_module_getattr = True
+                try:
+                    attr_itr = self.root
+                    atoms = target.split('.')
+                    for atom in atoms:
+                        attr_itr = getattr(attr_itr, atom)
+                    assert isinstance(attr_itr, torch.Tensor)
+                    meta_out = attr_itr.to(device='meta')
+                finally:
+                    self._disable_module_getattr = False
+            else:
+                return rv
+
+            # TODO
+            assert isinstance(rv, torch.fx.Proxy), 'Dont support composite output yet'
+            rv.install_tensor_meta(meta_out)
+        except Exception as e:
+            warnings.warn(f'Could not compute metadata for {kind} target {target}: {e}')
+
+        return rv
+
+    def _module_getattr(self, attr, attr_val, parameter_proxy_cache):
+        if getattr(self, '_disable_module_getattr', False):
+            return attr_val
+        else:
+            return super()._module_getattr(attr, attr_val, parameter_proxy_cache)
+
+    def call_module(self, m, forward, args, kwargs):
+        self.orig_forward = forward
+        return super().call_module(m, forward, args, kwargs)
+
+    def _insert_module_as_submodule(self, mod: torch.nn.Module) -> str:
+        """
+        Helper method which tries to insert a module that was not declared as submodule.
+        """
+        idx = 0
+        mod_name = mod.__class__.__name__.lower()
+        path = f"{mod_name}_{idx}"
+        while hasattr(self.root, path):
+            path = f"{mod_name}_{idx}"
+            idx += 1
+
+        self.root.add_module(path, mod)
+        return path
+
+    def path_of_module(self, mod: torch.nn.Module) -> str:
+        try:
+            return super().path_of_module(mod)
+        except NameError as e:
+            if self.allow_insert_stateless_mods and len(list(mod.parameters())) == 0 and len(list(mod.buffers())) == 0:
+                path = self._insert_module_as_submodule(mod)
+                self.prev_module = path
+                return path
+            raise
+
+    def proxy(self, node):
+        return MetaProxy(node, self)
+
+    def trace(self, root, meta_args : Dict[str, torch.Tensor], concrete_args=None):
+        assert isinstance(meta_args, dict)
+        self.meta_args = meta_args
+
+        self.patched_torch_methods = {
+            target: gen_constructor_wrapper(getattr(torch, target)) for target in self._TORCH_METHODS_TO_PATCH
+        }
+        self.orig_fns = set()
+
+        for name, (wrapper, orig) in self.patched_torch_methods.items():
+            setattr(torch, name, wrapper)
+            self.orig_fns.add(orig)
+
+        try:
+            return super().trace(root, concrete_args)
+        finally:
+            for name, (_, orig) in self.patched_torch_methods.items():
+                setattr(torch, name, orig)