Update CompileGraphToXlaHlo to support resource XlaArguments and populate XlaCompilationResult.resource_updates instead of XlaCompilationResult.outputs for resource writes.
This is necessary to enable MLIR support when compiling tf.functions with XLA that use resources.
PiperOrigin-RevId: 323692573
Change-Id: If3af722e5beb98dfd8260a4cb2df8db8a2a550ff
diff --git a/tensorflow/compiler/mlir/tensorflow/BUILD b/tensorflow/compiler/mlir/tensorflow/BUILD
index fe1f47d..8dcaf23 100644
--- a/tensorflow/compiler/mlir/tensorflow/BUILD
+++ b/tensorflow/compiler/mlir/tensorflow/BUILD
@@ -1459,6 +1459,7 @@
":mlir_roundtrip_flags",
":tensorflow",
":tensorflow_dialect_registration",
+ ":tensorflow_types",
":tensorflow_passes",
":translate_utils",
"@com_google_absl//absl/types:optional",
@@ -1520,6 +1521,9 @@
srcs = ["utils/compile_mlir_util_test.cc"],
deps = [
":compile_mlir_util",
+ "//tensorflow/cc:function_ops",
+ "//tensorflow/cc:resource_variable_ops",
+ "//tensorflow/cc:scope",
"//tensorflow/compiler/jit",
"//tensorflow/compiler/tf2xla:common",
"//tensorflow/compiler/tf2xla:xla_compiler",
diff --git a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc
index 9d6cc88..e273020 100644
--- a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc
+++ b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc
@@ -17,11 +17,14 @@
#include "absl/types/optional.h"
#include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/raw_ostream.h"
#include "mlir/Dialect/Shape/IR/Shape.h" // from @llvm-project
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
+#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/Dialect.h" // from @llvm-project
#include "mlir/IR/Function.h" // from @llvm-project
#include "mlir/IR/Location.h" // from @llvm-project
@@ -36,6 +39,7 @@
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
+#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
#include "tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.h"
#include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h"
@@ -52,6 +56,7 @@
#include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/compiler/xla/service/hlo_sharding.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/platform/logging.h"
namespace tensorflow {
@@ -79,9 +84,15 @@
return Status::OK();
}
+// Arguments to a computation can be either a tensor or resource.
+struct TensorOrResourceShape {
+ TensorShape shape;
+ bool is_resource = false;
+};
+
// Converts arg_shapes to xla::Shape's and store into xla_input_shapes.
Status GetXlaInputShapes(
- mlir::ModuleOp module, llvm::ArrayRef<TensorShape> arg_shapes,
+ mlir::ModuleOp module, llvm::ArrayRef<TensorOrResourceShape> arg_shapes,
bool use_tuple_args,
const XlaHelpers::ShapeRepresentationFn shape_representation_fn,
std::vector<xla::Shape>* xla_input_shapes) {
@@ -103,7 +114,7 @@
DataType dtype;
TF_RETURN_IF_ERROR(ConvertToDataType(func_type.getInput(i), &dtype));
TF_ASSIGN_OR_RETURN(xla_shape,
- shape_representation_fn(arg_shapes[i], dtype,
+ shape_representation_fn(arg_shapes[i].shape, dtype,
/*use_fast_memory=*/false));
// Rewrite layout with sharding, if sharding is set.
@@ -132,11 +143,13 @@
}
// Calculates computation output shape and build OutputDescription for each
-// output based on static shapes in MLIR module
+// output based on static shapes in MLIR module. If an output is a resource
+// write, `resource_updates` is populated insead of `outputs` for that output.
Status GetOutputInfo(
mlir::ModuleOp module,
const XlaHelpers::ShapeRepresentationFn shape_representation_fn,
- xla::Shape* xla_output_shape, std::vector<XlaOutputDescription>* outputs) {
+ xla::Shape* xla_output_shape, std::vector<XlaOutputDescription>* outputs,
+ std::vector<XlaResourceUpdate>* resource_updates) {
auto shape_representation_fn_no_fast_memory =
[shape_representation_fn](const TensorShape& shape, DataType dtype) {
return shape_representation_fn(shape, dtype, /*use_fast_memory=*/false);
@@ -147,17 +160,37 @@
outputs->clear();
outputs->reserve(func_type.getNumResults());
+ resource_updates->reserve(func_type.getNumResults());
std::vector<xla::Shape> shapes;
shapes.reserve(func_type.getNumResults());
- for (mlir::Type type : func_type.getResults()) {
+ llvm::SmallDenseMap<unsigned, unsigned> resource_arg_to_write;
+ for (unsigned i = 0; i < main_func.getNumArguments(); ++i)
+ if (auto aliasing_output = main_func.getArgAttrOfType<mlir::IntegerAttr>(
+ i, "tf.aliasing_output"))
+ resource_arg_to_write.insert({aliasing_output.getInt(), i});
+
+ for (auto type_and_idx : llvm::enumerate(func_type.getResults())) {
TF_ASSIGN_OR_RETURN(
xla::Shape shape,
- xla::TypeToShape(type, shape_representation_fn_no_fast_memory));
- auto tensor_type = type.dyn_cast<mlir::RankedTensorType>();
+ xla::TypeToShape(type_and_idx.value(),
+ shape_representation_fn_no_fast_memory));
+ auto tensor_type = type_and_idx.value().dyn_cast<mlir::RankedTensorType>();
shapes.push_back(shape);
+ auto it = resource_arg_to_write.find(type_and_idx.index());
+ if (it != resource_arg_to_write.end()) {
+ // Add resource write.
+ resource_updates->emplace_back();
+ XlaResourceUpdate& resource_update = resource_updates->back();
+ resource_update.input_index = it->getSecond();
+ resource_update.modified = true;
+ TF_RETURN_IF_ERROR(ConvertToDataType(tensor_type, &resource_update.type));
+ TF_RETURN_IF_ERROR(XLAShapeToTensorShape(shape, &resource_update.shape));
+ continue;
+ }
+
// Construct OutputDescription for result.
outputs->emplace_back();
XlaOutputDescription& out_desc = outputs->back();
@@ -180,14 +213,6 @@
return Status::OK();
}
-// Gets information about how computation updates Tensorflow resources.
-// TODO(ycao): Implement logic to compute resource updates when we need to
-// support graphs with resource updates in MLIR-based TF compiler bridge.
-void GetResourceUpdatesForMlir(
- std::vector<XlaResourceUpdate>* resource_updates) {
- resource_updates->clear();
-}
-
// Creates a vector that maps from the parameters of the XLA computation to
// their original argument positions.
// MLIR-based TF-Compiler bridge doesn't have constant analysis yet, thus no
@@ -201,7 +226,7 @@
}
// Refine MLIR types based on new shape information.
-Status RefineShapes(llvm::ArrayRef<TensorShape> arg_shapes,
+Status RefineShapes(llvm::ArrayRef<TensorOrResourceShape> arg_shapes,
mlir::ModuleOp module) {
auto producer_or = GetTfGraphProducerVersion(module);
if (!producer_or.ok()) return producer_or.status();
@@ -212,15 +237,20 @@
{
// Convert arg_shapes to a mlir friendly format.
size_t count = 0;
- for (const TensorShape& shape : arg_shapes) {
- count += shape.dims();
+ for (const TensorOrResourceShape& tensor_resource_shape : arg_shapes) {
+ if (tensor_resource_shape.is_resource) continue;
+ count += tensor_resource_shape.shape.dims();
}
shape_backing.resize(count);
arg_shapes_copy.reserve(arg_shapes.size());
size_t offset = 0;
- for (const TensorShape& shape : arg_shapes) {
+ for (const TensorOrResourceShape& tensor_resource_shape : arg_shapes) {
+ if (tensor_resource_shape.is_resource) {
+ arg_shapes_copy.push_back(llvm::ArrayRef<int64_t>());
+ continue;
+ }
size_t start = offset;
- for (tensorflow::TensorShapeDim dim : shape) {
+ for (tensorflow::TensorShapeDim dim : tensor_resource_shape.shape) {
shape_backing[offset] = dim.size;
++offset;
}
@@ -338,7 +368,7 @@
}
static Status CompileMlirToXlaHlo(
- mlir::ModuleOp module_op, llvm::ArrayRef<TensorShape> arg_shapes,
+ mlir::ModuleOp module_op, llvm::ArrayRef<TensorOrResourceShape> arg_shapes,
llvm::StringRef device_type, bool use_tuple_args,
XlaHelpers::ShapeRepresentationFn shape_representation_fn,
XlaCompilationResult* compilation_result,
@@ -372,14 +402,10 @@
shape_representation_fn,
&compilation_result->xla_input_shapes));
- // Compute all output descriptions.
- TF_RETURN_IF_ERROR(GetOutputInfo(module_op, shape_representation_fn,
- &compilation_result->xla_output_shape,
- &compilation_result->outputs));
-
- // Compute what resource variables need to be updated after XlaComputation's
- // execution.
- GetResourceUpdatesForMlir(&compilation_result->resource_updates);
+ // Compute all output descriptions and resource writes
+ TF_RETURN_IF_ERROR(GetOutputInfo(
+ module_op, shape_representation_fn, &compilation_result->xla_output_shape,
+ &compilation_result->outputs, &compilation_result->resource_updates));
if (VLOG_IS_ON(1))
tensorflow::DumpMlirOpToFile("mlir_compile_after", module_op);
@@ -399,26 +425,51 @@
TF_RETURN_IF_ERROR(
ParseMlirModule(mlir_module_string, &mlir_context, &mlir_module));
- return CompileMlirToXlaHlo(mlir_module.get(), arg_shapes, device_type,
- use_tuple_args, shape_representation_fn,
- compilation_result,
+ llvm::SmallVector<TensorOrResourceShape, 4> tensor_or_resource_shapes;
+ tensor_or_resource_shapes.reserve(arg_shapes.size());
+ for (const auto& arg_shape : arg_shapes)
+ tensor_or_resource_shapes.push_back({arg_shape});
+ return CompileMlirToXlaHlo(mlir_module.get(), tensor_or_resource_shapes,
+ device_type, use_tuple_args,
+ shape_representation_fn, compilation_result,
std::move(custom_legalization_passes));
}
// Rewrites the given module with specified args. For each of the constant args,
// it gets inlined in the "main' function and the corresponding argument is
-// removed from the signature.
+// removed from the signature. For resource args, their subtypes are populated.
// Returns the original indices for the other arguments on success.
static StatusOr<std::vector<int>> RewriteWithArgs(
mlir::ModuleOp module, llvm::ArrayRef<const XlaArgument> args) {
mlir::FuncOp main_fn = module.lookupSymbol<mlir::FuncOp>("main");
std::vector<int> params;
+ bool has_resource_args = false;
auto builder = mlir::OpBuilder(main_fn.getBody());
std::vector<int> args_to_erase;
for (int idx = 0; idx < args.size(); idx++) {
const XlaArgument& xla_arg = args[idx];
mlir::BlockArgument mlir_arg = main_fn.getArgument(idx);
+ if (xla_arg.kind == XlaArgument::kResource) {
+ mlir::Type element_type;
+ TF_RETURN_IF_ERROR(ConvertDataType(xla_arg.type, builder, &element_type));
+ auto resource_shape = absl::get<TensorShape>(xla_arg.shape).dim_sizes();
+ llvm::SmallVector<int64_t, 4> resource_subtype_shape(
+ resource_shape.begin(), resource_shape.end());
+ auto resource_subtype =
+ mlir::RankedTensorType::get(resource_subtype_shape, element_type);
+ auto resource_type =
+ mlir::TF::ResourceType::get({resource_subtype}, builder.getContext());
+
+ auto tensor_type = mlir_arg.getType().cast<mlir::TensorType>();
+ if (tensor_type.hasRank()) {
+ mlir_arg.setType(
+ mlir::RankedTensorType::get(tensor_type.getShape(), resource_type));
+ } else {
+ mlir_arg.setType(mlir::UnrankedTensorType::get(resource_type));
+ }
+ has_resource_args = true;
+ }
if (xla_arg.kind != XlaArgument::kConstant) {
params.push_back(idx);
continue;
@@ -433,7 +484,19 @@
args_to_erase.push_back(idx);
}
+ if (has_resource_args) {
+ llvm::SmallVector<mlir::Type, 4> updated_argument_types;
+ updated_argument_types.reserve(main_fn.getNumArguments());
+ for (mlir::BlockArgument& arg : main_fn.getArguments())
+ updated_argument_types.push_back(arg.getType());
+
+ main_fn.setType(mlir::FunctionType::get(updated_argument_types,
+ main_fn.getType().getResults(),
+ main_fn.getContext()));
+ }
+
for (int idx : llvm::reverse(args_to_erase)) main_fn.eraseArgument(idx);
+
return params;
}
@@ -456,10 +519,13 @@
mlir::ModuleOp module = module_or.ValueOrDie().get();
TF_ASSIGN_OR_RETURN(std::vector<int> remaining_params,
RewriteWithArgs(module, {args.data(), args.size()}));
- llvm::SmallVector<TensorShape, 4> arg_shapes;
- arg_shapes.reserve(args.size());
- for (unsigned idx : remaining_params)
- arg_shapes.push_back(absl::get<TensorShape>(args[idx].shape));
+ llvm::SmallVector<TensorOrResourceShape, 4> arg_shapes;
+ arg_shapes.reserve(remaining_params.size());
+ for (unsigned idx : remaining_params) {
+ const auto& arg = args[idx];
+ arg_shapes.push_back({absl::get<TensorShape>(arg.shape),
+ /*is_resource=*/arg.kind == XlaArgument::kResource});
+ }
mlir::PassManager pm(&context);
mlir::TF::StandardPipelineOptions tf_options;
diff --git a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h
index 719a96f..5c64a65 100644
--- a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h
+++ b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h
@@ -73,6 +73,7 @@
std::vector<std::unique_ptr<mlir::Pass>> custom_legalization_passes = {});
// Same as the above but takes input as TensorFlow Graph.
+// TODO(lyandy): Allow populating of targets/control outputs.
Status CompileGraphToXlaHlo(
const Graph& graph, llvm::ArrayRef<const XlaArgument> args,
llvm::StringRef device_type, bool use_tuple_args,
diff --git a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util_test.cc b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util_test.cc
index dde2408..6ebf689 100644
--- a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util_test.cc
+++ b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util_test.cc
@@ -15,6 +15,9 @@
#include "tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h"
+#include "tensorflow/cc/framework/scope.h"
+#include "tensorflow/cc/ops/function_ops.h"
+#include "tensorflow/cc/ops/resource_variable_ops.h"
#include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
@@ -448,9 +451,6 @@
FunctionLibraryDefinition flib_def(OpRegistry::Global(), {});
Graph graph(OpRegistry::Global());
- Tensor dummy_tensor(DT_FLOAT, TensorShape({1}));
- test::FillValues<float>(&dummy_tensor, {-1.0});
-
Node* arg = test::graph::Arg(&graph, 0, DT_FLOAT);
test::graph::Retval(&graph, 0, arg);
@@ -483,5 +483,60 @@
status_or_hlo_module.ValueOrDie()->ToString());
}
+// Tests a conversion from Graph to MLIR with resource arguments.
+TEST(CompileGraphToXlaHlo, Resources) {
+ FunctionLibraryDefinition flib_def(OpRegistry::Global(), {});
+ Graph graph(OpRegistry::Global());
+
+ Scope scope = Scope::NewRootScope().ExitOnError();
+ auto val = ops::_Arg(scope.WithOpName("arg0"), DT_FLOAT, 0);
+ auto var = ops::_Arg(scope.WithOpName("arg1"), DT_RESOURCE, 1);
+ auto assign =
+ ops::AssignVariableOp(scope.WithOpName("assign_variable"), var, val);
+ TF_ASSERT_OK(scope.ToGraph(&graph));
+
+ XlaCompiler::CompilationResult result;
+ XlaCompiler::Argument arg0;
+ arg0.kind = XlaCompiler::Argument::kParameter;
+ arg0.shape = TensorShape({2});
+ XlaCompiler::Argument arg1;
+ arg1.kind = XlaCompiler::Argument::kResource;
+ arg1.shape = TensorShape({2});
+ arg1.type = DT_FLOAT;
+
+ TF_ASSERT_OK(
+ CompileGraphToXlaHlo(graph, /*args=*/{arg0, arg1}, "XLA_CPU_JIT",
+ /*use_tuple_args=*/false, flib_def, GraphDebugInfo(),
+ /*shape_representation_fn=*/nullptr, &result));
+
+ EXPECT_EQ(result.outputs.size(), 0);
+ ASSERT_EQ(result.resource_updates.size(), 1);
+ const auto& resource_update = result.resource_updates[0];
+ EXPECT_EQ(resource_update.input_index, 1);
+ EXPECT_EQ(resource_update.modified, true);
+ EXPECT_EQ(resource_update.shape, TensorShape({2}));
+ EXPECT_EQ(resource_update.type, DT_FLOAT);
+
+ const xla::HloModuleConfig module_config(
+ result.computation->GetProgramShape().ValueOrDie());
+ auto status_or_hlo_module = xla::HloModule::CreateFromProto(
+ result.computation->proto(), module_config);
+ ASSERT_TRUE(status_or_hlo_module.ok());
+
+ constexpr char expected_hlo_module_string[] =
+ R"(HloModule main.4, input_output_alias={ {0}: 1 }
+
+ENTRY %main.4 (Arg_0.1: f32[2], Arg_1.2: f32[2]) -> (f32[2]) {
+ %Arg_1.2 = f32[2]{0} parameter(1)
+ %Arg_0.1 = f32[2]{0} parameter(0)
+ ROOT %tuple.3 = (f32[2]{0}) tuple(f32[2]{0} %Arg_0.1)
+}
+
+)";
+
+ EXPECT_EQ(expected_hlo_module_string,
+ status_or_hlo_module.ValueOrDie()->ToString());
+}
+
} // namespace
} // namespace tensorflow