[ONNX] Export KLDivLoss (#41858)
Summary:
Enable export for KLDivLoss
Pull Request resolved: https://github.com/pytorch/pytorch/pull/41858
Reviewed By: mrshenli
Differential Revision: D22918004
Pulled By: bzinodev
fbshipit-source-id: e3debf77a4cf0eae0df6ed5a72ee91c43e482b62
diff --git a/test/onnx/test_pytorch_onnx_onnxruntime.py b/test/onnx/test_pytorch_onnx_onnxruntime.py
index 2b49358..b5cf848 100644
--- a/test/onnx/test_pytorch_onnx_onnxruntime.py
+++ b/test/onnx/test_pytorch_onnx_onnxruntime.py
@@ -3670,6 +3670,72 @@
self.run_test(CrossEntropyLossMeanWeightIgnoreIndex(), input=(x, y))
+ @skipIfUnsupportedMinOpsetVersion(9)
+ def test_kldiv_loss(self):
+
+ x = torch.randn(5)
+ y = torch.randn(5)
+ self._kldiv_loss(x, y)
+
+ x = torch.randn(2, 3, 5)
+ y = torch.randn(2, 3, 5)
+ self._kldiv_loss(x, y)
+
+ x = torch.randn(2, 3, 5, 7)
+ y = torch.randn(2, 3, 5, 7)
+ self._kldiv_loss(x, y)
+
+ def _kldiv_loss(self, x, y):
+ class KLDivLossNone(torch.nn.Module):
+ def __init__(self):
+ super(KLDivLossNone, self).__init__()
+ self.loss = torch.nn.KLDivLoss(reduction='none', log_target=True)
+
+ def forward(self, input, target):
+ return self.loss(input, target)
+
+ self.run_test(KLDivLossNone(), input=(x, y))
+
+ class KLDivLossMean(torch.nn.Module):
+ def __init__(self):
+ super(KLDivLossMean, self).__init__()
+ self.loss = torch.nn.KLDivLoss(reduction='mean', log_target=False)
+
+ def forward(self, input, target):
+ return self.loss(input, target)
+
+ self.run_test(KLDivLossMean(), input=(x, y))
+
+ class KLDivLossSum(torch.nn.Module):
+ def __init__(self):
+ super(KLDivLossSum, self).__init__()
+ self.loss = torch.nn.KLDivLoss(reduction='sum', log_target=True)
+
+ def forward(self, input, target):
+ return self.loss(input, target)
+
+ self.run_test(KLDivLossSum(), input=(x, y))
+
+ class KLDivLossBatchMean(torch.nn.Module):
+ def __init__(self):
+ super(KLDivLossBatchMean, self).__init__()
+ self.loss = torch.nn.KLDivLoss(reduction='batchmean', log_target=False)
+
+ def forward(self, input, target):
+ return self.loss(input, target)
+
+ self.run_test(KLDivLossBatchMean(), input=(x, y))
+
+ class KLDivLossMiniBatchMean(torch.nn.Module):
+ def __init__(self):
+ super(KLDivLossMiniBatchMean, self).__init__()
+ self.loss = torch.nn.KLDivLoss(reduction='batchmean', size_average=False, log_target=True)
+
+ def forward(self, input, target):
+ return self.loss(input, target)
+
+ self.run_test(KLDivLossMiniBatchMean(), input=(x, y))
+
@skipIfUnsupportedMinOpsetVersion(12)
def test_nllloss(self):
class NLLModel(torch.nn.Module):
diff --git a/torch/onnx/symbolic_opset9.py b/torch/onnx/symbolic_opset9.py
index 0c43d14..86f51c5 100644
--- a/torch/onnx/symbolic_opset9.py
+++ b/torch/onnx/symbolic_opset9.py
@@ -2530,6 +2530,42 @@
out = reshape_as(g, out, index)
return out
+
+def _kl_div_log_target_impl(g, input, target):
+ diff_ = sub(g, target, input)
+ exp_ = exp(g, target)
+ output = mul(g, exp_, diff_)
+ return output
+
+
+def _kl_div_non_log_target_impl(g, input, target):
+ log_ = log(g, target)
+ diff_ = sub(g, log_, input)
+ output_pos = mul(g, target, diff_)
+ zeros_ = zeros_like(g, output_pos)
+ mask_ = gt(g, target, g.op("Constant", value_t=torch.tensor(0)))
+ output = where(g, mask_, output_pos, zeros_)
+ return output
+
+
+@parse_args('v', 'v', 'i', 'b')
+def kl_div(g, input, target, reduction, log_target):
+ if log_target:
+ output = _kl_div_log_target_impl(g, input, target)
+ else:
+ output = _kl_div_non_log_target_impl(g, input, target)
+
+ if reduction == 0:
+ return output
+ elif reduction == 1:
+ return g.op("ReduceMean", output, keepdims_i=0)
+ elif reduction == 2:
+ return g.op("ReduceSum", output, keepdims_i=0)
+ else:
+ return sym_help._onnx_unsupported("kl_div with reduction other than none, mean, or sum. Please open a bug to "
+ "request ONNX export support for the missing reduction type.")
+
+
@parse_args('v', 'v', 'is', 'i')
def as_strided(g, self, sizes, strides, offset=None):
sizes = sym_help._maybe_get_const(sizes, 'is')