Revert "Add binary_cross_entropy and trace decomp - fixed _log_softmax/_softmax dtype promotion semantics"
This reverts commit 8a3e9255ea675a898c22ee479c16c60dbc87656b.
Reverted https://github.com/pytorch/pytorch/pull/76670 on behalf of https://github.com/mruberry
diff --git a/torch/_decomp/decompositions.py b/torch/_decomp/decompositions.py
index e530aff..56ce840 100644
--- a/torch/_decomp/decompositions.py
+++ b/torch/_decomp/decompositions.py
@@ -60,15 +60,6 @@
return x
-def apply_loss_reduction(loss: Tensor, reduction: int):
- if reduction == Reduction.MEAN.value:
- return torch.mean(loss)
- elif reduction == Reduction.SUM.value:
- return torch.sum(loss)
- else:
- return loss
-
-
@register_decomposition(aten.tanh_backward)
@cast_for_opmath
def tanh_backward(out_grad: Tensor, y: Tensor):
@@ -283,11 +274,6 @@
)
-@register_decomposition(aten.trace)
-def trace(x):
- return torch.sum(torch.diagonal(x))
-
-
@register_decomposition(aten.prelu_backward)
@cast_for_opmath
def prelu_backward(
@@ -340,6 +326,15 @@
# return (max_deriv - sign * (buffer / (1 + buffer))) * grad_output
+def apply_loss_reduction(loss: Tensor, reduction: int):
+ if reduction == Reduction.MEAN.value:
+ return torch.mean(loss)
+ elif reduction == Reduction.SUM.value:
+ return torch.sum(loss)
+ else:
+ return loss
+
+
def to_real_dtype(dtype: torch.dtype):
if dtype == torch.complex32:
return torch.float16
@@ -483,26 +478,6 @@
return grad_input * grad_output
-@register_decomposition(aten.binary_cross_entropy)
-def binary_cross_entropy(
- self: Tensor,
- target: Tensor,
- weight: Optional[Tensor] = None,
- reduction: int = Reduction.MEAN.value,
-) -> Tensor:
- # We cannot currently model this without introducing data-dependent control flow
- # TORCH_CHECK(
- # (input_val >= 0) && (input_val <= 1),
- # "all elements of input should be between 0 and 1"
- # )
- loss = (target - 1) * torch.maximum(
- torch.log(1 - self), self.new_full((), -100)
- ) - target * torch.maximum(torch.log(self), self.new_full((), -100))
- if weight is not None:
- loss = loss * weight
- return apply_loss_reduction(loss, reduction)
-
-
@register_decomposition(aten.binary_cross_entropy_backward)
@cast_for_opmath
def binary_cross_entropy_backward(
@@ -566,8 +541,7 @@
grad_output: Tensor, output: Tensor, dim: int, input_dtype: int
):
new_grad = grad_output * output
- grad_input = new_grad - output * torch.sum(new_grad, dim=dim, keepdim=True)
- return aten.to(grad_input, dtype=input_dtype)
+ return new_grad - output * torch.sum(new_grad, dim=dim, keepdim=True)
@register_decomposition(aten._log_softmax_backward_data)
@@ -578,7 +552,7 @@
grad_input = grad_output - torch.exp(output) * torch.sum(
grad_output, dim=dim, keepdim=True
)
- return aten.to(grad_input, dtype=input_dtype)
+ return grad_input
# TODO: the type annotations on arguments are not quite right
@@ -1102,12 +1076,7 @@
# Cudnn return running mean and variance when training is True
if training:
return (a, b, c, input.new_zeros((0,), dtype=torch.uint8))
- return (
- a,
- input.new_zeros((0,)),
- input.new_zeros((0,)),
- input.new_zeros((0,), dtype=torch.uint8),
- )
+ return (a, input.new_zeros((0,)), input.new_zeros((0,)), input.new_zeros((0,), dtype=torch.uint8))
@register_decomposition(aten.cudnn_batch_norm_backward)