[ONNX] Export KLDivLoss (#41858)

Summary:
Enable export for KLDivLoss

Pull Request resolved: https://github.com/pytorch/pytorch/pull/41858

Reviewed By: mrshenli

Differential Revision: D22918004

Pulled By: bzinodev

fbshipit-source-id: e3debf77a4cf0eae0df6ed5a72ee91c43e482b62
diff --git a/test/onnx/test_pytorch_onnx_onnxruntime.py b/test/onnx/test_pytorch_onnx_onnxruntime.py
index 2b49358..b5cf848 100644
--- a/test/onnx/test_pytorch_onnx_onnxruntime.py
+++ b/test/onnx/test_pytorch_onnx_onnxruntime.py
@@ -3670,6 +3670,72 @@
 
         self.run_test(CrossEntropyLossMeanWeightIgnoreIndex(), input=(x, y))
 
+    @skipIfUnsupportedMinOpsetVersion(9)
+    def test_kldiv_loss(self):
+
+        x = torch.randn(5)
+        y = torch.randn(5)
+        self._kldiv_loss(x, y)
+
+        x = torch.randn(2, 3, 5)
+        y = torch.randn(2, 3, 5)
+        self._kldiv_loss(x, y)
+
+        x = torch.randn(2, 3, 5, 7)
+        y = torch.randn(2, 3, 5, 7)
+        self._kldiv_loss(x, y)
+
+    def _kldiv_loss(self, x, y):
+        class KLDivLossNone(torch.nn.Module):
+            def __init__(self):
+                super(KLDivLossNone, self).__init__()
+                self.loss = torch.nn.KLDivLoss(reduction='none', log_target=True)
+
+            def forward(self, input, target):
+                return self.loss(input, target)
+
+        self.run_test(KLDivLossNone(), input=(x, y))
+
+        class KLDivLossMean(torch.nn.Module):
+            def __init__(self):
+                super(KLDivLossMean, self).__init__()
+                self.loss = torch.nn.KLDivLoss(reduction='mean', log_target=False)
+
+            def forward(self, input, target):
+                return self.loss(input, target)
+
+        self.run_test(KLDivLossMean(), input=(x, y))
+
+        class KLDivLossSum(torch.nn.Module):
+            def __init__(self):
+                super(KLDivLossSum, self).__init__()
+                self.loss = torch.nn.KLDivLoss(reduction='sum', log_target=True)
+
+            def forward(self, input, target):
+                return self.loss(input, target)
+
+        self.run_test(KLDivLossSum(), input=(x, y))
+
+        class KLDivLossBatchMean(torch.nn.Module):
+            def __init__(self):
+                super(KLDivLossBatchMean, self).__init__()
+                self.loss = torch.nn.KLDivLoss(reduction='batchmean', log_target=False)
+
+            def forward(self, input, target):
+                return self.loss(input, target)
+
+        self.run_test(KLDivLossBatchMean(), input=(x, y))
+
+        class KLDivLossMiniBatchMean(torch.nn.Module):
+            def __init__(self):
+                super(KLDivLossMiniBatchMean, self).__init__()
+                self.loss = torch.nn.KLDivLoss(reduction='batchmean', size_average=False, log_target=True)
+
+            def forward(self, input, target):
+                return self.loss(input, target)
+
+        self.run_test(KLDivLossMiniBatchMean(), input=(x, y))
+
     @skipIfUnsupportedMinOpsetVersion(12)
     def test_nllloss(self):
         class NLLModel(torch.nn.Module):
diff --git a/torch/onnx/symbolic_opset9.py b/torch/onnx/symbolic_opset9.py
index 0c43d14..86f51c5 100644
--- a/torch/onnx/symbolic_opset9.py
+++ b/torch/onnx/symbolic_opset9.py
@@ -2530,6 +2530,42 @@
     out = reshape_as(g, out, index)
     return out
 
+
+def _kl_div_log_target_impl(g, input, target):
+    diff_ = sub(g, target, input)
+    exp_ = exp(g, target)
+    output = mul(g, exp_, diff_)
+    return output
+
+
+def _kl_div_non_log_target_impl(g, input, target):
+    log_ = log(g, target)
+    diff_ = sub(g, log_, input)
+    output_pos = mul(g, target, diff_)
+    zeros_ = zeros_like(g, output_pos)
+    mask_ = gt(g, target, g.op("Constant", value_t=torch.tensor(0)))
+    output = where(g, mask_, output_pos, zeros_)
+    return output
+
+
+@parse_args('v', 'v', 'i', 'b')
+def kl_div(g, input, target, reduction, log_target):
+    if log_target:
+        output = _kl_div_log_target_impl(g, input, target)
+    else:
+        output = _kl_div_non_log_target_impl(g, input, target)
+
+    if reduction == 0:
+        return output
+    elif reduction == 1:
+        return g.op("ReduceMean", output, keepdims_i=0)
+    elif reduction == 2:
+        return g.op("ReduceSum", output, keepdims_i=0)
+    else:
+        return sym_help._onnx_unsupported("kl_div with reduction other than none, mean, or sum. Please open a bug to "
+                                          "request ONNX export support for the missing reduction type.")
+
+
 @parse_args('v', 'v', 'is', 'i')
 def as_strided(g, self, sizes, strides, offset=None):
     sizes = sym_help._maybe_get_const(sizes, 'is')