Track op names in dtensor emitted ops
This change allows the original op name to show up in xprof of DTensor
emitted reduce ops.
PiperOrigin-RevId: 460516949
diff --git a/tensorflow/dtensor/mlir/BUILD b/tensorflow/dtensor/mlir/BUILD
index 593d77f..9f02ea3 100644
--- a/tensorflow/dtensor/mlir/BUILD
+++ b/tensorflow/dtensor/mlir/BUILD
@@ -137,6 +137,7 @@
srcs = ["dtensor_location.cc"],
hdrs = ["dtensor_location.h"],
deps = [
+ "//tensorflow/compiler/mlir:name_utils",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Support",
@@ -464,6 +465,7 @@
":tf_dtensor_dialect",
":value_utils",
"//tensorflow/compiler/mlir:array_container_utils",
+ "//tensorflow/compiler/mlir:name_utils",
"//tensorflow/compiler/mlir/hlo:convert_op_folder",
"//tensorflow/compiler/mlir/tensorflow",
"//tensorflow/compiler/mlir/tensorflow:convert_tensor",
@@ -512,6 +514,7 @@
srcs = ["dtensor_location_test.cc"],
deps = [
":dtensor_location",
+ "//tensorflow/compiler/mlir:name_utils",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"@llvm-project//mlir:IR",
diff --git a/tensorflow/dtensor/mlir/dtensor_location.cc b/tensorflow/dtensor/mlir/dtensor_location.cc
index 41a5db8..194a1d4 100644
--- a/tensorflow/dtensor/mlir/dtensor_location.cc
+++ b/tensorflow/dtensor/mlir/dtensor_location.cc
@@ -16,16 +16,26 @@
#include "tensorflow/dtensor/mlir/dtensor_location.h"
#include <algorithm>
+#include <queue>
#include <string>
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/raw_ostream.h"
+#include "tensorflow/compiler/mlir/utils/name_utils.h"
namespace tensorflow {
namespace dtensor {
+namespace {
+std::string CreateLocalLocationString(mlir::FileLineColLoc loc) {
+ return llvm::formatv(">> {0}:{1}:{2}", loc.getFilename(), loc.getLine(),
+ loc.getColumn())
+ .str();
+}
+} // namespace
+
mlir::Location DTensorLocation(mlir::Location loc, llvm::StringRef file,
unsigned int line) {
// Strip dirname.
@@ -33,6 +43,11 @@
if (!split.second.empty()) file = split.second;
mlir::Location callee_loc =
mlir::FileLineColLoc::get(loc.getContext(), file, line, 0);
+ const std::string name = GetNameFromLoc(loc);
+ if (!name.empty()) {
+ callee_loc = mlir::NameLoc::get(
+ mlir::StringAttr::get(loc.getContext(), name), callee_loc);
+ }
return mlir::CallSiteLoc::get(/*callee=*/callee_loc, /*caller=*/loc);
}
@@ -41,25 +56,24 @@
return DTensorLocation(op->getLoc(), file, line);
}
-std::string CreateLocalLocationString(mlir::FileLineColLoc loc) {
- return llvm::formatv(">> {0}:{1}:{2}", loc.getFilename(), loc.getLine(),
- loc.getColumn())
- .str();
-}
-
std::string DTensorLocationToString(mlir::Location loc) {
llvm::SmallVector<std::string, 4> stack;
- while (auto callsite_loc = loc.dyn_cast<mlir::CallSiteLoc>()) {
- if (auto callee_loc =
- callsite_loc.getCallee().dyn_cast<mlir::FileLineColLoc>())
- stack.push_back(CreateLocalLocationString(callee_loc));
+ std::queue<mlir::Location> queue;
+ queue.push(loc);
- loc = callsite_loc.getCaller();
+ while (!queue.empty()) {
+ mlir::Location& front = queue.front();
+ if (auto name_loc = front.dyn_cast<mlir::NameLoc>()) {
+ queue.push(name_loc.getChildLoc());
+ } else if (auto callsite_loc = front.dyn_cast<mlir::CallSiteLoc>()) {
+ queue.push(callsite_loc.getCallee());
+ queue.push(callsite_loc.getCaller());
+ } else if (auto line_loc = front.dyn_cast<mlir::FileLineColLoc>()) {
+ stack.push_back(CreateLocalLocationString(line_loc));
+ }
+ queue.pop();
}
- if (auto file_line_col_loc = loc.dyn_cast<mlir::FileLineColLoc>())
- stack.push_back(CreateLocalLocationString(file_line_col_loc));
-
std::reverse(stack.begin(), stack.end());
std::string s;
llvm::raw_string_ostream ss(s);
diff --git a/tensorflow/dtensor/mlir/dtensor_location_test.cc b/tensorflow/dtensor/mlir/dtensor_location_test.cc
index 9c44bac..e66c17a 100644
--- a/tensorflow/dtensor/mlir/dtensor_location_test.cc
+++ b/tensorflow/dtensor/mlir/dtensor_location_test.cc
@@ -18,6 +18,7 @@
#include "mlir/IR/Location.h" // from @llvm-project
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/Support/LLVM.h" // from @llvm-project
+#include "tensorflow/compiler/mlir/utils/name_utils.h"
#include "tensorflow/core/platform/test.h"
namespace {
@@ -72,4 +73,28 @@
EXPECT_EQ(tensorflow::dtensor::DTensorLocationToString(test_loc), stack);
}
+TEST(DTensorLocationTest, HandlesNameLoc) {
+ mlir::MLIRContext ctx;
+ mlir::Location test_loc =
+ mlir::NameLoc::get(mlir::StringAttr::get(&ctx, "op@"),
+ mlir::FileLineColLoc::get(&ctx, "test.cc", 10, 20));
+ test_loc = tensorflow::dtensor::DTensorLocation(test_loc, "test.cc", 21);
+ ASSERT_EQ(mlir::GetNameFromLoc(test_loc), "op");
+ ASSERT_TRUE(test_loc.isa<mlir::CallSiteLoc>());
+ auto callsite_loc = test_loc.cast<mlir::CallSiteLoc>();
+ mlir::Location caller_loc = test_loc.cast<mlir::CallSiteLoc>().getCaller();
+ ASSERT_TRUE(caller_loc.isa<mlir::NameLoc>());
+ CheckFileLineColLocation(caller_loc.cast<mlir::NameLoc>().getChildLoc(), 10,
+ 20);
+
+ mlir::Location callee_loc = callsite_loc.getCallee();
+ ASSERT_TRUE(callee_loc.isa<mlir::NameLoc>());
+ CheckFileLineColLocation(callee_loc.cast<mlir::NameLoc>().getChildLoc(), 21,
+ 0);
+
+ constexpr char stack[] = R"stack(>> test.cc:10:20
+>> test.cc:21:0)stack";
+ EXPECT_EQ(tensorflow::dtensor::DTensorLocationToString(test_loc), stack);
+}
+
} // namespace