Fake Tensor refactors part 2 (#116345)

This should help trace time a bit.
This refactors `op_implementations` (which requires O(n) checks per op) to mostly use a dict with O(1) cost per op.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/116345
Approved by: https://github.com/yanboliang
diff --git a/torch/_subclasses/fake_tensor.py b/torch/_subclasses/fake_tensor.py
index dd0136e..c8e5d95 100644
--- a/torch/_subclasses/fake_tensor.py
+++ b/torch/_subclasses/fake_tensor.py
@@ -414,25 +414,36 @@
         )
 
 
-op_implementations = []
+op_implementations_dict = {}
+op_implementations_checks = []
 
 
 def register_op_impl(run_impl_check: Union[Callable[[OpOverload], bool], OpOverload]):
     def impl_decorator(op_impl):
-        global op_implementations
         if isinstance(run_impl_check, OpOverload):
-            op_implementations.append((lambda func: func == run_impl_check, op_impl))
+            assert (
+                run_impl_check not in op_implementations_dict
+            ), f"duplicate registration: {run_impl_check}"
+            op_implementations_dict[run_impl_check] = op_impl
+        elif isinstance(run_impl_check, (list, tuple)):
+            for op in run_impl_check:
+                register_op_impl(op)(op_impl)
         else:
-            op_implementations.append((run_impl_check, op_impl))
+            assert callable(run_impl_check)
+            op_implementations_checks.append((run_impl_check, op_impl))
 
         return op_impl
 
     return impl_decorator
 
 
-@register_op_impl(
-    lambda func: (_is_tensor_constructor(func) or func in _like_tensor_constructors)
-)
+@register_op_impl(op_implementations_dict.__contains__)
+def dispatch_to_op_implementations_dict(fake_mode, func, *args, **kwargs):
+    return op_implementations_dict[func](fake_mode, func, *args, **kwargs)
+
+
+@register_op_impl(_is_tensor_constructor)
+@register_op_impl([*_like_tensor_constructors])
 def constructors(fake_mode, func, *args, **kwargs):
     assert func not in _non_kwarg_device_constructors
     _, new_kwargs = normalize_function(
@@ -456,7 +467,8 @@
     return FakeTensor(fake_mode, r, out_device)
 
 
-@register_op_impl(lambda func: func in (aten.to.prim_Device, aten.to.device))
+@register_op_impl(aten.to.prim_Device)
+@register_op_impl(aten.to.device)
 def non_kwarg_to(fake_mode, func, *args, **kwargs):
     _, new_kwargs = normalize_function(
         func, args, kwargs, normalize_to_only_use_kwargs=True
@@ -533,7 +545,7 @@
     raise DynamicOutputShapeException(func)
 
 
-@register_op_impl(lambda func: func is aten.repeat_interleave.Tensor)
+@register_op_impl(aten.repeat_interleave.Tensor)
 def repeat_interleave_tensor(fake_mode, func, repeats, output_size=None):
     if output_size is None:
         if (
@@ -552,7 +564,7 @@
     return repeats.new_empty(output_size)
 
 
-@register_op_impl(lambda func: func is torch.ops.aten._local_scalar_dense.default)
+@register_op_impl(torch.ops.aten._local_scalar_dense.default)
 def local_scalar_dense(fake_mode, func, arg):
     if fake_mode.shape_env is None or not fake_mode.shape_env.allow_scalar_outputs:
         # Without symints/symfloats, cannot handle this
@@ -567,7 +579,7 @@
         raise NotImplementedError(f"local_scalar_dense/item NYI for {arg.dtype}")
 
 
-@register_op_impl(lambda func: func is torch.ops.aten.nonzero.default)
+@register_op_impl(torch.ops.aten.nonzero.default)
 def nonzero(fake_mode, func, arg):
     if (
         fake_mode.shape_env is None
@@ -611,7 +623,7 @@
     return arg.new_empty((arg.nonzero_memo, arg.dim()), dtype=torch.int64)
 
 
-@register_op_impl(lambda func: func is torch.ops.aten.masked_select.default)
+@register_op_impl(torch.ops.aten.masked_select.default)
 def masked_select(fake_mode, func, self, mask):
     if (
         fake_mode.shape_env is None
@@ -632,7 +644,7 @@
     )
 
     if not has_free_symbols(self.numel()):
-        if self.numel() >= 2:
+        if self.numel() > 2:
             maxval = int(self.numel())
 
     _constrain_range_for_size(nnz, max=maxval)
@@ -789,14 +801,33 @@
         return out
 
 
-@register_op_impl(lambda fn: fn in _device_not_kwarg_ops)
+@register_op_impl(aten._nested_tensor_from_tensor_list.default)
+@register_op_impl(aten._nested_tensor_from_tensor_list.out)
+def nested_tensors_unsupported(fake_mode, func, *args, **kwargs):
+    raise UnsupportedOperatorException(
+        "torch.compile does not support strided NestedTensor"
+    )
+
+
+@register_op_impl(
+    [
+        x
+        for x in _device_not_kwarg_ops
+        if x
+        not in (
+            # these are already registered elsewhere
+            aten.to.device,
+            aten.to.prim_Device,
+            aten._nested_tensor_from_tensor_list.default,
+            aten._nested_tensor_from_tensor_list.out,
+        )
+    ]
+)
 def nyi(fake_mode, func, *args, **kwargs):
     assert func not in _device_not_kwarg_ops, f"NYI: {func}"
 
 
-@register_op_impl(
-    lambda func: func in (aten.convolution.default, aten.convolution_backward.default)
-)
+@register_op_impl([aten.convolution.default, aten.convolution_backward.default])
 def conv(fake_mode, func, *args, **kwargs):
     _, kwargs = normalize_function(
         func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
@@ -1675,14 +1706,7 @@
         # special handling for funcs registered through `register_op_impl`,
         # e.g., manipulating args on constructor calls to construct meta tensors
         # and then afterwards wrapping them to a FakeTensor
-        for run_impl_check, op_impl in op_implementations:
-            if func in (
-                aten._nested_tensor_from_tensor_list.default,
-                aten._nested_tensor_from_tensor_list.out,
-            ):
-                raise UnsupportedOperatorException(
-                    "torch.compile does not support strided NestedTensor"
-                )
+        for run_impl_check, op_impl in op_implementations_checks:
             if run_impl_check(func):
                 op_impl_out = op_impl(self, func, *args, **kwargs)
                 if op_impl_out != NotImplemented: