Track op names in dtensor emitted ops

This change allows the original op name to show up in xprof of DTensor
emitted reduce ops.

PiperOrigin-RevId: 460516949
diff --git a/tensorflow/dtensor/mlir/BUILD b/tensorflow/dtensor/mlir/BUILD
index 593d77f..9f02ea3 100644
--- a/tensorflow/dtensor/mlir/BUILD
+++ b/tensorflow/dtensor/mlir/BUILD
@@ -137,6 +137,7 @@
     srcs = ["dtensor_location.cc"],
     hdrs = ["dtensor_location.h"],
     deps = [
+        "//tensorflow/compiler/mlir:name_utils",
         "@llvm-project//llvm:Support",
         "@llvm-project//mlir:IR",
         "@llvm-project//mlir:Support",
@@ -464,6 +465,7 @@
         ":tf_dtensor_dialect",
         ":value_utils",
         "//tensorflow/compiler/mlir:array_container_utils",
+        "//tensorflow/compiler/mlir:name_utils",
         "//tensorflow/compiler/mlir/hlo:convert_op_folder",
         "//tensorflow/compiler/mlir/tensorflow",
         "//tensorflow/compiler/mlir/tensorflow:convert_tensor",
@@ -512,6 +514,7 @@
     srcs = ["dtensor_location_test.cc"],
     deps = [
         ":dtensor_location",
+        "//tensorflow/compiler/mlir:name_utils",
         "//tensorflow/core:test",
         "//tensorflow/core:test_main",
         "@llvm-project//mlir:IR",
diff --git a/tensorflow/dtensor/mlir/dtensor_location.cc b/tensorflow/dtensor/mlir/dtensor_location.cc
index 41a5db8..194a1d4 100644
--- a/tensorflow/dtensor/mlir/dtensor_location.cc
+++ b/tensorflow/dtensor/mlir/dtensor_location.cc
@@ -16,16 +16,26 @@
 #include "tensorflow/dtensor/mlir/dtensor_location.h"
 
 #include <algorithm>
+#include <queue>
 #include <string>
 
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SmallVector.h"
 #include "llvm/Support/FormatVariadic.h"
 #include "llvm/Support/raw_ostream.h"
+#include "tensorflow/compiler/mlir/utils/name_utils.h"
 
 namespace tensorflow {
 namespace dtensor {
 
+namespace {
+std::string CreateLocalLocationString(mlir::FileLineColLoc loc) {
+  return llvm::formatv(">> {0}:{1}:{2}", loc.getFilename(), loc.getLine(),
+                       loc.getColumn())
+      .str();
+}
+}  // namespace
+
 mlir::Location DTensorLocation(mlir::Location loc, llvm::StringRef file,
                                unsigned int line) {
   // Strip dirname.
@@ -33,6 +43,11 @@
   if (!split.second.empty()) file = split.second;
   mlir::Location callee_loc =
       mlir::FileLineColLoc::get(loc.getContext(), file, line, 0);
+  const std::string name = GetNameFromLoc(loc);
+  if (!name.empty()) {
+    callee_loc = mlir::NameLoc::get(
+        mlir::StringAttr::get(loc.getContext(), name), callee_loc);
+  }
   return mlir::CallSiteLoc::get(/*callee=*/callee_loc, /*caller=*/loc);
 }
 
@@ -41,25 +56,24 @@
   return DTensorLocation(op->getLoc(), file, line);
 }
 
-std::string CreateLocalLocationString(mlir::FileLineColLoc loc) {
-  return llvm::formatv(">> {0}:{1}:{2}", loc.getFilename(), loc.getLine(),
-                       loc.getColumn())
-      .str();
-}
-
 std::string DTensorLocationToString(mlir::Location loc) {
   llvm::SmallVector<std::string, 4> stack;
-  while (auto callsite_loc = loc.dyn_cast<mlir::CallSiteLoc>()) {
-    if (auto callee_loc =
-            callsite_loc.getCallee().dyn_cast<mlir::FileLineColLoc>())
-      stack.push_back(CreateLocalLocationString(callee_loc));
+  std::queue<mlir::Location> queue;
+  queue.push(loc);
 
-    loc = callsite_loc.getCaller();
+  while (!queue.empty()) {
+    mlir::Location& front = queue.front();
+    if (auto name_loc = front.dyn_cast<mlir::NameLoc>()) {
+      queue.push(name_loc.getChildLoc());
+    } else if (auto callsite_loc = front.dyn_cast<mlir::CallSiteLoc>()) {
+      queue.push(callsite_loc.getCallee());
+      queue.push(callsite_loc.getCaller());
+    } else if (auto line_loc = front.dyn_cast<mlir::FileLineColLoc>()) {
+      stack.push_back(CreateLocalLocationString(line_loc));
+    }
+    queue.pop();
   }
 
-  if (auto file_line_col_loc = loc.dyn_cast<mlir::FileLineColLoc>())
-    stack.push_back(CreateLocalLocationString(file_line_col_loc));
-
   std::reverse(stack.begin(), stack.end());
   std::string s;
   llvm::raw_string_ostream ss(s);
diff --git a/tensorflow/dtensor/mlir/dtensor_location_test.cc b/tensorflow/dtensor/mlir/dtensor_location_test.cc
index 9c44bac..e66c17a 100644
--- a/tensorflow/dtensor/mlir/dtensor_location_test.cc
+++ b/tensorflow/dtensor/mlir/dtensor_location_test.cc
@@ -18,6 +18,7 @@
 #include "mlir/IR/Location.h"  // from @llvm-project
 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
 #include "mlir/Support/LLVM.h"  // from @llvm-project
+#include "tensorflow/compiler/mlir/utils/name_utils.h"
 #include "tensorflow/core/platform/test.h"
 
 namespace {
@@ -72,4 +73,28 @@
   EXPECT_EQ(tensorflow::dtensor::DTensorLocationToString(test_loc), stack);
 }
 
+TEST(DTensorLocationTest, HandlesNameLoc) {
+  mlir::MLIRContext ctx;
+  mlir::Location test_loc =
+      mlir::NameLoc::get(mlir::StringAttr::get(&ctx, "op@"),
+                         mlir::FileLineColLoc::get(&ctx, "test.cc", 10, 20));
+  test_loc = tensorflow::dtensor::DTensorLocation(test_loc, "test.cc", 21);
+  ASSERT_EQ(mlir::GetNameFromLoc(test_loc), "op");
+  ASSERT_TRUE(test_loc.isa<mlir::CallSiteLoc>());
+  auto callsite_loc = test_loc.cast<mlir::CallSiteLoc>();
+  mlir::Location caller_loc = test_loc.cast<mlir::CallSiteLoc>().getCaller();
+  ASSERT_TRUE(caller_loc.isa<mlir::NameLoc>());
+  CheckFileLineColLocation(caller_loc.cast<mlir::NameLoc>().getChildLoc(), 10,
+                           20);
+
+  mlir::Location callee_loc = callsite_loc.getCallee();
+  ASSERT_TRUE(callee_loc.isa<mlir::NameLoc>());
+  CheckFileLineColLocation(callee_loc.cast<mlir::NameLoc>().getChildLoc(), 21,
+                           0);
+
+  constexpr char stack[] = R"stack(>> test.cc:10:20
+>> test.cc:21:0)stack";
+  EXPECT_EQ(tensorflow::dtensor::DTensorLocationToString(test_loc), stack);
+}
+
 }  // namespace