Add meta registrations for some foreach ops (#102225)
as title
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102225
Approved by: https://github.com/ngimel
diff --git a/torch/_meta_registrations.py b/torch/_meta_registrations.py
index e265439..11ab86d 100644
--- a/torch/_meta_registrations.py
+++ b/torch/_meta_registrations.py
@@ -1871,6 +1871,23 @@
return [torch.empty_like(e) for e in exponent]
+@register_meta([aten._foreach_addcdiv_.Tensor, aten._foreach_addcmul_.Tensor])
+def meta__foreach_addcop_tensor(self, tensor1, tensor2, scalars):
+ check(
+ all(isinstance(l, List) for l in [self, tensor1, tensor2])
+ and isinstance(scalars, torch.Tensor),
+ lambda: (
+ "_foreach_addc*_ op expects arguments of type: List[Tensor], List[Tensor], List[Tensor], tensor, "
+ f"but got: {type(self)}, {type(tensor1)}, {type(tensor2)}, and {type(scalars)}"
+ ),
+ )
+ check(len(self) > 0, lambda: "input tensor list must not be empty.")
+ check(
+ len(self) == len(tensor1) and len(self) == len(tensor2),
+ lambda: "All input tensor lists must have the same length",
+ )
+
+
@register_meta([aten._fused_adam_.default])
def meta__fused_adam_(
self,