Fix l2 normalization when handling zero vector (#9594)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/9594

When the input vector is a zero vector, the previous GPU code will give Nan in backward. We fix this.

Reviewed By: pjh5

Differential Revision: D8849732

fbshipit-source-id: 87b1fb1ee05dfdb0d43bcbe67e36f15896fe1706
diff --git a/caffe2/operators/normalize_op.cc b/caffe2/operators/normalize_op.cc
index 1a7d720..73a8820 100644
--- a/caffe2/operators/normalize_op.cc
+++ b/caffe2/operators/normalize_op.cc
@@ -12,6 +12,7 @@
     const int m,
     const int n,
     const int sf) {
+  const T kEps = 1e-12f;
   using InnerStride = Eigen::InnerStride<Eigen::Dynamic>;
   using StridedVec =
       Eigen::Map<Eigen::Matrix<T, 1, Eigen::Dynamic>, 0, InnerStride>;
@@ -22,10 +23,9 @@
     auto base = (i / sf) * sf * m + (i % sf);
     ConstStridedVec xVec(xData + base, 1, m, InnerStride(sf));
     auto norm = xVec.template lpNorm<2>();
-    if (norm != 0) {
-      StridedVec yVec(yData + base, 1, m, InnerStride(sf));
-      yVec = xVec / norm;
-    }
+    norm = std::max(norm, kEps);
+    StridedVec yVec(yData + base, 1, m, InnerStride(sf));
+    yVec = xVec / norm;
   }
 };
 
@@ -37,6 +37,7 @@
     const int m,
     const int n,
     const int sf) {
+  const T kEps = 1e-12f;
   using InnerStride = Eigen::InnerStride<Eigen::Dynamic>;
   using StridedVec =
       Eigen::Map<Eigen::Matrix<T, 1, Eigen::Dynamic>, 0, InnerStride>;
@@ -50,11 +51,10 @@
 
     auto row_sum = xVec.dot(gOutVec);
     auto row_norm = xVec.template lpNorm<2>();
+    row_norm = std::max(row_norm, kEps);
     auto row_norm_3 = pow(row_norm, 3);
-    if (row_norm != 0) {
-      StridedVec gInVec(gInData + base, 1, m, InnerStride(sf));
-      gInVec = (gOutVec / row_norm) - ((xVec / row_norm_3) * row_sum);
-    }
+    StridedVec gInVec(gInData + base, 1, m, InnerStride(sf));
+    gInVec = (gOutVec / row_norm) - ((xVec / row_norm_3) * row_sum);
   }
 };
 
diff --git a/caffe2/operators/normalize_ops.cu b/caffe2/operators/normalize_ops.cu
index e3dc7b2..dcffe02 100644
--- a/caffe2/operators/normalize_ops.cu
+++ b/caffe2/operators/normalize_ops.cu
@@ -12,6 +12,7 @@
     const int sf,
     const float* xData,
     float* yData) {
+  const float kEps = 1e-12f;
   typedef cub::BlockReduce<float, CAFFE_CUDA_NUM_THREADS> BlockReduce;
   __shared__ BlockReduce::TempStorage temp_storage;
 
@@ -28,13 +29,12 @@
 
     if (threadIdx.x == 0) {
       norm = sqrtf(reduce_result);
+      norm = fmaxf(norm, kEps);
     }
     __syncthreads();
-    if (norm != 0) {
-      for (int j = threadIdx.x; j < m; j += blockDim.x) {
-        const auto index = base + j * sf;
-        yData[index] = xData[index] / norm;
-      }
+    for (int j = threadIdx.x; j < m; j += blockDim.x) {
+      const auto index = base + j * sf;
+      yData[index] = xData[index] / norm;
     }
   }
 }
@@ -46,6 +46,7 @@
     const float* in_mat,
     const float* grad_out_mat,
     float* grad_mat) {
+  const float kEps = 1e-12f;
   typedef cub::BlockReduce<float, CAFFE_CUDA_NUM_THREADS> BlockReduce;
   __shared__ BlockReduce::TempStorage temp_storage_sum;
   __shared__ BlockReduce::TempStorage temp_storage_norm;
@@ -67,6 +68,7 @@
     if (threadIdx.x == 0) {
       row_sum = reduce_result;
       row_norm = sqrtf(reduce_norm);
+      row_norm = fmaxf(row_norm, kEps);
       row_norm_3 = powf(row_norm, 3);
     }
     __syncthreads();
diff --git a/caffe2/python/operator_test/normalize_op_test.py b/caffe2/python/operator_test/normalize_op_test.py
index 965bbe7..933d78f 100644
--- a/caffe2/python/operator_test/normalize_op_test.py
+++ b/caffe2/python/operator_test/normalize_op_test.py
@@ -9,45 +9,45 @@
 import hypothesis.strategies as st
 from caffe2.python import core
 import caffe2.python.hypothesis_test_util as hu
+import copy
 
 
 class TestNormalizeOp(hu.HypothesisTestCase):
-
-    @given(X=hu.tensor(min_dim=1,
-                       max_dim=5,
-                       elements=st.floats(min_value=0.5, max_value=1.0)),
-           **hu.gcs)
+    @given(
+        X=hu.tensor(
+            min_dim=1, max_dim=5, elements=st.floats(min_value=0.5, max_value=1.0)
+        ),
+        **hu.gcs
+    )
     def test_normalize(self, X, gc, dc):
         def ref_normalize(X, axis):
-            x_normed = X / (
-                np.sqrt((X**2).sum(axis=axis, keepdims=True)) + np.finfo(X.dtype).tiny)
+            x_normed = X / np.maximum(
+                np.sqrt((X ** 2).sum(axis=axis, keepdims=True)), 1e-12
+            )
             return (x_normed,)
 
         for axis in range(-X.ndim, X.ndim):
+            x = copy.copy(X)
             op = core.CreateOperator("Normalize", "X", "Y", axis=axis)
             self.assertReferenceChecks(
-                gc,
-                op,
-                [X],
-                functools.partial(ref_normalize, axis=axis))
-            self.assertDeviceChecks(dc, op, [X], [0])
-            self.assertGradientChecks(gc, op, [X], 0, [0])
+                gc, op, [x], functools.partial(ref_normalize, axis=axis)
+            )
+            self.assertDeviceChecks(dc, op, [x], [0])
+            self.assertGradientChecks(gc, op, [x], 0, [0])
 
-    @given(X=hu.tensor(min_dim=1,
-                       max_dim=5,
-                       elements=st.floats(min_value=0.5, max_value=1.0)),
-           **hu.gcs)
+    @given(
+        X=hu.tensor(
+            min_dim=1, max_dim=5, elements=st.floats(min_value=0.5, max_value=1.0)
+        ),
+        **hu.gcs
+    )
     def test_normalize_L1(self, X, gc, dc):
         def ref(X, axis):
             norm = abs(X).sum(axis=axis, keepdims=True)
             return (X / norm,)
 
         for axis in range(-X.ndim, X.ndim):
-            print('axis: ', axis)
+            print("axis: ", axis)
             op = core.CreateOperator("NormalizeL1", "X", "Y", axis=axis)
-            self.assertReferenceChecks(
-                gc,
-                op,
-                [X],
-                functools.partial(ref, axis=axis))
+            self.assertReferenceChecks(gc, op, [X], functools.partial(ref, axis=axis))
             self.assertDeviceChecks(dc, op, [X], [0])