Use register_meta for everything in meta_registrations (#84297)

Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84297
Approved by: https://github.com/Chillee
diff --git a/test/test_proxy_tensor.py b/test/test_proxy_tensor.py
index a67133b..8044221 100644
--- a/test/test_proxy_tensor.py
+++ b/test/test_proxy_tensor.py
@@ -891,7 +891,6 @@
 symbolic_tensor_failures = {
     # Needs complex-value support
     xfail('polar'),
-    xfail('complex'),
     xfail('linalg.eig'),
     xfail('__getitem__', ''),  # aten.size.default - couldn't find symbolic meta function/decomposition
     xfail('__rmatmul__', ''),  # aten.new_empty.default - couldn't find symbolic meta function/decomposition
@@ -1046,7 +1045,6 @@
     xfail('linalg.tensorinv', ''),  # aten.size.default - couldn't find symbolic meta function/decomposition
     xfail('linalg.tensorsolve', ''),  # aten.size.default - couldn't find symbolic meta function/decomposition
     xfail('linalg.vander', ''),  # aten.size.default - couldn't find symbolic meta function/decomposition
-    xfail('linalg.vecdot', ''),  # Could not run 'aten::vdot' with arguments from the 'Meta' backend. This could be ...
     xfail('linalg.vector_norm', ''),  # TensorImpl do not have numel
     xfail('logaddexp2', ''),  # aten.logaddexp2.default - couldn't find symbolic meta function/decomposition
     xfail('logaddexp', ''),  # aten.logaddexp.default - couldn't find symbolic meta function/decomposition
@@ -1071,7 +1069,6 @@
     xfail('mode', ''),  # aten.mode.default - couldn't find symbolic meta function/decomposition
     xfail('msort', ''),  # aten.sort.default - couldn't find symbolic meta function/decomposition
     xfail('mv', ''),  # aten.mv.default - couldn't find symbolic meta function/decomposition
-    xfail('nanmean', ''),  # The underlying op of 'aten.stride' has no overload name '_schema'
     xfail('nanquantile', ''),  # Could not run 'aten::equal' with arguments from the 'Meta' backend.
     xfail('narrow', ''),  # aten.size.default - couldn't find symbolic meta function/decomposition
     xfail('native_layer_norm', ''),  # Unexpected type <class 'torch.SymIntNode'> when computing elementwise type promot...
@@ -1243,7 +1240,6 @@
     xfail('unfold', ''),  # aten.unfold.default - couldn't find symbolic meta function/decomposition
     xfail('var_mean', ''),  # Unexpected type <class 'torch.SymIntNode'> when computing elementwise type promotion!
     xfail('var', ''),  # Unexpected type <class 'torch.SymIntNode'> when computing elementwise type promotion!
-    xfail('vdot', ''),  # aten.vdot.default - couldn't find symbolic meta function/decomposition
     xfail('view_as_complex', ''),  # aten.view_as_complex.default - couldn't find symbolic meta function/decomposition
     xfail('view_as', ''),  # aten.size.default - couldn't find symbolic meta function/decomposition
     xfail('vsplit', ''),  # aten.size.default - couldn't find symbolic meta function/decomposition
diff --git a/torch/_meta_registrations.py b/torch/_meta_registrations.py
index 269c0a1..e56456c 100644
--- a/torch/_meta_registrations.py
+++ b/torch/_meta_registrations.py
@@ -17,7 +17,7 @@
 
 aten = torch.ops.aten
 
-meta_lib = torch.library.Library("aten", "IMPL", "Meta")
+_meta_lib_dont_use_me_use_register_meta = torch.library.Library("aten", "IMPL", "Meta")
 
 meta_table = {}
 
@@ -32,7 +32,7 @@
                     if op._overloadname != "default"
                     else op.overloadpacket.__name__
                 )
-                meta_lib.impl(name, f)
+                _meta_lib_dont_use_me_use_register_meta.impl(name, f)
 
         tree_map(add_func, op)
         return f
@@ -195,7 +195,7 @@
     return utils.compute_reduction_output_shape(self.shape, dims)
 
 
-@torch.library.impl(meta_lib, "bernoulli.out")
+@register_meta(aten.bernoulli.out)
 def meta_bernoulli(self, *, generator=None, out):
     torch._resize_output_(out, self.size(), self.device)
     return out
@@ -380,8 +380,7 @@
     return repeats.new_empty(output_size)
 
 
-@torch.library.impl(meta_lib, "complex")
-@torch.library.impl(meta_lib, "complex.out")
+@register_meta([aten.complex.default, aten.complex.out])
 @out_wrapper()
 def meta_complex(real, imag):
     assert real.dtype.is_floating_point
@@ -390,7 +389,7 @@
     return real.new_empty(out_shape, dtype=corresponding_complex_dtype(real.dtype))
 
 
-@torch.library.impl(meta_lib, "vdot")
+@register_meta(aten.vdot.default)
 def vdot(self, other):
     if not self.is_complex:
         return torch.dot(self, other)
@@ -539,7 +538,7 @@
     return self.new_empty(self.size())
 
 
-@torch.library.impl(meta_lib, "_cdist_forward")
+@register_meta(aten._cdist_forward.default)
 def meta_cdist_forward(x1, x2, p, compute_mode):
     check(
         x1.dim() >= 2,
@@ -575,7 +574,7 @@
     return x1.new_empty(output_shape)
 
 
-@torch.library.impl(meta_lib, "_embedding_bag")
+@register_meta(aten._embedding_bag.default)
 def meta_embedding_bag(
     weight,
     indices,
@@ -684,7 +683,7 @@
     return self.new_empty((sz,))
 
 
-@torch.library.impl(meta_lib, "_embedding_bag_forward_only")
+@register_meta(aten._embedding_bag_forward_only.default)
 def meta_embedding_bag_forward_only(weight, indices, offsets, *args):
     output, offset2bag, bag_size, max_indices = meta_embedding_bag(
         weight, indices, offsets, *args
@@ -735,7 +734,7 @@
     )
 
 
-@torch.library.impl(meta_lib, "logical_not_")
+@register_meta(aten.logical_not_.default)
 def meta_logical_not_(self):
     return self