[primTorch] Add ref for `triplet_margin_loss`, improve `triplet_margin_with_distance_loss` (#85614)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85614
Approved by: https://github.com/lezcano, https://github.com/mruberry
diff --git a/aten/src/ATen/native/Loss.cpp b/aten/src/ATen/native/Loss.cpp
index 027af18..52569ba 100644
--- a/aten/src/ATen/native/Loss.cpp
+++ b/aten/src/ATen/native/Loss.cpp
@@ -157,15 +157,17 @@
auto n_dim = negative.dim();
TORCH_CHECK(
a_dim == p_dim && p_dim == n_dim,
- "All inputs should have same dimension but got ",
- a_dim,
- "D, ",
- p_dim,
- "D and ",
- n_dim,
- "D inputs.")
+ "The anchor, positive, and negative tensors are expected to have "
+ "the same number of dimensions, but got: anchor ", a_dim, "D, "
+ "positive ", p_dim, "D, and negative ", n_dim, "D inputs")
+
auto dist_pos = at::pairwise_distance(anchor, positive, p, eps);
auto dist_neg = at::pairwise_distance(anchor, negative, p, eps);
+ // The distance swap is described in the paper "Learning shallow
+ // convolutional feature descriptors with triplet losses" by V. Balntas, E.
+ // Riba et al. If True, and if the positive example is closer to the
+ // negative example than the anchor is, swaps the positive example and the
+ // anchor in the loss computation.
if (swap) {
auto dist_swap = at::pairwise_distance(positive, negative, p, eps);
dist_neg = at::min(dist_neg, dist_swap);
diff --git a/test/test_nn.py b/test/test_nn.py
index 084d593..681efab 100644
--- a/test/test_nn.py
+++ b/test/test_nn.py
@@ -9202,21 +9202,6 @@
self.assertEqual(F.triplet_margin_loss(input1, input2, input3, swap=True, reduction='none'),
loss_reference_fns['TripletMarginLoss'](input1, input2, input3, swap=True, reduction='none'))
- def test_triplet_margin_loss_invalid(self):
- input1 = torch.randn(5, 10, requires_grad=True)
- input2 = torch.randn(5, 10, requires_grad=True)
- input3 = torch.randn(5, 10, requires_grad=True)
- input_1d = torch.randn(10, requires_grad=True)
-
- with self.assertRaisesRegex(RuntimeError, "All inputs should have same dimension"):
- F.triplet_margin_loss(input1, input2, input_1d)
-
- with self.assertRaisesRegex(RuntimeError, "All inputs should have same dimension"):
- F.triplet_margin_loss(input1, input_1d, input3)
-
- with self.assertRaisesRegex(RuntimeError, "All inputs should have same dimension"):
- F.triplet_margin_loss(input_1d, input2, input3)
-
def test_pointwise_loss_target_grad_none_reduction(self):
i = torch.randn(5, 10)
t = torch.randn(5, 10, requires_grad=True)
diff --git a/test/test_ops.py b/test/test_ops.py
index 4ecbc59..1f6dbd1 100644
--- a/test/test_ops.py
+++ b/test/test_ops.py
@@ -1623,6 +1623,7 @@
'_refs.broadcast_shapes',
'_refs.broadcast_tensors',
'_refs.nn.functional.tanhshrink',
+ '_refs.nn.functional.triplet_margin_loss',
'_refs.rfloordiv',
'_refs.rtruediv',
'_refs.rpow',
diff --git a/torch/_refs/__init__.py b/torch/_refs/__init__.py
index 9a0d391..b6ddad3 100644
--- a/torch/_refs/__init__.py
+++ b/torch/_refs/__init__.py
@@ -114,6 +114,7 @@
"bitwise_or",
"bitwise_right_shift",
"bitwise_xor",
+ "clamp_min",
# "complex",
"copysign",
"div",
diff --git a/torch/_refs/nn/functional/__init__.py b/torch/_refs/nn/functional/__init__.py
index b3f698f..bd146e9 100644
--- a/torch/_refs/nn/functional/__init__.py
+++ b/torch/_refs/nn/functional/__init__.py
@@ -1,4 +1,4 @@
-from typing import Optional, Union
+from typing import Callable, Optional, Union
import torch
@@ -46,6 +46,7 @@
"softshrink",
"tanhshrink",
"threshold",
+ "triplet_margin_loss",
"glu",
"pairwise_distance",
"pdist",
@@ -362,7 +363,8 @@
Reference implementation of torch.nn.functional.l1_loss
"""
if size_average is not None or reduce is not None:
- # TODO: raise exception instead of converting value
+ # TODO: Raise exception instead of converting value. This is only for
+ # primTorch since it can drop support for deprecated arguments.
# msg = "size_average and reduce args are deprecated, please use reduction argument."
reduction = _get_string_reduction_arg(size_average=size_average, reduce=reduce)
_check_reduction_value(reduction)
@@ -406,7 +408,8 @@
reduction: str = "mean",
) -> TensorLikeType:
if size_average is not None or reduce is not None:
- # TODO: raise exception instead of converting value
+ # TODO: Raise exception instead of converting value. This is only for
+ # primTorch since it can drop support for deprecated arguments.
# msg = "size_average and reduce args are deprecated, please use reduction argument."
reduction = _get_string_reduction_arg(size_average=size_average, reduce=reduce)
_check_reduction_value(reduction)
@@ -501,6 +504,84 @@
return torch.where(a <= threshold, value, a)
+# CompositeImplicitAutograd - don't register decomp
+# No elementwise type promotion - core op doesn't explicitly type promote
+def triplet_margin_loss(
+ anchor: TensorLikeType,
+ positive: TensorLikeType,
+ negative: TensorLikeType,
+ margin: float = 1.0,
+ p: float = 2,
+ eps: float = 1e-6,
+ swap: bool = False,
+ size_average: Optional[bool] = None,
+ reduce: Optional[bool] = None,
+ reduction: str = "mean",
+) -> TensorLikeType:
+ if size_average is not None or reduce is not None:
+ # TODO: Raise exception instead of converting value. This is only for
+ # primTorch since it can drop support for deprecated arguments.
+ # msg = "size_average and reduce args are deprecated, please use reduction argument."
+ reduction = _get_string_reduction_arg(size_average=size_average, reduce=reduce)
+
+ # torch.nn.functional.triplet_margin_with_distance_loss has no ref defined
+ # since it's a pure Python implementation. Use this helper instead.
+ return _triplet_margin_with_distance_loss(
+ anchor=anchor,
+ positive=positive,
+ negative=negative,
+ distance_function=lambda x, y: torch.pairwise_distance(x, y, p, eps),
+ margin=margin,
+ swap=swap,
+ reduction=reduction,
+ )
+
+
+# Pure Python impl - don't register decomp and don't add a ref. Defined as a
+# helper here since triplet_margin_loss can be nicely implemented with it.
+def _triplet_margin_with_distance_loss(
+ anchor: TensorLikeType,
+ positive: TensorLikeType,
+ negative: TensorLikeType,
+ *,
+ distance_function: Optional[
+ Callable[[TensorLikeType, TensorLikeType], TensorLikeType]
+ ] = None,
+ margin: float = 1.0,
+ swap: bool = False,
+ reduction: str = "mean",
+) -> TensorLikeType:
+ _check_reduction_value(reduction)
+
+ a_dim = anchor.ndim
+ p_dim = positive.ndim
+ n_dim = negative.ndim
+ check(
+ a_dim == p_dim and p_dim == n_dim,
+ lambda: (
+ f"The anchor, positive, and negative tensors are expected to have "
+ f"the same number of dimensions, but got: anchor {a_dim}D, "
+ f"positive {p_dim}D, and negative {n_dim}D inputs"
+ ),
+ )
+
+ if distance_function is None:
+ distance_function = torch.pairwise_distance
+
+ dist_pos = distance_function(anchor, positive)
+ dist_neg = distance_function(anchor, negative)
+ # The distance swap is described in the paper "Learning shallow
+ # convolutional feature descriptors with triplet losses" by V. Balntas, E.
+ # Riba et al. If True, and if the positive example is closer to the
+ # negative example than the anchor is, swaps the positive example and the
+ # anchor in the loss computation.
+ if swap:
+ dist_swap = distance_function(positive, negative)
+ dist_neg = torch.minimum(dist_neg, dist_swap)
+ loss = torch.clamp_min(margin + dist_pos - dist_neg, 0)
+ return _apply_loss_reduction(loss, reduction)
+
+
@register_decomposition(torch.ops.aten.hardtanh)
@elementwise_unary_scalar_wrapper
@elementwise_type_promotion_wrapper(
@@ -582,7 +663,8 @@
Reference implementation of torch.nn.functional.poisson_nll_loss
"""
if size_average is not None or reduce is not None:
- # TODO: raise exception instead of converting value
+ # TODO: Raise exception instead of converting value. This is only for
+ # primTorch since it can drop support for deprecated arguments.
# msg = "size_average and reduce args are deprecated, please use reduction argument."
reduction = _get_string_reduction_arg(size_average=size_average, reduce=reduce)
_check_reduction_value(reduction)
diff --git a/torch/nn/functional.py b/torch/nn/functional.py
index a7de8b2..f069765 100644
--- a/torch/nn/functional.py
+++ b/torch/nn/functional.py
@@ -4583,24 +4583,43 @@
reduction=reduction,
)
- distance_function = distance_function if distance_function is not None else pairwise_distance
+ # Check validity of reduction mode
+ if reduction not in ("mean", "sum", "none"):
+ raise ValueError(f"{reduction} is not a valid value for reduction")
- positive_dist = distance_function(anchor, positive)
- negative_dist = distance_function(anchor, negative)
+ # Check dimensions
+ a_dim = anchor.ndim
+ p_dim = positive.ndim
+ n_dim = negative.ndim
+ if not (a_dim == p_dim and p_dim == n_dim):
+ raise RuntimeError(
+ (f"The anchor, positive, and negative tensors are expected to have "
+ f"the same number of dimensions, but got: anchor {a_dim}D, "
+ f"positive {p_dim}D, and negative {n_dim}D inputs"))
+ # Calculate loss
+ if distance_function is None:
+ distance_function = torch.pairwise_distance
+
+ dist_pos = distance_function(anchor, positive)
+ dist_neg = distance_function(anchor, negative)
+ # The distance swap is described in the paper "Learning shallow
+ # convolutional feature descriptors with triplet losses" by V. Balntas, E.
+ # Riba et al. If True, and if the positive example is closer to the
+ # negative example than the anchor is, swaps the positive example and the
+ # anchor in the loss computation.
if swap:
- swap_dist = distance_function(positive, negative)
- negative_dist = torch.min(negative_dist, swap_dist)
+ dist_swap = distance_function(positive, negative)
+ dist_neg = torch.minimum(dist_neg, dist_swap)
+ loss = torch.clamp_min(margin + dist_pos - dist_neg, 0)
- output = torch.clamp(positive_dist - negative_dist + margin, min=0.0)
-
- reduction_enum = _Reduction.get_enum(reduction)
- if reduction_enum == 1:
- return output.mean()
- elif reduction_enum == 2:
- return output.sum()
- else:
- return output
+ # Apply reduction
+ if reduction == "sum":
+ return torch.sum(loss)
+ elif reduction == "mean":
+ return torch.mean(loss)
+ else: # reduction == "none"
+ return loss
def normalize(input: Tensor, p: float = 2.0, dim: int = 1, eps: float = 1e-12, out: Optional[Tensor] = None) -> Tensor:
diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py
index 4cb8e80..71a9dab 100644
--- a/torch/testing/_internal/common_methods_invocations.py
+++ b/torch/testing/_internal/common_methods_invocations.py
@@ -7386,6 +7386,58 @@
kwargs["distance_function"] = torch.nn.PairwiseDistance()
yield SampleInput(input, args=args, kwargs=kwargs)
+def error_inputs_triplet_margin_loss(op_info, device, **kwargs):
+ make_input = partial(make_tensor, device=device, dtype=torch.float32)
+
+ samples = (
+ # input, args, kwargs, error_type, error_regex
+ # invalid reduction
+ (make_input(3, 4), (make_input(3, 4), make_input(3, 4)),
+ dict(reduction="abc"),
+ ValueError, "abc is not a valid value for reduction"),
+
+ # shape mismatch
+ (make_input(3, 5), (make_input(3, 4), make_input(3, 4)),
+ dict(),
+ RuntimeError,
+ (r"The size of tensor a \(5\) must match the size of tensor b \(4\) "
+ r"at non-singleton dimension 1")),
+ (make_input(3, 4), (make_input(3, 5), make_input(3, 4)),
+ dict(),
+ RuntimeError,
+ (r"The size of tensor a \(4\) must match the size of tensor b \(5\) "
+ r"at non-singleton dimension 1")),
+ (make_input(3, 4), (make_input(3, 4), make_input(3, 5)),
+ dict(),
+ RuntimeError,
+ (r"The size of tensor a \(4\) must match the size of tensor b \(5\) "
+ r"at non-singleton dimension 1")),
+
+ # different dimensions
+ (make_input(3,), (make_input(3, 4), make_input(3, 4)),
+ dict(),
+ RuntimeError,
+ (r"The anchor, positive, and negative tensors are expected to have "
+ r"the same number of dimensions, but got: anchor 1D, positive 2D, "
+ r"and negative 2D inputs")),
+ (make_input(3, 4), (make_input(3,), make_input(3, 4)),
+ dict(),
+ RuntimeError,
+ (r"The anchor, positive, and negative tensors are expected to have "
+ r"the same number of dimensions, but got: anchor 2D, positive 1D, "
+ r"and negative 2D inputs")),
+ (make_input(3, 4), (make_input(3, 4), make_input(3,)),
+ dict(),
+ RuntimeError,
+ (r"The anchor, positive, and negative tensors are expected to have "
+ r"the same number of dimensions, but got: anchor 2D, positive 2D, "
+ r"and negative 1D inputs")),
+ )
+
+ for input, args, kwargs, error_type, error_regex in samples:
+ yield ErrorInput(SampleInput(input, args=args, kwargs=kwargs),
+ error_type=error_type, error_regex=error_regex)
+
def sample_inputs_scaled_dot_product_attention(op_info, device, dtype, requires_grad, **kwargs):
make = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
@@ -12101,6 +12153,7 @@
OpInfo(
"nn.functional.triplet_margin_loss",
sample_inputs_func=sample_inputs_triplet_margin_loss,
+ error_inputs_func=error_inputs_triplet_margin_loss,
dtypes=all_types_and_complex_and(torch.bfloat16),
dtypesIfCUDA=all_types_and_complex_and(torch.float16, torch.bfloat16),
supports_out=False,
@@ -12110,6 +12163,7 @@
OpInfo(
"nn.functional.triplet_margin_with_distance_loss",
sample_inputs_func=partial(sample_inputs_triplet_margin_loss, with_distance=True),
+ error_inputs_func=error_inputs_triplet_margin_loss,
dtypes=all_types_and_complex_and(torch.bfloat16),
dtypesIfCUDA=all_types_and_complex_and(torch.float16, torch.bfloat16),
supports_out=False,
@@ -17572,6 +17626,20 @@
torch_opinfo_name="clamp",
supports_nvfuser=False,
),
+ PythonRefInfo(
+ "_refs.nn.functional.triplet_margin_loss",
+ torch_opinfo_name="nn.functional.triplet_margin_loss",
+ supports_out=False,
+ # TODO: Uses minimum and clamp, which don't support nvfuser.
+ supports_nvfuser=False,
+ skips=(
+ # AssertionError: Tensor-likes are not close!
+ # Greatest absolute difference: 6.103515625e-05 at index (4,) (up to 1e-05 allowed)
+ # Greatest relative difference: 8.519846983548175e-06 at index (4,) (up to 1.3e-06 allowed)
+ DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_python_ref',
+ dtypes=(torch.uint8,), device_type="cpu"),
+ )
+ ),
#
# Data Conversion & Data Movement Opinfos
#