[tfdbg] tf.debugging.enable_check_numerics() uses CheckNumericsV2 op

- Add the CPU and GPU implementationsof CheckNumericsV2Op
- CheckNumericsV2Op inherits from CheckNumercsOp, but has the new feature of
  distinguishes +/- infinities.

PiperOrigin-RevId: 281390258
Change-Id: Id45ef975da95104ef9b482bbb7034790b314262e
diff --git a/tensorflow/core/api_def/base_api/api_def_CheckNumericsV2.pbtxt b/tensorflow/core/api_def/base_api/api_def_CheckNumericsV2.pbtxt
new file mode 100644
index 0000000..2aa0d64
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_CheckNumericsV2.pbtxt
@@ -0,0 +1,17 @@
+op {
+  graph_op_name: "CheckNumericsV2"
+  visibility: HIDDEN
+  attr {
+    name: "message"
+    description: <<END
+Prefix of the error message.
+END
+  }
+  summary: "Checks a tensor for NaN, -Inf and +Inf values."
+  description: <<END
+When run, reports an `InvalidArgument` error if `tensor` has any values
+that are not a number (NaN) or infinity (Inf). Otherwise, passes `tensor` as-is.
+Unlike CheckNumerics (V1), CheckNumericsV2 distinguishes -Inf and +Inf in the
+errors it throws.
+END
+}
diff --git a/tensorflow/core/api_def/java_api/api_def_CheckNumericsV2.pbtxt b/tensorflow/core/api_def/java_api/api_def_CheckNumericsV2.pbtxt
new file mode 100644
index 0000000..640d292
--- /dev/null
+++ b/tensorflow/core/api_def/java_api/api_def_CheckNumericsV2.pbtxt
@@ -0,0 +1,4 @@
+op {
+  graph_op_name: "CheckNumericsV2"
+  visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_CheckNumericsV2.pbtxt b/tensorflow/core/api_def/python_api/api_def_CheckNumericsV2.pbtxt
new file mode 100644
index 0000000..640d292
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_CheckNumericsV2.pbtxt
@@ -0,0 +1,4 @@
+op {
+  graph_op_name: "CheckNumericsV2"
+  visibility: HIDDEN
+}
diff --git a/tensorflow/core/kernels/check_numerics_op.cc b/tensorflow/core/kernels/check_numerics_op.cc
index 1baeefb..63c52a1 100644
--- a/tensorflow/core/kernels/check_numerics_op.cc
+++ b/tensorflow/core/kernels/check_numerics_op.cc
@@ -50,10 +50,25 @@
 extern template struct CheckNumericsLaunch<Eigen::half>;
 extern template struct CheckNumericsLaunch<float>;
 extern template struct CheckNumericsLaunch<double>;
+
+template <typename T>
+struct CheckNumericsLaunchV2 {
+  void Run(const GPUDevice& d, const T* data, int size,
+           int abnormal_detected[3]);
+};
+
+extern template struct CheckNumericsLaunchV2<Eigen::half>;
+extern template struct CheckNumericsLaunchV2<float>;
+extern template struct CheckNumericsLaunchV2<double>;
 #endif
 
 namespace {
 
+const int kInfBit = 0x01;
+const int kNaNBit = 0x02;
+const int kNegativeInfBit = 0x04;
+const int kPositiveInfBit = 0x08;
+
 template <typename Device, typename T>
 class CheckNumericsOp;
 
@@ -77,30 +92,11 @@
     const T* data = in.data();
     const int64 size = in.size();
     // Check to see if any element of the tensor is NaN or Inf.
-    int fp_props =
-        std::accumulate(data, data + size, 0, [](const int x, const T& y) {
-          int result = x;
-          if (TF_PREDICT_TRUE(Eigen::numext::isfinite(y))) {
-            // Do nothing: common case
-          } else if (Eigen::numext::isinf(y)) {
-            result |= kInfBit;
-          } else if (Eigen::numext::isnan(y)) {
-            result |= kNaNBit;
-          }
-          return result;
-        });
+    int fp_props = std::accumulate(
+        data, data + size, 0,
+        [this](const int x, const T& y) { return checkFloatingElement(x, y); });
     if (fp_props != 0) {
-      string status;
-      if ((fp_props & kInfBit) && (fp_props & kNaNBit)) {
-        status = "Inf and NaN";
-      } else {
-        if (fp_props & kInfBit) {
-          status = "Inf";
-        }
-        if (fp_props & kNaNBit) {
-          status = "NaN";
-        }
-      }
+      const string& status = getErrorString(fp_props);
       if (!status.empty()) {
         context->SetStatus(errors::InvalidArgument(message_, " : Tensor had ",
                                                    status, " values"));
@@ -108,10 +104,86 @@
     }
   }
 
+ protected:
+  virtual int checkFloatingElement(const int x, const T& y) {
+    int result = x;
+    if (TF_PREDICT_TRUE(Eigen::numext::isfinite(y))) {
+      // Do nothing: common case.
+    } else {
+      if (Eigen::numext::isinf(y)) {
+        result |= kInfBit;
+      } else if (Eigen::numext::isnan(y)) {
+        result |= kNaNBit;
+      }
+    }
+    return result;
+  }
+
+  virtual const string getErrorString(const int fp_props) {
+    string status;
+    if ((fp_props & kInfBit) && (fp_props & kNaNBit)) {
+      status = "Inf and NaN";
+    } else {
+      if (fp_props & kInfBit) {
+        status = "Inf";
+      }
+      if (fp_props & kNaNBit) {
+        status = "NaN";
+      }
+    }
+    return status;
+  }
+
  private:
   string message_;
-  static const int kInfBit = 0x01;
-  static const int kNaNBit = 0x02;
+};
+
+template <typename Device, typename T>
+class CheckNumericsV2Op;
+
+// Partial specialization for CPU: v2.
+// The v2 op differs from the v1 in that it distinguishes -inf and +inf.
+template <typename T>
+class CheckNumericsV2Op<CPUDevice, T> : public CheckNumericsOp<CPUDevice, T> {
+ public:
+  explicit CheckNumericsV2Op(OpKernelConstruction* context)
+      : CheckNumericsOp<CPUDevice, T>(context) {}
+
+ protected:
+  int checkFloatingElement(const int x, const T& y) override {
+    int result = x;
+    if (TF_PREDICT_TRUE(Eigen::numext::isfinite(y))) {
+      // Do nothing: common case.
+    } else {
+      if (Eigen::numext::isinf(y)) {
+        result |= y < static_cast<T>(0.) ? kNegativeInfBit : kPositiveInfBit;
+      } else if (Eigen::numext::isnan(y)) {
+        result |= kNaNBit;
+      }
+    }
+    return result;
+  }
+
+  const string getErrorString(const int fp_props) override {
+    std::vector<string> anomalies;
+    if (fp_props & kNegativeInfBit) {
+      anomalies.push_back("-Inf");
+    }
+    if (fp_props & kPositiveInfBit) {
+      anomalies.push_back("+Inf");
+    }
+    if (fp_props & kNaNBit) {
+      anomalies.push_back("NaN");
+    }
+    if (anomalies.size() == 3) {
+      return strings::StrCat(anomalies[0], ", ", anomalies[1], ", and ",
+                             anomalies[2]);
+    } else if (anomalies.size() == 2) {
+      return strings::StrCat(anomalies[0], " and ", anomalies[1]);
+    } else {
+      return anomalies[0];
+    }
+  }
 };
 
 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
@@ -138,8 +210,8 @@
     auto input = context->input(0).flat<T>();
 
     // Allocate and initialize the elements to hold the check results
-    const int abnormal_detected_size = 2;
     Tensor abnormal_detected;
+    const int abnormal_detected_size = getAnomalyIndicatorSize();
     OP_REQUIRES_OK(context, context->allocate_temp(
                                 DT_INT32, TensorShape({abnormal_detected_size}),
                                 &abnormal_detected));
@@ -156,8 +228,8 @@
 
     // Call the GPU kernels for the numerical checks
     const Device& d = context->eigen_device<Device>();
-    CheckNumericsLaunch<T>().Run(d, input.data(), input.size(),
-                                 abnormal_detected.flat<int>().data());
+    RunKernel(d, input.data(), input.size(),
+              abnormal_detected.flat<int>().data());
 
     // Copy the results from device to host
     AllocatorAttributes attr;
@@ -190,42 +262,97 @@
       se::rocm::ScopedActivateExecutorContext scoped_activation{
           stream->parent()};
 #endif
-      auto abnormal_detected_host_flat = abnormal_detected_host.flat<int>();
-      int is_nan = abnormal_detected_host_flat(0);
-      int is_inf = abnormal_detected_host_flat(1);
+      TTypes<const int>::Vec abnormal_detected_host_flat =
+          abnormal_detected_host.flat<int>();
       abnormal_detected_ref.Unref();
-      if (is_nan || is_inf) {
-        string status;
-        LOG(ERROR) << "abnormal_detected_host @"
-                   << abnormal_detected_host_flat.data() << " = {" << is_nan
-                   << ", " << is_inf << "} " << message_;
-
-        // Results should always be 1 or 0.  If we see anything else then
-        // there has been some GPU memory corruption.
-        CHECK_GE(is_nan, 0);
-        CHECK_GE(is_inf, 0);
-        CHECK_LE(is_nan, 1);
-        CHECK_LE(is_inf, 1);
-
-        if (is_nan && is_inf) {
-          status = "Inf and NaN";
-        } else if (is_nan) {
-          status = "NaN";
-        } else if (is_inf) {
-          status = "Inf";
-        }
-        context->SetStatus(errors::InvalidArgument(message_, " : Tensor had ",
-                                                   status, " values"));
-      }
+      checkForAnomalies(context, abnormal_detected_host_flat);
       done();
     };
     context->device()->tensorflow_gpu_device_info()->event_mgr->ThenExecute(
         stream, std::move(check_cb));
   }
 
- private:
+ protected:
+  virtual int getAnomalyIndicatorSize() { return 2; }
+
+  virtual void RunKernel(const GPUDevice& d, const T* data, int size,
+                         int* abnormal_detected) {
+    CheckNumericsLaunch<T>().Run(d, data, size, abnormal_detected);
+  }
+
+  virtual void checkForAnomalies(
+      OpKernelContext* context,
+      const TTypes<const int>::Vec& abnormality_indicators) {
+    const int is_nan = abnormality_indicators(0);
+    const int is_inf = abnormality_indicators(1);
+    if (is_nan || is_inf) {
+      LOG(ERROR) << "abnormal_detected_host @" << abnormality_indicators.data()
+                 << " = {" << is_nan << ", " << is_inf << "} " << message_;
+
+      string anomalies;
+      if (is_nan && is_inf) {
+        anomalies = "Inf and NaN";
+      } else if (is_nan) {
+        anomalies = "NaN";
+      } else if (is_inf) {
+        anomalies = "Inf";
+      }
+      context->SetStatus(errors::InvalidArgument(message_, " : Tensor had ",
+                                                 anomalies, " values"));
+    }
+  }
+
   string message_;
 };
+
+template <typename T>
+class CheckNumericsV2Op<GPUDevice, T> : public CheckNumericsOp<GPUDevice, T> {
+ public:
+  CheckNumericsV2Op(OpKernelConstruction* context)
+      : CheckNumericsOp<GPUDevice, T>(context) {}
+
+ protected:
+  int getAnomalyIndicatorSize() override { return 3; }
+
+  void RunKernel(const GPUDevice& d, const T* data, int size,
+                 int* abnormal_detected) override {
+    CheckNumericsLaunchV2<T>().Run(d, data, size, abnormal_detected);
+  }
+
+  void checkForAnomalies(
+      OpKernelContext* context,
+      const TTypes<const int>::Vec& abnormality_indicators) override {
+    const int is_nan = abnormality_indicators(0);
+    const int is_negative_inf = abnormality_indicators(1);
+    const int is_positive_inf = abnormality_indicators(2);
+    if (is_negative_inf || is_positive_inf || is_nan) {
+      std::vector<string> anomalies;
+      if (is_negative_inf) {
+        anomalies.push_back("-Inf");
+      }
+      if (is_positive_inf) {
+        anomalies.push_back("+Inf");
+      }
+      if (is_nan) {
+        anomalies.push_back("NaN");
+      }
+      string all_anomalies;
+      if (anomalies.size() == 3) {
+        all_anomalies = strings::StrCat(anomalies[0], ", ", anomalies[1],
+                                        ", and ", anomalies[2]);
+      } else if (anomalies.size() == 2) {
+        all_anomalies = strings::StrCat(anomalies[0], " and ", anomalies[1]);
+      } else {
+        all_anomalies = anomalies[0];
+      }
+      context->SetStatus(errors::InvalidArgument(
+          this->message_, " : Tensor had ", all_anomalies, " values"));
+    }
+  }
+
+  static const int abnormal_detected_size = 3;
+};
+
 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
 
 }  // namespace
@@ -239,6 +366,15 @@
 TF_CALL_float(REGISTER_CPU_KERNEL);
 TF_CALL_double(REGISTER_CPU_KERNEL);
 
+#define REGISTER_V2_CPU_KERNEL(T)                                        \
+  REGISTER_KERNEL_BUILDER(                                               \
+      Name("CheckNumericsV2").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
+      CheckNumericsV2Op<CPUDevice, T>);
+TF_CALL_half(REGISTER_V2_CPU_KERNEL);
+TF_CALL_bfloat16(REGISTER_V2_CPU_KERNEL);
+TF_CALL_float(REGISTER_V2_CPU_KERNEL);
+TF_CALL_double(REGISTER_V2_CPU_KERNEL);
+
 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
 REGISTER_KERNEL_BUILDER(
     Name("CheckNumerics").Device(DEVICE_GPU).TypeConstraint<Eigen::half>("T"),
@@ -249,6 +385,16 @@
 REGISTER_KERNEL_BUILDER(
     Name("CheckNumerics").Device(DEVICE_GPU).TypeConstraint<double>("T"),
     CheckNumericsOp<GPUDevice, double>);
+
+REGISTER_KERNEL_BUILDER(
+    Name("CheckNumericsV2").Device(DEVICE_GPU).TypeConstraint<Eigen::half>("T"),
+    CheckNumericsV2Op<GPUDevice, Eigen::half>);
+REGISTER_KERNEL_BUILDER(
+    Name("CheckNumericsV2").Device(DEVICE_GPU).TypeConstraint<float>("T"),
+    CheckNumericsV2Op<GPUDevice, float>);
+REGISTER_KERNEL_BUILDER(
+    Name("CheckNumericsV2").Device(DEVICE_GPU).TypeConstraint<double>("T"),
+    CheckNumericsV2Op<GPUDevice, double>);
 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
 
 }  // namespace tensorflow
diff --git a/tensorflow/core/kernels/check_numerics_op_gpu.cu.cc b/tensorflow/core/kernels/check_numerics_op_gpu.cu.cc
index 2060b64..3e41734 100644
--- a/tensorflow/core/kernels/check_numerics_op_gpu.cu.cc
+++ b/tensorflow/core/kernels/check_numerics_op_gpu.cu.cc
@@ -54,6 +54,29 @@
   }
 }
 
+// V2 of CheckNumericsKernel for GPU.
+// Unlike CheckNumericsKernel (V1), this kernel disinguishes -Inf and +Inf.
+// The 3 elements of `abnormal_detected` are used to signify NaN, -Inf and +Inf,
+// respectively.
+template <typename T>
+__global__ void CheckNumericsKernelV2(const T* __restrict__ data, int size,
+                                      int abnormal_detected[3]) {
+  const int32 thread_id = blockIdx.x * blockDim.x + threadIdx.x;
+  const int32 total_thread_count = gridDim.x * blockDim.x;
+
+  int32 offset = thread_id;
+
+  while (offset < size) {
+    if (isnan(data[offset])) {
+      abnormal_detected[0] = 1;
+    }
+    if (isinf(data[offset])) {
+      abnormal_detected[data[offset] < static_cast<T>(0.f) ? 1 : 2] = 1;
+    }
+    offset += total_thread_count;
+  }
+}
+
 }  // namespace
 
 // A simple launch pad to launch the Cuda kernels that checks the numerical
@@ -76,5 +99,24 @@
 template struct CheckNumericsLaunch<float>;
 template struct CheckNumericsLaunch<double>;
 
+template <typename T>
+struct CheckNumericsLaunchV2 {
+  void Run(const GPUDevice& d, const T* data, int size,
+           int abnormal_detected[3]) {
+    const int32 block_size = d.maxGpuThreadsPerBlock();
+    const int32 num_blocks =
+        (d.getNumGpuMultiProcessors() * d.maxGpuThreadsPerMultiProcessor()) /
+        block_size;
+
+    TF_CHECK_OK(GpuLaunchKernel(CheckNumericsKernelV2<T>, num_blocks,
+                                block_size, 0, d.stream(), data, size,
+                                abnormal_detected));
+  }
+};
+
+template struct CheckNumericsLaunchV2<Eigen::half>;
+template struct CheckNumericsLaunchV2<float>;
+template struct CheckNumericsLaunchV2<double>;
+
 }  // namespace tensorflow
 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
diff --git a/tensorflow/core/ops/array_ops.cc b/tensorflow/core/ops/array_ops.cc
index dbe357d..a427b8b 100644
--- a/tensorflow/core/ops/array_ops.cc
+++ b/tensorflow/core/ops/array_ops.cc
@@ -1350,6 +1350,15 @@
     .SetShapeFn(shape_inference::UnchangedShape);
 
 // --------------------------------------------------------------------------
+REGISTER_OP("CheckNumericsV2")
+    .Input("tensor: T")
+    .Output("output: T")
+    .Attr("T: {bfloat16, half, float, double}")
+    .Attr("message: string")
+    .SetIsStateful()
+    .SetShapeFn(shape_inference::UnchangedShape);
+
+// --------------------------------------------------------------------------
 REGISTER_OP("Reshape")
     .Input("tensor: T")
     .Input("shape: Tshape")
diff --git a/tensorflow/python/debug/lib/check_numerics_callback.py b/tensorflow/python/debug/lib/check_numerics_callback.py
index 8dac482..735aedb 100644
--- a/tensorflow/python/debug/lib/check_numerics_callback.py
+++ b/tensorflow/python/debug/lib/check_numerics_callback.py
@@ -246,7 +246,7 @@
       for slot, output in enumerate(outputs):
         if (output.dtype.is_floating and
             (op_type_bytes, slot) not in IGNORE_OP_OUTPUTS):
-          checked_output = array_ops.check_numerics(
+          checked_output = array_ops.check_numerics_v2(
               # TF v2 has automatic control dependencies added to stateful async
               # ops, which allows us to run check_numerics asynchronously.
               # In the above case we use debug_summary to reduce all output
@@ -268,7 +268,7 @@
           instrumented_outputs.append(output)
       return instrumented_outputs
     else:
-      if op_type_bytes == b"CheckNumerics":
+      if op_type_bytes == b"CheckNumericsV2":
         # TODO(b/140334369): Remove this special casing logic once op_callback.
         # automatically prevents infinite recursion in eager mode.
         return None
@@ -276,14 +276,10 @@
       for slot, output in enumerate(outputs):
         if (output.dtype.is_floating and
             (op_type_bytes, slot) not in IGNORE_OP_OUTPUTS):
-          array_ops.check_numerics(
+          array_ops.check_numerics_v2(
               output,
               get_check_numerics_error_message(
-                  slot,
-                  len(outputs),
-                  op_type,
-                  output,
-                  inputs,
+                  slot, len(outputs), op_type, output, inputs,
                   stack_height_limit=self._stack_height_limit,
                   path_length_limit=self._path_length_limit))
 
diff --git a/tensorflow/python/kernel_tests/numerics_test.py b/tensorflow/python/kernel_tests/numerics_test.py
index f13f9d6..8a4cdff 100644
--- a/tensorflow/python/kernel_tests/numerics_test.py
+++ b/tensorflow/python/kernel_tests/numerics_test.py
@@ -22,6 +22,7 @@
 
 from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
 from tensorflow.python.framework import ops
 from tensorflow.python.framework import test_util
 from tensorflow.python.ops import array_ops
@@ -129,6 +130,51 @@
         r"or `tf.while_loop\(\)`\."):
       numerics.add_check_numerics_ops()
 
+  def testCheckNumericsV2OpNegativeAndPositveInf(self):
+    """Test that CheckNumericsV2 op distinguishes negative and positive infs."""
+    with self.session(graph=ops.Graph()):
+      t1 = constant_op.constant([-1.0, 1.0])
+      t2 = constant_op.constant([0.0, 0.0])
+      checked = array_ops.check_numerics_v2(
+          t1 / t2, message="pass through test")
+      caught = None
+      try:
+        self.evaluate(checked)
+      except errors.InvalidArgumentError as error:
+        caught = error
+      self.assertIn("had -Inf and +Inf values", caught.message)
+      self.assertIn("pass through test", caught.message)
+
+  def testCheckNumericsV2OpNegativeAndPositveInfAndNaN(self):
+    """CheckNumericsV2 op distinguishes - & + infs when nan is present."""
+    with self.session(graph=ops.Graph()):
+      t1 = constant_op.constant([-1.0, 1.0, 0.0])
+      t2 = constant_op.constant([0.0, 0.0, 0.0])
+      checked = array_ops.check_numerics_v2(
+          t1 / t2, message="pass through test")
+      caught = None
+      try:
+        self.evaluate(checked)
+      except errors.InvalidArgumentError as error:
+        caught = error
+      self.assertIn("had -Inf, +Inf, and NaN values", caught.message)
+      self.assertIn("pass through test", caught.message)
+
+  def testCheckNumericsV2PositveInfAndNaN(self):
+    """Test that CheckNumericsV2 op shows sign of inf when nan is present."""
+    with self.session(graph=ops.Graph()):
+      t1 = constant_op.constant([0.0, 1.0])
+      t2 = constant_op.constant([0.0, 0.0])
+      checked = array_ops.check_numerics_v2(
+          t1 / t2, message="pass through test")
+      caught = None
+      try:
+        self.evaluate(checked)
+      except errors.InvalidArgumentError as error:
+        caught = error
+      self.assertIn("had +Inf and NaN values", caught.message)
+      self.assertIn("pass through test", caught.message)
+
 
 if __name__ == "__main__":
   test.main()
diff --git a/tensorflow/python/ops/array_grad.py b/tensorflow/python/ops/array_grad.py
index 6dd8538..0480802 100644
--- a/tensorflow/python/ops/array_grad.py
+++ b/tensorflow/python/ops/array_grad.py
@@ -710,6 +710,15 @@
       op.get_attr("message"))
 
 
+@ops.RegisterGradient("CheckNumericsV2")
+def _CheckNumericsV2Grad(op, grad):
+  """Gradient for check_numerics op."""
+  return array_ops.check_numerics_v2(
+      grad,
+      "Not a number (NaN) or infinity (Inf) values detected in gradient. %s" %
+      op.get_attr("message"))
+
+
 @ops.RegisterGradient("PlaceholderWithDefault")
 @ops.RegisterGradient("Identity")
 def _IdGrad(_, grad):
diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py
index 32b424c..1ee71d9 100644
--- a/tensorflow/python/ops/array_ops.py
+++ b/tensorflow/python/ops/array_ops.py
@@ -2556,8 +2556,8 @@
          [0, 0, 0, 0]], dtype=int32)>
 
   Args:
-    shape: A `list` of integers, a `tuple` of integers, or a 1-D `Tensor` of
-      type `int32`.
+    shape: A `list` of integers, a `tuple` of integers, or
+      a 1-D `Tensor` of type `int32`.
     dtype: The DType of an element in the resulting `Tensor`.
     name: Optional string. A name for the operation.
 
@@ -2787,8 +2787,8 @@
          [1, 1, 1, 1]], dtype=int32)>
 
   Args:
-    shape: A `list` of integers, a `tuple` of integers, or a 1-D `Tensor` of
-      type `int32`.
+    shape: A `list` of integers, a `tuple` of integers, or
+      a 1-D `Tensor` of type `int32`.
     dtype: Optional DType of an element in the resulting `Tensor`. Default is
       `tf.float32`.
     name: Optional string. A name for the operation.
@@ -4797,8 +4797,8 @@
       axis=axis)
 
 
-@tf_export(
-    "quantization.dequantize", v1=["quantization.dequantize", "dequantize"])
+@tf_export("quantization.dequantize", v1=["quantization.dequantize",
+                                          "dequantize"])
 @deprecation.deprecated_endpoints("dequantize")
 def dequantize(  # pylint: disable=missing-docstring
     input,  # pylint: disable=redefined-builtin
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt
index edeb3ba..eab0efc 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt
@@ -645,6 +645,10 @@
     argspec: "args=[\'tensor\', \'message\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
   }
   member_method {
+    name: "CheckNumericsV2"
+    argspec: "args=[\'tensor\', \'message\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
     name: "Cholesky"
     argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
   }
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt
index edeb3ba..eab0efc 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt
@@ -645,6 +645,10 @@
     argspec: "args=[\'tensor\', \'message\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
   }
   member_method {
+    name: "CheckNumericsV2"
+    argspec: "args=[\'tensor\', \'message\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
     name: "Cholesky"
     argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
   }