Use register_meta for everything in meta_registrations (#84297)
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84297
Approved by: https://github.com/Chillee
diff --git a/test/test_proxy_tensor.py b/test/test_proxy_tensor.py
index a67133b..8044221 100644
--- a/test/test_proxy_tensor.py
+++ b/test/test_proxy_tensor.py
@@ -891,7 +891,6 @@
symbolic_tensor_failures = {
# Needs complex-value support
xfail('polar'),
- xfail('complex'),
xfail('linalg.eig'),
xfail('__getitem__', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
xfail('__rmatmul__', ''), # aten.new_empty.default - couldn't find symbolic meta function/decomposition
@@ -1046,7 +1045,6 @@
xfail('linalg.tensorinv', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
xfail('linalg.tensorsolve', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
xfail('linalg.vander', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
- xfail('linalg.vecdot', ''), # Could not run 'aten::vdot' with arguments from the 'Meta' backend. This could be ...
xfail('linalg.vector_norm', ''), # TensorImpl do not have numel
xfail('logaddexp2', ''), # aten.logaddexp2.default - couldn't find symbolic meta function/decomposition
xfail('logaddexp', ''), # aten.logaddexp.default - couldn't find symbolic meta function/decomposition
@@ -1071,7 +1069,6 @@
xfail('mode', ''), # aten.mode.default - couldn't find symbolic meta function/decomposition
xfail('msort', ''), # aten.sort.default - couldn't find symbolic meta function/decomposition
xfail('mv', ''), # aten.mv.default - couldn't find symbolic meta function/decomposition
- xfail('nanmean', ''), # The underlying op of 'aten.stride' has no overload name '_schema'
xfail('nanquantile', ''), # Could not run 'aten::equal' with arguments from the 'Meta' backend.
xfail('narrow', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
xfail('native_layer_norm', ''), # Unexpected type <class 'torch.SymIntNode'> when computing elementwise type promot...
@@ -1243,7 +1240,6 @@
xfail('unfold', ''), # aten.unfold.default - couldn't find symbolic meta function/decomposition
xfail('var_mean', ''), # Unexpected type <class 'torch.SymIntNode'> when computing elementwise type promotion!
xfail('var', ''), # Unexpected type <class 'torch.SymIntNode'> when computing elementwise type promotion!
- xfail('vdot', ''), # aten.vdot.default - couldn't find symbolic meta function/decomposition
xfail('view_as_complex', ''), # aten.view_as_complex.default - couldn't find symbolic meta function/decomposition
xfail('view_as', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
xfail('vsplit', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
diff --git a/torch/_meta_registrations.py b/torch/_meta_registrations.py
index 269c0a1..e56456c 100644
--- a/torch/_meta_registrations.py
+++ b/torch/_meta_registrations.py
@@ -17,7 +17,7 @@
aten = torch.ops.aten
-meta_lib = torch.library.Library("aten", "IMPL", "Meta")
+_meta_lib_dont_use_me_use_register_meta = torch.library.Library("aten", "IMPL", "Meta")
meta_table = {}
@@ -32,7 +32,7 @@
if op._overloadname != "default"
else op.overloadpacket.__name__
)
- meta_lib.impl(name, f)
+ _meta_lib_dont_use_me_use_register_meta.impl(name, f)
tree_map(add_func, op)
return f
@@ -195,7 +195,7 @@
return utils.compute_reduction_output_shape(self.shape, dims)
-@torch.library.impl(meta_lib, "bernoulli.out")
+@register_meta(aten.bernoulli.out)
def meta_bernoulli(self, *, generator=None, out):
torch._resize_output_(out, self.size(), self.device)
return out
@@ -380,8 +380,7 @@
return repeats.new_empty(output_size)
-@torch.library.impl(meta_lib, "complex")
-@torch.library.impl(meta_lib, "complex.out")
+@register_meta([aten.complex.default, aten.complex.out])
@out_wrapper()
def meta_complex(real, imag):
assert real.dtype.is_floating_point
@@ -390,7 +389,7 @@
return real.new_empty(out_shape, dtype=corresponding_complex_dtype(real.dtype))
-@torch.library.impl(meta_lib, "vdot")
+@register_meta(aten.vdot.default)
def vdot(self, other):
if not self.is_complex:
return torch.dot(self, other)
@@ -539,7 +538,7 @@
return self.new_empty(self.size())
-@torch.library.impl(meta_lib, "_cdist_forward")
+@register_meta(aten._cdist_forward.default)
def meta_cdist_forward(x1, x2, p, compute_mode):
check(
x1.dim() >= 2,
@@ -575,7 +574,7 @@
return x1.new_empty(output_shape)
-@torch.library.impl(meta_lib, "_embedding_bag")
+@register_meta(aten._embedding_bag.default)
def meta_embedding_bag(
weight,
indices,
@@ -684,7 +683,7 @@
return self.new_empty((sz,))
-@torch.library.impl(meta_lib, "_embedding_bag_forward_only")
+@register_meta(aten._embedding_bag_forward_only.default)
def meta_embedding_bag_forward_only(weight, indices, offsets, *args):
output, offset2bag, bag_size, max_indices = meta_embedding_bag(
weight, indices, offsets, *args
@@ -735,7 +734,7 @@
)
-@torch.library.impl(meta_lib, "logical_not_")
+@register_meta(aten.logical_not_.default)
def meta_logical_not_(self):
return self