Adds multilabel_soft_margin_loss opinfo
Per title
Pull Request resolved: https://github.com/pytorch/pytorch/pull/75883
Approved by: https://github.com/ngimel
diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py
index 6f1ec78..e5c43c9 100644
--- a/torch/testing/_internal/common_methods_invocations.py
+++ b/torch/testing/_internal/common_methods_invocations.py
@@ -2658,6 +2658,21 @@
return tuple(sample_inputs)
+# TODO: add reduction kwargs
+def sample_inputs_multilabel_soft_margin_loss(op_info, device, dtype, requires_grad, **kwargs):
+ _make_tensor = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+ shapes = (
+ (S,),
+ (S, S),
+ )
+
+ for shape in shapes:
+ # Produce one with weight and one without.
+ yield SampleInput(_make_tensor(shape), args=(_make_tensor(shape, requires_grad=False),), kwargs={})
+ yield SampleInput(_make_tensor(shape), args=(_make_tensor(shape, requires_grad=False),),
+ kwargs={'weight': _make_tensor(shape, requires_grad=False)})
+
def sample_inputs_addr(op_info, device, dtype, requires_grad, **kwargs):
input1 = SampleInput(
make_tensor((S, M), dtype=dtype, device=device, low=None, high=None, requires_grad=requires_grad),
@@ -11809,6 +11824,36 @@
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
autodiff_nonfusible_nodes=["aten::leaky_relu"]),
+ OpInfo(
+ "nn.functional.multilabel_soft_margin_loss",
+ ref=_NOTHING,
+ supports_out=False,
+ dtypes=floating_types_and(torch.bfloat16),
+ dtypesIfCUDA=floating_types_and(torch.float16),
+ sample_inputs_func=sample_inputs_multilabel_soft_margin_loss,
+ decorators=(
+ DecorateInfo(
+ toleranceOverride({torch.float32: tol(atol=1e-4, rtol=1e-4)}),
+ "TestJit",
+ "test_variant_consistency_jit",
+ ),
+ ),
+ skips=(
+ # target doesn't require grad
+ DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_floating_inputs_are_differentiable'),
+ # AssertionError: False is not true : Scalars failed to compare as equal! 0 != 4096
+ # __main__.TestJitCUDA.test_variant_consistency_jit_nn_functional_multilabel_soft_margin_loss_cuda_float32
+ # leaked 4096 bytes CUDA memory on device 0
+ DecorateInfo(
+ # Skip instead of expectedFailure because this fails
+ # locally for me but passes in CI.
+ unittest.skip("Skipped!"),
+ "TestJit",
+ "test_variant_consistency_jit",
+ device_type="cuda",
+ ),
+ ),
+ ),
OpInfo('nn.functional.avg_pool2d',
aten_name='avg_pool2d',
supports_autograd=True,