Fake Tensor refactors part 2 (#116345)
This should help trace time a bit.
This refactors `op_implementations` (which requires O(n) checks per op) to mostly use a dict with O(1) cost per op.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/116345
Approved by: https://github.com/yanboliang
diff --git a/torch/_subclasses/fake_tensor.py b/torch/_subclasses/fake_tensor.py
index dd0136e..c8e5d95 100644
--- a/torch/_subclasses/fake_tensor.py
+++ b/torch/_subclasses/fake_tensor.py
@@ -414,25 +414,36 @@
)
-op_implementations = []
+op_implementations_dict = {}
+op_implementations_checks = []
def register_op_impl(run_impl_check: Union[Callable[[OpOverload], bool], OpOverload]):
def impl_decorator(op_impl):
- global op_implementations
if isinstance(run_impl_check, OpOverload):
- op_implementations.append((lambda func: func == run_impl_check, op_impl))
+ assert (
+ run_impl_check not in op_implementations_dict
+ ), f"duplicate registration: {run_impl_check}"
+ op_implementations_dict[run_impl_check] = op_impl
+ elif isinstance(run_impl_check, (list, tuple)):
+ for op in run_impl_check:
+ register_op_impl(op)(op_impl)
else:
- op_implementations.append((run_impl_check, op_impl))
+ assert callable(run_impl_check)
+ op_implementations_checks.append((run_impl_check, op_impl))
return op_impl
return impl_decorator
-@register_op_impl(
- lambda func: (_is_tensor_constructor(func) or func in _like_tensor_constructors)
-)
+@register_op_impl(op_implementations_dict.__contains__)
+def dispatch_to_op_implementations_dict(fake_mode, func, *args, **kwargs):
+ return op_implementations_dict[func](fake_mode, func, *args, **kwargs)
+
+
+@register_op_impl(_is_tensor_constructor)
+@register_op_impl([*_like_tensor_constructors])
def constructors(fake_mode, func, *args, **kwargs):
assert func not in _non_kwarg_device_constructors
_, new_kwargs = normalize_function(
@@ -456,7 +467,8 @@
return FakeTensor(fake_mode, r, out_device)
-@register_op_impl(lambda func: func in (aten.to.prim_Device, aten.to.device))
+@register_op_impl(aten.to.prim_Device)
+@register_op_impl(aten.to.device)
def non_kwarg_to(fake_mode, func, *args, **kwargs):
_, new_kwargs = normalize_function(
func, args, kwargs, normalize_to_only_use_kwargs=True
@@ -533,7 +545,7 @@
raise DynamicOutputShapeException(func)
-@register_op_impl(lambda func: func is aten.repeat_interleave.Tensor)
+@register_op_impl(aten.repeat_interleave.Tensor)
def repeat_interleave_tensor(fake_mode, func, repeats, output_size=None):
if output_size is None:
if (
@@ -552,7 +564,7 @@
return repeats.new_empty(output_size)
-@register_op_impl(lambda func: func is torch.ops.aten._local_scalar_dense.default)
+@register_op_impl(torch.ops.aten._local_scalar_dense.default)
def local_scalar_dense(fake_mode, func, arg):
if fake_mode.shape_env is None or not fake_mode.shape_env.allow_scalar_outputs:
# Without symints/symfloats, cannot handle this
@@ -567,7 +579,7 @@
raise NotImplementedError(f"local_scalar_dense/item NYI for {arg.dtype}")
-@register_op_impl(lambda func: func is torch.ops.aten.nonzero.default)
+@register_op_impl(torch.ops.aten.nonzero.default)
def nonzero(fake_mode, func, arg):
if (
fake_mode.shape_env is None
@@ -611,7 +623,7 @@
return arg.new_empty((arg.nonzero_memo, arg.dim()), dtype=torch.int64)
-@register_op_impl(lambda func: func is torch.ops.aten.masked_select.default)
+@register_op_impl(torch.ops.aten.masked_select.default)
def masked_select(fake_mode, func, self, mask):
if (
fake_mode.shape_env is None
@@ -632,7 +644,7 @@
)
if not has_free_symbols(self.numel()):
- if self.numel() >= 2:
+ if self.numel() > 2:
maxval = int(self.numel())
_constrain_range_for_size(nnz, max=maxval)
@@ -789,14 +801,33 @@
return out
-@register_op_impl(lambda fn: fn in _device_not_kwarg_ops)
+@register_op_impl(aten._nested_tensor_from_tensor_list.default)
+@register_op_impl(aten._nested_tensor_from_tensor_list.out)
+def nested_tensors_unsupported(fake_mode, func, *args, **kwargs):
+ raise UnsupportedOperatorException(
+ "torch.compile does not support strided NestedTensor"
+ )
+
+
+@register_op_impl(
+ [
+ x
+ for x in _device_not_kwarg_ops
+ if x
+ not in (
+ # these are already registered elsewhere
+ aten.to.device,
+ aten.to.prim_Device,
+ aten._nested_tensor_from_tensor_list.default,
+ aten._nested_tensor_from_tensor_list.out,
+ )
+ ]
+)
def nyi(fake_mode, func, *args, **kwargs):
assert func not in _device_not_kwarg_ops, f"NYI: {func}"
-@register_op_impl(
- lambda func: func in (aten.convolution.default, aten.convolution_backward.default)
-)
+@register_op_impl([aten.convolution.default, aten.convolution_backward.default])
def conv(fake_mode, func, *args, **kwargs):
_, kwargs = normalize_function(
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
@@ -1675,14 +1706,7 @@
# special handling for funcs registered through `register_op_impl`,
# e.g., manipulating args on constructor calls to construct meta tensors
# and then afterwards wrapping them to a FakeTensor
- for run_impl_check, op_impl in op_implementations:
- if func in (
- aten._nested_tensor_from_tensor_list.default,
- aten._nested_tensor_from_tensor_list.out,
- ):
- raise UnsupportedOperatorException(
- "torch.compile does not support strided NestedTensor"
- )
+ for run_impl_check, op_impl in op_implementations_checks:
if run_impl_check(func):
op_impl_out = op_impl(self, func, *args, **kwargs)
if op_impl_out != NotImplemented: