make vector_norm backward call norm_backward (#59135)
Summary:
Per title. Remove duplicated code.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/59135
Reviewed By: mruberry
Differential Revision: D28775716
Pulled By: ngimel
fbshipit-source-id: 50dc77590db15976453fc41c3657a77198749849
diff --git a/torch/csrc/autograd/FunctionsManual.cpp b/torch/csrc/autograd/FunctionsManual.cpp
index 110f81e..c2db17f 100644
--- a/torch/csrc/autograd/FunctionsManual.cpp
+++ b/torch/csrc/autograd/FunctionsManual.cpp
@@ -198,7 +198,7 @@
scale_v = grad / norm;
} else if (std::isinf(p)) {
Tensor is_eq_max = (self.abs() == norm).logical_or_(self.isnan().logical_and_(norm.isnan())).type_as(self);
- self_scaled = self.sign() * is_eq_max;
+ self_scaled = self.sgn() * is_eq_max;
Tensor nb_max = is_eq_max.count_nonzero(dim);
if (self.dim() != 0) {
nb_max = unsqueeze_multiple(nb_max, dim, ndim);
@@ -217,50 +217,8 @@
}
Tensor linalg_vector_norm_backward(Tensor grad, const Tensor& self, const Scalar& scalar_ord, Tensor norm, const optional<IntArrayRef>& opt_dim, bool keepdim) {
- size_t ndim = self.sizes().size();
- auto ord = scalar_ord.toDouble();
auto dim = opt_dim.value_or(IntArrayRef({}));
- Tensor self_scaled;
- Tensor scale_v;
-
- if (!keepdim && self.dim() != 0) {
- grad = unsqueeze_multiple(grad, dim, ndim);
- norm = unsqueeze_multiple(norm, dim, ndim);
- }
-
- if (ord == 0.0) {
- return at::zeros_like(self, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
- } else if (ord == 1.0) {
- return self.sgn() * grad;
- } else if (ord == 2.0) {
- self_scaled = self;
- scale_v = grad / norm;
- } else if (std::isinf(ord)) {
- // Find the elements from `self` that equal the norm result
- Tensor is_equal_to_norm;
-
- is_equal_to_norm = (self.abs() == norm);
-
- // Need to explicitly check for nan in the input and output since `nan ==
- // nan` is false
- is_equal_to_norm = is_equal_to_norm.logical_or_(self.isnan().logical_and_(norm.isnan())).type_as(self);
-
- self_scaled = self.sgn() * is_equal_to_norm;
- Tensor nb_max = is_equal_to_norm.count_nonzero(dim);
- if (self.dim() != 0) {
- nb_max = unsqueeze_multiple(nb_max, dim, ndim);
- }
- scale_v = grad / nb_max;
- } else if (ord < 2.0) {
- self_scaled = self.sgn() * self.abs().pow(ord - 1);
- scale_v = grad / norm.pow(ord - 1);
- } else {
- self_scaled = self * self.abs().pow(ord - 2);
- scale_v = grad / norm.pow(ord - 1);
- }
- // handle case at 0 where we return a subgradient containing 0
- scale_v.masked_fill_(norm == 0, 0);
- return self_scaled * scale_v;
+ return norm_backward(grad, self, scalar_ord, norm, dim, keepdim);
}
Tensor pow_backward(Tensor grad, const Tensor & self, const Scalar & exponent) {