[tfdbg] Add Shape mode to DebugNumericSummaryV2Op.

- The TensorDebugMode added is SHAPE,
  a mode that computes a shape-[10] rank-1 tensor given any float-type tensor.
  The first element is the id of the tensor. The second element is the dtype of the
  tensor, represented by the enumerated type defined in
  tensorflow/core/framework/types.proto. The third and fourth elements are the rank
  and size of the tensor respectively, and finally the fourth to tenth elements
  represent the shape of the tensor. Shorter shapes are right-padded with zero and
  longer shapes have the head truncated.
- The CPU and GPU kernels of the op are added.

PiperOrigin-RevId: 284243269
Change-Id: I2adc2c68792ee284ac2401bedd816c0ea960f87b
diff --git a/tensorflow/core/kernels/debug_ops.h b/tensorflow/core/kernels/debug_ops.h
index 963b2bb..643dfda 100644
--- a/tensorflow/core/kernels/debug_ops.h
+++ b/tensorflow/core/kernels/debug_ops.h
@@ -514,7 +514,7 @@
     const int64 size = in.size();
     Tensor* output_tensor;
     Tout tensor_id = static_cast<Tout>(tensor_id_);
-    const float num_elem = static_cast<float>(context->input(0).NumElements());
+    const Tout num_elem = static_cast<Tout>(context->input(0).NumElements());
     // Disregard lossy cast if mode is REDUCE_INF_NAN_THREE_SLOTS because
     // that mode does not make use of tensor_id.
     if (tensor_debug_mode_ != 8) {
@@ -565,6 +565,32 @@
       output_tensor->flat<Tout>()(2) = fp_props[0];  // Slot for -inf count
       output_tensor->flat<Tout>()(3) = fp_props[1];  // Slot for inf count
       output_tensor->flat<Tout>()(4) = fp_props[2];  // Slot for nan count
+    } else if (tensor_debug_mode_ == 5) {            // SHAPE
+      TensorShape shape({10});
+      OP_REQUIRES_OK(context,
+                     context->allocate_output(0, shape, &output_tensor));
+
+      int num_dims = tensor.dims();
+      output_tensor->flat<Tout>()(0) = tensor_id;
+      output_tensor->flat<Tout>()(1) = static_cast<Tout>(tensor.dtype());
+      output_tensor->flat<Tout>()(2) = static_cast<Tout>(num_dims);
+      output_tensor->flat<Tout>()(3) = num_elem;
+
+      // Tensor shape - stored as (6 columns)
+      // if num_dim is less than 6, we right pad the shape with zeros
+      // if num_dim is greater than 6, we truncate the head (left most) of the
+      // dimensions as they are more predictable than the last few (e.g. batch
+      // size as first dimension)
+      int dim_idx = 4;
+      for (int i = std::max(0, num_dims - kShapeDims);
+           i < std::max(6, num_dims); ++i) {
+        if (i < num_dims) {
+          output_tensor->flat<Tout>()(dim_idx++) =
+              static_cast<Tout>(tensor.dim_size(i));
+        } else {
+          output_tensor->flat<Tout>()(dim_idx++) = 0.0;
+        }
+      }
     } else if (tensor_debug_mode_ == 8) {  // REDUCE_INF_NAN_THREE_SLOTS.
       TensorShape shape({3});
       OP_REQUIRES_OK(context,
@@ -605,6 +631,7 @@
  private:
   int tensor_debug_mode_;
   int64 tensor_id_;
+  static constexpr int kShapeDims = 6;
   static constexpr int kNegInfBit = 0x01;
   static constexpr int kPosInfBit = 0x02;
   static constexpr int kNaNBit = 0x04;
@@ -628,7 +655,11 @@
   void ComputeAsync(OpKernelContext* context, DoneCallback done) override {
     Tensor* output_tensor;
     Tout tensor_id = static_cast<Tout>(tensor_id_);
-    const float num_elem = static_cast<float>(context->input(0).NumElements());
+    const Tensor& tensor = context->input(0);
+    const Tout num_elem = static_cast<Tout>(tensor.NumElements());
+    const Device& d = context->eigen_device<Device>();
+    auto input = tensor.flat<Tin>();
+    auto check_cb = [this, done]() { done(); };
     // Disregard lossy cast if mode is REDUCE_INF_NAN_THREE_SLOTS because
     // that mode does not make use of tensor_id.
     if (tensor_debug_mode_ != 8) {
@@ -657,19 +688,16 @@
       stream->ThenMemZero(&output_tensor_ptr, 2 * sizeof(Tout));
       // Copy tensor_id to slot zero
       stream->ThenMemcpy(&output_tensor_ptr, &tensor_id, sizeof(Tout));
-      if (context->input(0).NumElements() == 0) {
+      if (num_elem == 0) {
         done();
         return;
       }
 
       // Call the GPU kernels for the numerical (inf/nan) checks.
-      const Device& d = context->eigen_device<Device>();
       auto input = context->input(0).flat<Tin>();
       CurtHealthLaunch<Tin, Tout>().Run(d, input.data(), input.size(),
                                         output_tensor->flat<Tout>().data() + 1);
 
-      auto check_cb = [this, done]() { done(); };
-
       context->device()->tensorflow_gpu_device_info()->event_mgr->ThenExecute(
           stream, std::move(check_cb));
     } else if (tensor_debug_mode_ == 3) {  // CONCISE_HEALTH.
@@ -693,14 +721,43 @@
       }
 
       // Call the GPU kernels for the numerical (inf/nan) checks.
-      const Device& d = context->eigen_device<Device>();
-      auto input = context->input(0).flat<Tin>();
       ConciseHealthLaunch<Tin, Tout>().Run(
           d, input.data(), input.size(),
           output_tensor->flat<Tout>().data() + 2);
 
-      auto check_cb = [this, done]() { done(); };
+      context->device()->tensorflow_gpu_device_info()->event_mgr->ThenExecute(
+          stream, std::move(check_cb));
+    } else if (tensor_debug_mode_ == 5) {  // SHAPE
+      TensorShape shape({10});
+      OP_REQUIRES_OK(context,
+                     context->allocate_output(0, shape, &output_tensor));
 
+      auto* stream = context->op_device_context()->stream();
+      OP_REQUIRES_ASYNC(context, stream != nullptr,
+                        errors::Internal("No GPU stream available."), done);
+
+      se::DeviceMemoryBase output_tensor_ptr(
+          output_tensor->flat<Tout>().data(),
+          output_tensor->flat<Tout>().size());
+
+      int num_dims = tensor.dims();
+      Tout static_output[10] = {tensor_id,
+                                static_cast<Tout>(tensor.dtype()),
+                                static_cast<Tout>(num_dims),
+                                num_elem,
+                                0.0,
+                                0.0,
+                                0.0,
+                                0.0,
+                                0.0,
+                                0.0};
+      // Tensor shape: right pad zeros, truncate head
+      int dim_idx = 4;
+      for (int i = std::max(0, num_dims - 6); i < num_dims; ++i) {
+        static_output[dim_idx++] = static_cast<Tout>(tensor.dim_size(i));
+      }
+      // Write to device stream
+      stream->ThenMemcpy(&output_tensor_ptr, &static_output, sizeof(Tout) * 10);
       context->device()->tensorflow_gpu_device_info()->event_mgr->ThenExecute(
           stream, std::move(check_cb));
     } else if (tensor_debug_mode_ == 8) {  // REDUCE_INF_NAN_THREE_SLOTS.
@@ -717,19 +774,16 @@
           output_tensor->flat<Tout>().size());
       stream->ThenMemset32(&output_tensor_ptr, 0,
                            output_tensor->flat<Tout>().size() * sizeof(Tout));
-      if (context->input(0).NumElements() == 0) {
+      if (num_elem == 0) {
         done();
         return;
       }
 
       // Call the GPU kernels for the numerical (inf/nan) checks.
-      const Device& d = context->eigen_device<Device>();
       auto input = context->input(0).flat<Tin>();
       ReduceInfNanThreeSlotsLaunch<Tin, Tout>().Run(
           d, input.data(), input.size(), output_tensor->flat<Tout>().data());
 
-      auto check_cb = [this, done]() { done(); };
-
       context->device()->tensorflow_gpu_device_info()->event_mgr->ThenExecute(
           stream, std::move(check_cb));
     } else {
diff --git a/tensorflow/python/debug/lib/debug_v2_ops_test.py b/tensorflow/python/debug/lib/debug_v2_ops_test.py
index 76d077c..ea3d897 100644
--- a/tensorflow/python/debug/lib/debug_v2_ops_test.py
+++ b/tensorflow/python/debug/lib/debug_v2_ops_test.py
@@ -272,6 +272,7 @@
     modes = [
         debug_event_pb2.TensorDebugMode.CURT_HEALTH,
         debug_event_pb2.TensorDebugMode.CONCISE_HEALTH,
+        debug_event_pb2.TensorDebugMode.SHAPE,
     ]
     # Maximum allowed tensor_id
     tensor_id = np.power(2, 53)
@@ -481,6 +482,70 @@
     self.assertAllEqual(tensor_1, tensor_2)
     self.assertEqual(tensor_id_1, tensor_id_2)
 
+  @test_util.run_in_graph_and_eager_modes
+  def testDebugNumericSummaryV2OpShapeEmpty(self):
+
+    def debug_summary(x):
+      return self.evaluate(
+          gen_debug_ops.debug_numeric_summary_v2(
+              x,
+              tensor_debug_mode=(debug_event_pb2.TensorDebugMode.SHAPE),
+              tensor_id=x._id,
+              output_dtype=dtypes.float64)), x._id
+
+    tensor, tensor_id = debug_summary(constant_op.constant(0.0))
+    self.assertAllEqual(
+        tensor, [tensor_id, 1.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])
+
+  @test_util.run_in_graph_and_eager_modes
+  def testDebugNumericSummaryV2OpShapeSmall(self):
+
+    def debug_summary(x):
+      return self.evaluate(
+          gen_debug_ops.debug_numeric_summary_v2(
+              x,
+              tensor_debug_mode=(debug_event_pb2.TensorDebugMode.SHAPE),
+              tensor_id=x._id,
+              output_dtype=dtypes.float64)), x._id
+
+    x = np.zeros([3, 4], dtype=np.float32)
+    tensor, tensor_id = debug_summary(constant_op.constant(x))
+    self.assertAllEqual(
+        tensor, [tensor_id, 1.0, 2.0, 12.0, 3.0, 4.0, 0.0, 0.0, 0.0, 0.0])
+
+    x = np.ones([1, 2, 3, 4, 5, 6], dtype=np.float16)
+    x[0, 1, 2, 2, 2, 2] = np.nan
+    tensor, tensor_id = debug_summary(constant_op.constant(x))
+    self.assertAllEqual(
+        tensor,
+        [tensor_id, 19, 6.0, 2 * 3 * 4 * 5 * 6, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
+
+    x = np.zeros([2], dtype=np.float32)
+    tensor, tensor_id = debug_summary(constant_op.constant(x))
+    self.assertAllEqual(
+        tensor, [tensor_id, 1.0, 1.0, 2.0, 2.0, 0.0, 0.0, 0.0, 0.0, 0.0])
+
+    tensor, tensor_id = debug_summary(constant_op.constant([]))
+    self.assertAllEqual(
+        tensor, [tensor_id, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])
+
+  @test_util.run_in_graph_and_eager_modes
+  def testDebugNumericSummaryV2OpShapeLarge(self):
+
+    def debug_summary(x):
+      return self.evaluate(
+          gen_debug_ops.debug_numeric_summary_v2(
+              x,
+              tensor_debug_mode=(debug_event_pb2.TensorDebugMode.SHAPE),
+              tensor_id=x._id,
+              output_dtype=dtypes.float64)), x._id
+
+    x = np.ones([1, 2, 3, 4, 5, 6, 7], dtype=np.double)
+    tensor, tensor_id = debug_summary(constant_op.constant(x))
+    self.assertAllEqual(tensor, [
+        tensor_id, 2.0, 7.0, 2 * 3 * 4 * 5 * 6 * 7, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0
+    ])
+
 
 if __name__ == "__main__":
   ops.enable_eager_execution()