Adds opinfo for pdist
Per title
Pull Request resolved: https://github.com/pytorch/pytorch/pull/75878
Approved by: https://github.com/ngimel
diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py
index 825c4f8..e034a00 100644
--- a/torch/testing/_internal/common_methods_invocations.py
+++ b/torch/testing/_internal/common_methods_invocations.py
@@ -8118,6 +8118,22 @@
)
return sample_inputs
+def sample_inputs_pdist(op_info, device, dtype, requires_grad, **kwargs):
+ make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+ yield from (SampleInput(make_input((n, m))) for n, m in itertools.product((1, S), repeat=2))
+ yield from (SampleInput(make_input((S, S)), kwargs=dict(p=p)) for p in (0.0, 1.0, 2.0, 10.0, float("inf")))
+
+def reference_pdist(input, p=2):
+ pdist = scipy.spatial.distance.pdist
+ if p == 0:
+ output = pdist(input, "hamming") * input.shape[1]
+ elif p == float("inf"):
+ output = pdist(input, lambda x, y: np.abs(x - y).max())
+ else:
+ output = pdist(input, "minkowski", p=p)
+ return output.astype(input.dtype)
+
def sample_inputs_diagflat(op_info, device, dtype, requires_grad, **kwargs):
make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
@@ -16752,6 +16768,13 @@
)
),
OpInfo(
+ "nn.functional.pdist",
+ ref=reference_pdist,
+ sample_inputs_func=sample_inputs_pdist,
+ dtypes=floating_types(),
+ supports_out=False,
+ supports_gradgrad=False),
+ OpInfo(
"nn.functional.poisson_nll_loss",
ref=_NOTHING,
dtypes=all_types_and(torch.bfloat16),