Revert "Use register_meta for everything in meta_registrations (#84297)"
This reverts commit 8cd296f6804727899b39198d1641055b64f99056.
Reverted https://github.com/pytorch/pytorch/pull/84297 on behalf of https://github.com/suo due to broke test_proxy_tensor on master
diff --git a/test/test_proxy_tensor.py b/test/test_proxy_tensor.py
index e42dc0f..d835050 100644
--- a/test/test_proxy_tensor.py
+++ b/test/test_proxy_tensor.py
@@ -891,6 +891,7 @@
symbolic_tensor_failures = {
# Needs complex-value support
xfail('polar'),
+ xfail('complex'),
xfail('linalg.eig'),
xfail('__getitem__', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
xfail('__rmatmul__', ''), # aten.new_empty.default - couldn't find symbolic meta function/decomposition
@@ -1045,6 +1046,7 @@
xfail('linalg.tensorinv', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
xfail('linalg.tensorsolve', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
xfail('linalg.vander', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
+ xfail('linalg.vecdot', ''), # Could not run 'aten::vdot' with arguments from the 'Meta' backend. This could be ...
xfail('linalg.vector_norm', ''), # TensorImpl do not have numel
xfail('logaddexp2', ''), # aten.logaddexp2.default - couldn't find symbolic meta function/decomposition
xfail('logaddexp', ''), # aten.logaddexp.default - couldn't find symbolic meta function/decomposition
@@ -1241,6 +1243,7 @@
xfail('unfold', ''), # aten.unfold.default - couldn't find symbolic meta function/decomposition
xfail('var_mean', ''), # Unexpected type <class 'torch.SymIntNode'> when computing elementwise type promotion!
xfail('var', ''), # Unexpected type <class 'torch.SymIntNode'> when computing elementwise type promotion!
+ xfail('vdot', ''), # aten.vdot.default - couldn't find symbolic meta function/decomposition
xfail('view_as_complex', ''), # aten.view_as_complex.default - couldn't find symbolic meta function/decomposition
xfail('view_as', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
xfail('vsplit', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
diff --git a/torch/_meta_registrations.py b/torch/_meta_registrations.py
index e56456c..269c0a1 100644
--- a/torch/_meta_registrations.py
+++ b/torch/_meta_registrations.py
@@ -17,7 +17,7 @@
aten = torch.ops.aten
-_meta_lib_dont_use_me_use_register_meta = torch.library.Library("aten", "IMPL", "Meta")
+meta_lib = torch.library.Library("aten", "IMPL", "Meta")
meta_table = {}
@@ -32,7 +32,7 @@
if op._overloadname != "default"
else op.overloadpacket.__name__
)
- _meta_lib_dont_use_me_use_register_meta.impl(name, f)
+ meta_lib.impl(name, f)
tree_map(add_func, op)
return f
@@ -195,7 +195,7 @@
return utils.compute_reduction_output_shape(self.shape, dims)
-@register_meta(aten.bernoulli.out)
+@torch.library.impl(meta_lib, "bernoulli.out")
def meta_bernoulli(self, *, generator=None, out):
torch._resize_output_(out, self.size(), self.device)
return out
@@ -380,7 +380,8 @@
return repeats.new_empty(output_size)
-@register_meta([aten.complex.default, aten.complex.out])
+@torch.library.impl(meta_lib, "complex")
+@torch.library.impl(meta_lib, "complex.out")
@out_wrapper()
def meta_complex(real, imag):
assert real.dtype.is_floating_point
@@ -389,7 +390,7 @@
return real.new_empty(out_shape, dtype=corresponding_complex_dtype(real.dtype))
-@register_meta(aten.vdot.default)
+@torch.library.impl(meta_lib, "vdot")
def vdot(self, other):
if not self.is_complex:
return torch.dot(self, other)
@@ -538,7 +539,7 @@
return self.new_empty(self.size())
-@register_meta(aten._cdist_forward.default)
+@torch.library.impl(meta_lib, "_cdist_forward")
def meta_cdist_forward(x1, x2, p, compute_mode):
check(
x1.dim() >= 2,
@@ -574,7 +575,7 @@
return x1.new_empty(output_shape)
-@register_meta(aten._embedding_bag.default)
+@torch.library.impl(meta_lib, "_embedding_bag")
def meta_embedding_bag(
weight,
indices,
@@ -683,7 +684,7 @@
return self.new_empty((sz,))
-@register_meta(aten._embedding_bag_forward_only.default)
+@torch.library.impl(meta_lib, "_embedding_bag_forward_only")
def meta_embedding_bag_forward_only(weight, indices, offsets, *args):
output, offset2bag, bag_size, max_indices = meta_embedding_bag(
weight, indices, offsets, *args
@@ -734,7 +735,7 @@
)
-@register_meta(aten.logical_not_.default)
+@torch.library.impl(meta_lib, "logical_not_")
def meta_logical_not_(self):
return self