[primTorch] Add ref for `triplet_margin_loss`, improve `triplet_margin_with_distance_loss` (#85614)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/85614
Approved by: https://github.com/lezcano, https://github.com/mruberry
diff --git a/aten/src/ATen/native/Loss.cpp b/aten/src/ATen/native/Loss.cpp
index 027af18..52569ba 100644
--- a/aten/src/ATen/native/Loss.cpp
+++ b/aten/src/ATen/native/Loss.cpp
@@ -157,15 +157,17 @@
   auto n_dim = negative.dim();
   TORCH_CHECK(
       a_dim == p_dim && p_dim == n_dim,
-      "All inputs should have same dimension but got ",
-      a_dim,
-      "D, ",
-      p_dim,
-      "D and ",
-      n_dim,
-      "D inputs.")
+      "The anchor, positive, and negative tensors are expected to have "
+      "the same number of dimensions, but got: anchor ", a_dim, "D, "
+      "positive ", p_dim, "D, and negative ", n_dim, "D inputs")
+
   auto dist_pos = at::pairwise_distance(anchor, positive, p, eps);
   auto dist_neg = at::pairwise_distance(anchor, negative, p, eps);
+  // The distance swap is described in the paper "Learning shallow
+  // convolutional feature descriptors with triplet losses" by V. Balntas, E.
+  // Riba et al.  If True, and if the positive example is closer to the
+  // negative example than the anchor is, swaps the positive example and the
+  // anchor in the loss computation.
   if (swap) {
     auto dist_swap = at::pairwise_distance(positive, negative, p, eps);
     dist_neg = at::min(dist_neg, dist_swap);
diff --git a/test/test_nn.py b/test/test_nn.py
index 084d593..681efab 100644
--- a/test/test_nn.py
+++ b/test/test_nn.py
@@ -9202,21 +9202,6 @@
         self.assertEqual(F.triplet_margin_loss(input1, input2, input3, swap=True, reduction='none'),
                          loss_reference_fns['TripletMarginLoss'](input1, input2, input3, swap=True, reduction='none'))
 
-    def test_triplet_margin_loss_invalid(self):
-        input1 = torch.randn(5, 10, requires_grad=True)
-        input2 = torch.randn(5, 10, requires_grad=True)
-        input3 = torch.randn(5, 10, requires_grad=True)
-        input_1d = torch.randn(10, requires_grad=True)
-
-        with self.assertRaisesRegex(RuntimeError, "All inputs should have same dimension"):
-            F.triplet_margin_loss(input1, input2, input_1d)
-
-        with self.assertRaisesRegex(RuntimeError, "All inputs should have same dimension"):
-            F.triplet_margin_loss(input1, input_1d, input3)
-
-        with self.assertRaisesRegex(RuntimeError, "All inputs should have same dimension"):
-            F.triplet_margin_loss(input_1d, input2, input3)
-
     def test_pointwise_loss_target_grad_none_reduction(self):
         i = torch.randn(5, 10)
         t = torch.randn(5, 10, requires_grad=True)
diff --git a/test/test_ops.py b/test/test_ops.py
index 4ecbc59..1f6dbd1 100644
--- a/test/test_ops.py
+++ b/test/test_ops.py
@@ -1623,6 +1623,7 @@
         '_refs.broadcast_shapes',
         '_refs.broadcast_tensors',
         '_refs.nn.functional.tanhshrink',
+        '_refs.nn.functional.triplet_margin_loss',
         '_refs.rfloordiv',
         '_refs.rtruediv',
         '_refs.rpow',
diff --git a/torch/_refs/__init__.py b/torch/_refs/__init__.py
index 9a0d391..b6ddad3 100644
--- a/torch/_refs/__init__.py
+++ b/torch/_refs/__init__.py
@@ -114,6 +114,7 @@
     "bitwise_or",
     "bitwise_right_shift",
     "bitwise_xor",
+    "clamp_min",
     # "complex",
     "copysign",
     "div",
diff --git a/torch/_refs/nn/functional/__init__.py b/torch/_refs/nn/functional/__init__.py
index b3f698f..bd146e9 100644
--- a/torch/_refs/nn/functional/__init__.py
+++ b/torch/_refs/nn/functional/__init__.py
@@ -1,4 +1,4 @@
-from typing import Optional, Union
+from typing import Callable, Optional, Union
 
 import torch
 
@@ -46,6 +46,7 @@
     "softshrink",
     "tanhshrink",
     "threshold",
+    "triplet_margin_loss",
     "glu",
     "pairwise_distance",
     "pdist",
@@ -362,7 +363,8 @@
     Reference implementation of torch.nn.functional.l1_loss
     """
     if size_average is not None or reduce is not None:
-        # TODO: raise exception instead of converting value
+        # TODO: Raise exception instead of converting value.  This is only for
+        # primTorch since it can drop support for deprecated arguments.
         # msg = "size_average and reduce args are deprecated, please use reduction argument."
         reduction = _get_string_reduction_arg(size_average=size_average, reduce=reduce)
     _check_reduction_value(reduction)
@@ -406,7 +408,8 @@
     reduction: str = "mean",
 ) -> TensorLikeType:
     if size_average is not None or reduce is not None:
-        # TODO: raise exception instead of converting value
+        # TODO: Raise exception instead of converting value.  This is only for
+        # primTorch since it can drop support for deprecated arguments.
         # msg = "size_average and reduce args are deprecated, please use reduction argument."
         reduction = _get_string_reduction_arg(size_average=size_average, reduce=reduce)
     _check_reduction_value(reduction)
@@ -501,6 +504,84 @@
     return torch.where(a <= threshold, value, a)
 
 
+# CompositeImplicitAutograd - don't register decomp
+# No elementwise type promotion - core op doesn't explicitly type promote
+def triplet_margin_loss(
+    anchor: TensorLikeType,
+    positive: TensorLikeType,
+    negative: TensorLikeType,
+    margin: float = 1.0,
+    p: float = 2,
+    eps: float = 1e-6,
+    swap: bool = False,
+    size_average: Optional[bool] = None,
+    reduce: Optional[bool] = None,
+    reduction: str = "mean",
+) -> TensorLikeType:
+    if size_average is not None or reduce is not None:
+        # TODO: Raise exception instead of converting value.  This is only for
+        # primTorch since it can drop support for deprecated arguments.
+        # msg = "size_average and reduce args are deprecated, please use reduction argument."
+        reduction = _get_string_reduction_arg(size_average=size_average, reduce=reduce)
+
+    # torch.nn.functional.triplet_margin_with_distance_loss has no ref defined
+    # since it's a pure Python implementation.  Use this helper instead.
+    return _triplet_margin_with_distance_loss(
+        anchor=anchor,
+        positive=positive,
+        negative=negative,
+        distance_function=lambda x, y: torch.pairwise_distance(x, y, p, eps),
+        margin=margin,
+        swap=swap,
+        reduction=reduction,
+    )
+
+
+# Pure Python impl - don't register decomp and don't add a ref.  Defined as a
+# helper here since triplet_margin_loss can be nicely implemented with it.
+def _triplet_margin_with_distance_loss(
+    anchor: TensorLikeType,
+    positive: TensorLikeType,
+    negative: TensorLikeType,
+    *,
+    distance_function: Optional[
+        Callable[[TensorLikeType, TensorLikeType], TensorLikeType]
+    ] = None,
+    margin: float = 1.0,
+    swap: bool = False,
+    reduction: str = "mean",
+) -> TensorLikeType:
+    _check_reduction_value(reduction)
+
+    a_dim = anchor.ndim
+    p_dim = positive.ndim
+    n_dim = negative.ndim
+    check(
+        a_dim == p_dim and p_dim == n_dim,
+        lambda: (
+            f"The anchor, positive, and negative tensors are expected to have "
+            f"the same number of dimensions, but got: anchor {a_dim}D, "
+            f"positive {p_dim}D, and negative {n_dim}D inputs"
+        ),
+    )
+
+    if distance_function is None:
+        distance_function = torch.pairwise_distance
+
+    dist_pos = distance_function(anchor, positive)
+    dist_neg = distance_function(anchor, negative)
+    # The distance swap is described in the paper "Learning shallow
+    # convolutional feature descriptors with triplet losses" by V. Balntas, E.
+    # Riba et al.  If True, and if the positive example is closer to the
+    # negative example than the anchor is, swaps the positive example and the
+    # anchor in the loss computation.
+    if swap:
+        dist_swap = distance_function(positive, negative)
+        dist_neg = torch.minimum(dist_neg, dist_swap)
+    loss = torch.clamp_min(margin + dist_pos - dist_neg, 0)
+    return _apply_loss_reduction(loss, reduction)
+
+
 @register_decomposition(torch.ops.aten.hardtanh)
 @elementwise_unary_scalar_wrapper
 @elementwise_type_promotion_wrapper(
@@ -582,7 +663,8 @@
     Reference implementation of torch.nn.functional.poisson_nll_loss
     """
     if size_average is not None or reduce is not None:
-        # TODO: raise exception instead of converting value
+        # TODO: Raise exception instead of converting value.  This is only for
+        # primTorch since it can drop support for deprecated arguments.
         # msg = "size_average and reduce args are deprecated, please use reduction argument."
         reduction = _get_string_reduction_arg(size_average=size_average, reduce=reduce)
     _check_reduction_value(reduction)
diff --git a/torch/nn/functional.py b/torch/nn/functional.py
index a7de8b2..f069765 100644
--- a/torch/nn/functional.py
+++ b/torch/nn/functional.py
@@ -4583,24 +4583,43 @@
             reduction=reduction,
         )
 
-    distance_function = distance_function if distance_function is not None else pairwise_distance
+    # Check validity of reduction mode
+    if reduction not in ("mean", "sum", "none"):
+        raise ValueError(f"{reduction} is not a valid value for reduction")
 
-    positive_dist = distance_function(anchor, positive)
-    negative_dist = distance_function(anchor, negative)
+    # Check dimensions
+    a_dim = anchor.ndim
+    p_dim = positive.ndim
+    n_dim = negative.ndim
+    if not (a_dim == p_dim and p_dim == n_dim):
+        raise RuntimeError(
+            (f"The anchor, positive, and negative tensors are expected to have "
+             f"the same number of dimensions, but got: anchor {a_dim}D, "
+             f"positive {p_dim}D, and negative {n_dim}D inputs"))
 
+    # Calculate loss
+    if distance_function is None:
+        distance_function = torch.pairwise_distance
+
+    dist_pos = distance_function(anchor, positive)
+    dist_neg = distance_function(anchor, negative)
+    # The distance swap is described in the paper "Learning shallow
+    # convolutional feature descriptors with triplet losses" by V. Balntas, E.
+    # Riba et al.  If True, and if the positive example is closer to the
+    # negative example than the anchor is, swaps the positive example and the
+    # anchor in the loss computation.
     if swap:
-        swap_dist = distance_function(positive, negative)
-        negative_dist = torch.min(negative_dist, swap_dist)
+        dist_swap = distance_function(positive, negative)
+        dist_neg = torch.minimum(dist_neg, dist_swap)
+    loss = torch.clamp_min(margin + dist_pos - dist_neg, 0)
 
-    output = torch.clamp(positive_dist - negative_dist + margin, min=0.0)
-
-    reduction_enum = _Reduction.get_enum(reduction)
-    if reduction_enum == 1:
-        return output.mean()
-    elif reduction_enum == 2:
-        return output.sum()
-    else:
-        return output
+    # Apply reduction
+    if reduction == "sum":
+        return torch.sum(loss)
+    elif reduction == "mean":
+        return torch.mean(loss)
+    else:  # reduction == "none"
+        return loss
 
 
 def normalize(input: Tensor, p: float = 2.0, dim: int = 1, eps: float = 1e-12, out: Optional[Tensor] = None) -> Tensor:
diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py
index 4cb8e80..71a9dab 100644
--- a/torch/testing/_internal/common_methods_invocations.py
+++ b/torch/testing/_internal/common_methods_invocations.py
@@ -7386,6 +7386,58 @@
             kwargs["distance_function"] = torch.nn.PairwiseDistance()
         yield SampleInput(input, args=args, kwargs=kwargs)
 
+def error_inputs_triplet_margin_loss(op_info, device, **kwargs):
+    make_input = partial(make_tensor, device=device, dtype=torch.float32)
+
+    samples = (
+        # input, args, kwargs, error_type, error_regex
+        # invalid reduction
+        (make_input(3, 4), (make_input(3, 4), make_input(3, 4)),
+         dict(reduction="abc"),
+         ValueError, "abc is not a valid value for reduction"),
+
+        # shape mismatch
+        (make_input(3, 5), (make_input(3, 4), make_input(3, 4)),
+         dict(),
+         RuntimeError,
+         (r"The size of tensor a \(5\) must match the size of tensor b \(4\) "
+          r"at non-singleton dimension 1")),
+        (make_input(3, 4), (make_input(3, 5), make_input(3, 4)),
+         dict(),
+         RuntimeError,
+         (r"The size of tensor a \(4\) must match the size of tensor b \(5\) "
+          r"at non-singleton dimension 1")),
+        (make_input(3, 4), (make_input(3, 4), make_input(3, 5)),
+         dict(),
+         RuntimeError,
+         (r"The size of tensor a \(4\) must match the size of tensor b \(5\) "
+          r"at non-singleton dimension 1")),
+
+        # different dimensions
+        (make_input(3,), (make_input(3, 4), make_input(3, 4)),
+         dict(),
+         RuntimeError,
+         (r"The anchor, positive, and negative tensors are expected to have "
+          r"the same number of dimensions, but got: anchor 1D, positive 2D, "
+          r"and negative 2D inputs")),
+        (make_input(3, 4), (make_input(3,), make_input(3, 4)),
+         dict(),
+         RuntimeError,
+         (r"The anchor, positive, and negative tensors are expected to have "
+          r"the same number of dimensions, but got: anchor 2D, positive 1D, "
+          r"and negative 2D inputs")),
+        (make_input(3, 4), (make_input(3, 4), make_input(3,)),
+         dict(),
+         RuntimeError,
+         (r"The anchor, positive, and negative tensors are expected to have "
+          r"the same number of dimensions, but got: anchor 2D, positive 2D, "
+          r"and negative 1D inputs")),
+    )
+
+    for input, args, kwargs, error_type, error_regex in samples:
+        yield ErrorInput(SampleInput(input, args=args, kwargs=kwargs),
+                         error_type=error_type, error_regex=error_regex)
+
 
 def sample_inputs_scaled_dot_product_attention(op_info, device, dtype, requires_grad, **kwargs):
     make = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
@@ -12101,6 +12153,7 @@
     OpInfo(
         "nn.functional.triplet_margin_loss",
         sample_inputs_func=sample_inputs_triplet_margin_loss,
+        error_inputs_func=error_inputs_triplet_margin_loss,
         dtypes=all_types_and_complex_and(torch.bfloat16),
         dtypesIfCUDA=all_types_and_complex_and(torch.float16, torch.bfloat16),
         supports_out=False,
@@ -12110,6 +12163,7 @@
     OpInfo(
         "nn.functional.triplet_margin_with_distance_loss",
         sample_inputs_func=partial(sample_inputs_triplet_margin_loss, with_distance=True),
+        error_inputs_func=error_inputs_triplet_margin_loss,
         dtypes=all_types_and_complex_and(torch.bfloat16),
         dtypesIfCUDA=all_types_and_complex_and(torch.float16, torch.bfloat16),
         supports_out=False,
@@ -17572,6 +17626,20 @@
         torch_opinfo_name="clamp",
         supports_nvfuser=False,
     ),
+    PythonRefInfo(
+        "_refs.nn.functional.triplet_margin_loss",
+        torch_opinfo_name="nn.functional.triplet_margin_loss",
+        supports_out=False,
+        # TODO: Uses minimum and clamp, which don't support nvfuser.
+        supports_nvfuser=False,
+        skips=(
+            # AssertionError: Tensor-likes are not close!
+            # Greatest absolute difference: 6.103515625e-05 at index (4,) (up to 1e-05 allowed)
+            # Greatest relative difference: 8.519846983548175e-06 at index (4,) (up to 1.3e-06 allowed)
+            DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_python_ref',
+                         dtypes=(torch.uint8,), device_type="cpu"),
+        )
+    ),
     #
     # Data Conversion & Data Movement Opinfos
     #