Implement KLDivLoss double backwards.
diff --git a/test/common_nn.py b/test/common_nn.py
index 8729867..4e89d2b 100644
--- a/test/common_nn.py
+++ b/test/common_nn.py
@@ -280,7 +280,13 @@
module_name='KLDivLoss',
input=torch.rand(10, 10).log(),
target=torch.rand(10, 10),
- check_gradgrad=False,
+ ),
+ dict(
+ module_name='KLDivLoss',
+ constructor_args=(False,),
+ input=torch.rand(10, 10).log(),
+ target=torch.rand(10, 10),
+ desc='no_size_average',
),
dict(
module_name='MSELoss',
diff --git a/torch/nn/_functions/thnn/auto_double_backwards.py b/torch/nn/_functions/thnn/auto_double_backwards.py
index 5bee9cc..744c244 100644
--- a/torch/nn/_functions/thnn/auto_double_backwards.py
+++ b/torch/nn/_functions/thnn/auto_double_backwards.py
@@ -165,13 +165,13 @@
return gI, ggO, None, None, None
-def mseloss_double_backwards(ctx, ggI):
+def klddivloss_double_backwards(ctx, ggI):
size_average = ctx.additional_args[0]
input, target, gO = ctx.saved_variables
div_factor = input.nelement() if size_average else 1
- gI = ggI * (gO * 2. / div_factor).expand_as(input)
- ggO = (ggI * (input - target)).sum() * (2. / div_factor)
+ gI = None
+ ggO = (ggI * target).sum() / -div_factor
return gI, None, ggO, None, None
@@ -189,6 +189,17 @@
return gI, None, ggO, None, None
+def mseloss_double_backwards(ctx, ggI):
+ size_average = ctx.additional_args[0]
+ input, target, gO = ctx.saved_variables
+ div_factor = input.nelement() if size_average else 1
+
+ gI = ggI * (gO * 2. / div_factor).expand_as(input)
+ ggO = (ggI * (input - target)).sum() * (2. / div_factor)
+
+ return gI, None, ggO, None, None
+
+
def nllloss_double_backwards(ctx, ggI):
t = ctx.saved_variables
target = t[1]
@@ -237,7 +248,7 @@
large_error_neg_mask = (((input_sub_target <= 0) + large_error_mask) == 2).type_as(ggI)
small_error_mask = small_error_mask.type_as(ggI)
- gI = 1. / div_factor * small_error_mask * ggI * gO
+ gI = small_error_mask * ggI * gO / div_factor
ggO = (ggI * (input_sub_target * small_error_mask + large_error_pos_mask - large_error_neg_mask)).sum() / div_factor
return gI, None, ggO, None, None, None
@@ -254,6 +265,7 @@
'Softplus': softplus_double_backwards,
'Softshrink': softshrink_double_backwards,
'Threshold': threshold_double_backwards,
+ 'KLDivLoss': klddivloss_double_backwards,
'L1Loss': l1loss_double_backwards,
'MSELoss': mseloss_double_backwards,
'NLLLoss': nllloss_double_backwards,