blob: 6ebf6897bb12decdf69a6767b8fd09497c3e77ae [file] [log] [blame]
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h"
#include "tensorflow/cc/framework/scope.h"
#include "tensorflow/cc/ops/function_ops.h"
#include "tensorflow/cc/ops/resource_variable_ops.h"
#include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/graph/testlib.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/stream_executor/lib/statusor.h"
namespace tensorflow {
namespace {
// A dummy shape representation function that simply converts given shape into
// an xla::Shape without assigning any layouts.
xla::StatusOr<xla::Shape> TestShapeRepresentation(const TensorShape& shape,
DataType type,
bool use_fast_memory) {
xla::Shape xla_shape;
TF_RETURN_IF_ERROR(TensorShapeToXLAShape(type, shape, &xla_shape));
return xla_shape;
}
TEST(CompileSerializedMlirToXlaHloTest, InvalidSerializedMlirModule) {
constexpr char invalid_mlir_module[] =
"totally @invalid MLIR module {here} <-";
std::vector<TensorShape> arg_shapes;
XlaCompiler::CompilationResult compilation_result;
Status s = CompileSerializedMlirToXlaHlo(
invalid_mlir_module, arg_shapes, "XLA_CPU_JIT",
/*use_tuple_args=*/true, TestShapeRepresentation, &compilation_result);
EXPECT_EQ(s.code(), tensorflow::errors::Code::INVALID_ARGUMENT);
EXPECT_EQ(s.ToString(),
"Invalid argument: could not parse MLIR module-:1:1: error: "
"custom op 'totally' is unknown\n");
}
constexpr llvm::StringRef kBinaryAddModule = R"(
module attributes {tf.versions = {producer = 179 : i32}} {
func @main(%arg0: tensor<f32>, %arg1: tensor<f32>) -> tensor<f32> {
%0 = "tf.AddV2"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", name = "add"} : (tensor<f32>, tensor<f32>) -> tensor<f32>
return %0 : tensor<f32>
}
}
)";
TEST(CompileSerializedMlirToXlaHloTest, TupleArgs) {
std::vector<TensorShape> arg_shapes(2, TensorShape());
XlaCompiler::CompilationResult compilation_result;
Status s = CompileSerializedMlirToXlaHlo(
kBinaryAddModule, arg_shapes, "XLA_CPU_JIT",
/*use_tuple_args=*/true, TestShapeRepresentation, &compilation_result);
TF_ASSERT_OK(s);
const xla::HloModuleConfig module_config(
compilation_result.computation->GetProgramShape().ValueOrDie());
auto status_or_hlo_module = xla::HloModule::CreateFromProto(
compilation_result.computation->proto(), module_config);
TF_ASSERT_OK(status_or_hlo_module.status());
constexpr char expected_hlo_module_string[] = R"(HloModule main.6
ENTRY %main.6 (arg_tuple.1: (f32[], f32[])) -> (f32[]) {
%arg_tuple.1 = (f32[], f32[]) parameter(0)
%get-tuple-element.2 = f32[] get-tuple-element((f32[], f32[]) %arg_tuple.1), index=0
%get-tuple-element.3 = f32[] get-tuple-element((f32[], f32[]) %arg_tuple.1), index=1
%add.4 = f32[] add(f32[] %get-tuple-element.2, f32[] %get-tuple-element.3)
ROOT %tuple.5 = (f32[]) tuple(f32[] %add.4)
}
)";
EXPECT_EQ(expected_hlo_module_string,
status_or_hlo_module.ValueOrDie()->ToString());
// Expect an in order input mapping.
EXPECT_EQ(compilation_result.input_mapping, std::vector<int>({0, 1}));
// Expect a single tuple-shape, containing two F32 scalars.
EXPECT_EQ(compilation_result.xla_input_shapes.size(), 1);
xla::Shape expected_input_shape =
xla::ShapeUtil::MakeTupleShape({xla::ShapeUtil::MakeShape(xla::F32, {}),
xla::ShapeUtil::MakeShape(xla::F32, {})});
EXPECT_EQ(compilation_result.xla_input_shapes.front(), expected_input_shape);
// Expect output shape is a tuple shape containing a single F32 Scalar type.
const xla::Shape output_shape =
xla::ShapeUtil::MakeShape(xla::PrimitiveType::F32, {});
const xla::Shape tuple_output_shape =
xla::ShapeUtil::MakeTupleShape({output_shape});
EXPECT_EQ(compilation_result.xla_output_shape, tuple_output_shape);
// Expect exactly 1 OutputDescription.
EXPECT_EQ(compilation_result.outputs.size(), 1);
const XlaCompiler::OutputDescription& output_desc =
compilation_result.outputs.front();
EXPECT_EQ(output_desc.type, DataType::DT_FLOAT);
EXPECT_EQ(output_desc.shape, TensorShape());
EXPECT_FALSE(output_desc.is_constant);
EXPECT_FALSE(output_desc.is_tensor_list);
// Expect no resource updates from computation.
EXPECT_TRUE(compilation_result.resource_updates.empty());
}
TEST(CompileSerializedMlirToXlaHloTest, IndividualArgs) {
std::vector<TensorShape> arg_shapes(2, TensorShape());
XlaCompiler::CompilationResult compilation_result;
Status s = CompileSerializedMlirToXlaHlo(
kBinaryAddModule, arg_shapes, "XLA_CPU_JIT",
/*use_tuple_args=*/false, TestShapeRepresentation, &compilation_result);
TF_ASSERT_OK(s);
const xla::HloModuleConfig module_config(
compilation_result.computation->GetProgramShape().ValueOrDie());
auto status_or_hlo_module = xla::HloModule::CreateFromProto(
compilation_result.computation->proto(), module_config);
TF_ASSERT_OK(status_or_hlo_module.status());
constexpr char expected_hlo_module_string[] = R"(HloModule main.5
ENTRY %main.5 (Arg_0.1: f32[], Arg_1.2: f32[]) -> (f32[]) {
%Arg_0.1 = f32[] parameter(0)
%Arg_1.2 = f32[] parameter(1)
%add.3 = f32[] add(f32[] %Arg_0.1, f32[] %Arg_1.2)
ROOT %tuple.4 = (f32[]) tuple(f32[] %add.3)
}
)";
EXPECT_EQ(expected_hlo_module_string,
status_or_hlo_module.ValueOrDie()->ToString());
// Expect an in order input mapping.
EXPECT_EQ(compilation_result.input_mapping, std::vector<int>({0, 1}));
// Expect two inputs, each containing a F32 scalar.
EXPECT_EQ(compilation_result.xla_input_shapes.size(), 2);
xla::Shape expected_input_shape = xla::ShapeUtil::MakeShape(xla::F32, {});
EXPECT_EQ(compilation_result.xla_input_shapes[0], expected_input_shape);
EXPECT_EQ(compilation_result.xla_input_shapes[1], expected_input_shape);
// Expect output shape is a tuple shape containing a single F32 Scalar type.
const xla::Shape output_shape =
xla::ShapeUtil::MakeShape(xla::PrimitiveType::F32, {});
const xla::Shape tuple_output_shape =
xla::ShapeUtil::MakeTupleShape({output_shape});
EXPECT_EQ(compilation_result.xla_output_shape, tuple_output_shape);
// Expect exactly 1 OutputDescription.
EXPECT_EQ(compilation_result.outputs.size(), 1);
const XlaCompiler::OutputDescription& output_desc =
compilation_result.outputs.front();
EXPECT_EQ(output_desc.type, DataType::DT_FLOAT);
EXPECT_EQ(output_desc.shape, TensorShape());
EXPECT_FALSE(output_desc.is_constant);
EXPECT_FALSE(output_desc.is_tensor_list);
// Expect no resource updates from computation.
EXPECT_TRUE(compilation_result.resource_updates.empty());
}
// Tests that foldable ops are constant-folded to enable legalization of ops
// that require compile time constant operand.
TEST(CompileSerializedMlirToXlaHloTest, CompileTimeConstantFoldedSuccess) {
// "tf.Shape" can only be folded away after shape inference. tf.Reshape can
// only be lowered when tf.Shape is folded into a constant.
constexpr char mlir_module[] = R"(
module attributes {tf.versions = {producer = 179 : i32}} {
func @main(%arg0: tensor<10x19xf32>, %arg1: tensor<19x10xf32> {mhlo.is_same_data_across_replicas}) -> tensor<10x19xf32> {
%0 = "tf.Shape"(%arg0) : (tensor<10x19xf32>) -> tensor<2xi64>
%1 = "tf.Reshape"(%arg1, %0) : (tensor<19x10xf32>, tensor<2xi64>) -> tensor<10x19xf32>
return %1 : tensor<10x19xf32>
}
}
)";
std::vector<TensorShape> arg_shapes{TensorShape({10, 19}),
TensorShape({19, 10})};
XlaCompiler::CompilationResult compilation_result;
Status s = CompileSerializedMlirToXlaHlo(
mlir_module, arg_shapes, "XLA_CPU_JIT",
/*use_tuple_args=*/true, TestShapeRepresentation, &compilation_result);
TF_ASSERT_OK(s);
const xla::HloModuleConfig module_config(
compilation_result.computation->GetProgramShape().ValueOrDie());
auto status_or_hlo_module = xla::HloModule::CreateFromProto(
compilation_result.computation->proto(), module_config);
TF_ASSERT_OK(status_or_hlo_module.status());
constexpr char expected_hlo_module_string[] = R"(HloModule main.6
ENTRY %main.6 (arg_tuple.1: (f32[10,19], f32[19,10])) -> (f32[10,19]) {
%arg_tuple.1 = (f32[10,19]{1,0}, f32[19,10]{1,0}) parameter(0), parameter_replication={false,true}
%get-tuple-element.2 = f32[10,19]{1,0} get-tuple-element((f32[10,19]{1,0}, f32[19,10]{1,0}) %arg_tuple.1), index=0
%get-tuple-element.3 = f32[19,10]{1,0} get-tuple-element((f32[10,19]{1,0}, f32[19,10]{1,0}) %arg_tuple.1), index=1
%reshape.4 = f32[10,19]{1,0} reshape(f32[19,10]{1,0} %get-tuple-element.3)
ROOT %tuple.5 = (f32[10,19]{1,0}) tuple(f32[10,19]{1,0} %reshape.4)
}
)";
EXPECT_EQ(expected_hlo_module_string,
status_or_hlo_module.ValueOrDie()->ToString());
}
TEST(CompileSerializedMlirToXlaHloTest, ShapeInference) {
constexpr char mlir_module[] = R"(
module attributes {tf.versions = {producer = 179 : i32}} {
func @main(%arg0: tensor<*xf32>, %arg1: tensor<?x19xf32>) -> tensor<?x19xf32> {
%0 = "tf.MatMul"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", transpose_a = false, transpose_b = false} : (tensor<*xf32>, tensor<?x19xf32>) -> tensor<?x19xf32>
return %0 : tensor<?x19xf32>
}
}
)";
std::vector<TensorShape> arg_shapes{TensorShape({10, 17}),
TensorShape({17, 19})};
XlaCompiler::CompilationResult compilation_result;
Status s = CompileSerializedMlirToXlaHlo(
mlir_module, arg_shapes, "XLA_CPU_JIT",
/*use_tuple_args=*/true, TestShapeRepresentation, &compilation_result);
TF_ASSERT_OK(s);
const xla::HloModuleConfig module_config(
compilation_result.computation->GetProgramShape().ValueOrDie());
auto status_or_hlo_module = xla::HloModule::CreateFromProto(
compilation_result.computation->proto(), module_config);
TF_ASSERT_OK(status_or_hlo_module.status());
constexpr char expected_signature[] =
R"((arg_tuple.1: (f32[10,17], f32[17,19])) -> (f32[10,19]))";
EXPECT_THAT(status_or_hlo_module.ValueOrDie()->ToString(),
::testing::HasSubstr(expected_signature));
}
TEST(CompileSerializedMlirToXlaHloTest, ShapeInferenceAfterLegalization) {
constexpr char mlir_module[] = R"(
module attributes {tf.versions = {producer = 179 : i32}} {
func @main(%arg0: tensor<8x16x16x64xbf16>, %arg1: tensor<64xf32>) -> (tensor<8x16x16x64xbf16>, tensor<64xf32>, tensor<64xf32>, tensor<64xf32>, tensor<64xf32>, tensor<*xf32>) {
%0:6 = "tf.FusedBatchNormV3"(%arg0, %arg1, %arg1, %arg1, %arg1) {data_format = "NHWC", device = "", epsilon = 9.99999974E-5 : f32, exponential_avg_factor = 1.000000e+00 : f32, is_training = false} : (tensor<8x16x16x64xbf16>, tensor<64xf32>, tensor<64xf32>, tensor<64xf32>, tensor<64xf32>) -> (tensor<8x16x16x64xbf16>, tensor<64xf32>, tensor<64xf32>, tensor<64xf32>, tensor<64xf32>, tensor<*xf32>)
return %0#0, %0#1, %0#2, %0#3, %0#4, %0#5 : tensor<8x16x16x64xbf16>, tensor<64xf32>, tensor<64xf32>, tensor<64xf32>, tensor<64xf32>, tensor<*xf32>
}
}
)";
std::vector<TensorShape> arg_shapes{TensorShape({8, 16, 16, 64}),
TensorShape({64})};
XlaCompiler::CompilationResult compilation_result;
Status s = CompileSerializedMlirToXlaHlo(
mlir_module, arg_shapes, "XLA_CPU_JIT",
/*use_tuple_args=*/true, TestShapeRepresentation, &compilation_result);
TF_ASSERT_OK(s);
const xla::HloModuleConfig module_config(
compilation_result.computation->GetProgramShape().ValueOrDie());
auto status_or_hlo_module = xla::HloModule::CreateFromProto(
compilation_result.computation->proto(), module_config);
TF_ASSERT_OK(status_or_hlo_module.status());
constexpr char expected_signature[] =
R"(-> (bf16[8,16,16,64], f32[64], f32[64], f32[64], f32[64], f32[0]))";
EXPECT_THAT(status_or_hlo_module.ValueOrDie()->ToString(),
::testing::HasSubstr(expected_signature));
}
TEST(CompileSerializedMlirToXlaHloTest, ConstantFoldHook) {
constexpr char mlir_module[] = R"(
module attributes {tf.versions = {producer = 179 : i32}} {
func @main() -> (tensor<0xi32>, tensor<0xi32>) {
%0 = "tf.Const"() {value = dense<[]> : tensor<0xi32>} : () -> tensor<0xi32>
%r0, %r1 = "tf.BroadcastGradientArgs"(%0, %0) {T = i32} : (tensor<0xi32>, tensor<0xi32>) -> (tensor<0xi32>, tensor<0xi32>)
return %r0, %r1 : tensor<0xi32>, tensor<0xi32>
}
}
)";
std::vector<TensorShape> arg_shapes(2, TensorShape());
XlaCompiler::CompilationResult compilation_result;
Status s = CompileSerializedMlirToXlaHlo(
mlir_module, arg_shapes, "XLA_CPU_JIT",
/*use_tuple_args=*/true, TestShapeRepresentation, &compilation_result);
TF_ASSERT_OK(s);
const xla::HloModuleConfig module_config(
compilation_result.computation->GetProgramShape().ValueOrDie());
auto status_or_hlo_module = xla::HloModule::CreateFromProto(
compilation_result.computation->proto(), module_config);
TF_ASSERT_OK(status_or_hlo_module.status());
constexpr char expected_hlo_module_string[] = R"(HloModule main.4
ENTRY %main.4 (arg_tuple.1: ()) -> (s32[0], s32[0]) {
%arg_tuple.1 = () parameter(0)
%constant.2 = s32[0]{0} constant({})
ROOT %tuple.3 = (s32[0]{0}, s32[0]{0}) tuple(s32[0]{0} %constant.2, s32[0]{0} %constant.2)
}
)";
EXPECT_EQ(expected_hlo_module_string,
status_or_hlo_module.ValueOrDie()->ToString());
}
// The following xla::OpSharding protos are used:
// Serialized string:
// "\08\03\1A\02\01\02\22\02\00\01"
// Proto debug string:
// type: OTHER
// tile_assignment_dimensions: 1
// tile_assignment_dimensions: 2
// tile_assignment_devices: 0
// tile_assignment_devices: 1
//
// Serialized string:
// "\08\01\1A\01\01\22\01\00"
// Proto debug string:
// type: MAXIMAL
// tile_assignment_dimensions: 1
// tile_assignment_devices: 0
//
// Serialized string:
// ""
// Proto debug string (empty but would equivalent to):
// type: REPLICATED
TEST(CompileSerializedMlirToXlaHloTest, ArgumentSharding) {
constexpr char mlir_module[] = R"(
module attributes {tf.versions = {producer = 179 : i32}} {
func @main(%arg0: tensor<128x10xf32> {mhlo.sharding = "\08\03\1A\02\01\02\22\02\00\01"}, %arg1: tensor<10x1024xf32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg2: tensor<128x1024xf32> {mhlo.sharding = ""}) {
return
}
}
)";
std::vector<TensorShape> arg_shapes{TensorShape({128, 10}),
TensorShape({10, 1024}),
TensorShape({128, 1024})};
XlaCompiler::CompilationResult compilation_result;
Status s = CompileSerializedMlirToXlaHlo(
mlir_module, arg_shapes, "XLA_CPU_JIT",
/*use_tuple_args=*/true, TestShapeRepresentation, &compilation_result);
TF_ASSERT_OK(s);
const xla::HloModuleConfig module_config(
compilation_result.computation->GetProgramShape().ValueOrDie());
auto status_or_hlo_module = xla::HloModule::CreateFromProto(
compilation_result.computation->proto(), module_config);
TF_ASSERT_OK(status_or_hlo_module.status());
constexpr char expected_hlo_module_string[] = R"(HloModule main.6
ENTRY %main.6 (arg_tuple.1: (f32[128,10], f32[10,1024], f32[128,1024])) -> () {
%arg_tuple.1 = (f32[128,10]{1,0}, f32[10,1024]{1,0}, f32[128,1024]{1,0}) parameter(0), sharding={{devices=[1,2]0,1}, {maximal device=0}, {replicated}}
%get-tuple-element.2 = f32[128,10]{1,0} get-tuple-element((f32[128,10]{1,0}, f32[10,1024]{1,0}, f32[128,1024]{1,0}) %arg_tuple.1), index=0
%get-tuple-element.3 = f32[10,1024]{1,0} get-tuple-element((f32[128,10]{1,0}, f32[10,1024]{1,0}, f32[128,1024]{1,0}) %arg_tuple.1), index=1
%get-tuple-element.4 = f32[128,1024]{1,0} get-tuple-element((f32[128,10]{1,0}, f32[10,1024]{1,0}, f32[128,1024]{1,0}) %arg_tuple.1), index=2
ROOT %tuple.5 = () tuple()
}
)";
EXPECT_EQ(expected_hlo_module_string,
status_or_hlo_module.ValueOrDie()->ToString());
}
TEST(CompileSerializedMlirToXlaHloTest, BadArgumentSharding) {
constexpr char mlir_module[] = R"(
module attributes {tf.versions = {producer = 179 : i32}} {
func @main(%arg0: tensor<128x10xf32> {mhlo.sharding = "bad_sharding"}) {
return
}
}
)";
std::vector<TensorShape> arg_shapes{TensorShape({128, 10})};
XlaCompiler::CompilationResult compilation_result;
Status s = CompileSerializedMlirToXlaHlo(
mlir_module, arg_shapes, "XLA_CPU_JIT",
/*use_tuple_args=*/true, TestShapeRepresentation, &compilation_result);
ASSERT_FALSE(s.ok());
EXPECT_EQ(s.error_message(),
"failed to parse argument sharding 0 'bad_sharding'");
}
TEST(CompileSerializedMlirToXlaHloTest, ResultSharding) {
constexpr char mlir_module[] = R"(
module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 351 : i32}} {
func @main(%arg0: tensor<128x10xf32>, %arg1: tensor<10x1024xf32>, %arg2: tensor<128x1024xf32>) -> (tensor<128x10xf32> {mhlo.sharding = "\08\03\1A\02\01\02\22\02\00\01"}, tensor<10x1024xf32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, tensor<128x1024xf32> {mhlo.sharding = ""}) {
return %arg0, %arg1, %arg2 : tensor<128x10xf32>, tensor<10x1024xf32>, tensor<128x1024xf32>
}
}
)";
std::vector<TensorShape> arg_shapes{TensorShape({128, 10}),
TensorShape({10, 1024}),
TensorShape({128, 1024})};
XlaCompiler::CompilationResult compilation_result;
Status s = CompileSerializedMlirToXlaHlo(
mlir_module, arg_shapes, "XLA_CPU_JIT",
/*use_tuple_args=*/true, TestShapeRepresentation, &compilation_result);
TF_ASSERT_OK(s);
const xla::HloModuleConfig module_config(
compilation_result.computation->GetProgramShape().ValueOrDie());
auto status_or_hlo_module = xla::HloModule::CreateFromProto(
compilation_result.computation->proto(), module_config);
TF_ASSERT_OK(status_or_hlo_module.status());
constexpr char expected_hlo_module_string[] = R"(HloModule main.9
ENTRY %main.9 (arg_tuple.1: (f32[128,10], f32[10,1024], f32[128,1024])) -> (f32[128,10], f32[10,1024], f32[128,1024]) {
%arg_tuple.1 = (f32[128,10]{1,0}, f32[10,1024]{1,0}, f32[128,1024]{1,0}) parameter(0)
%get-tuple-element.2 = f32[128,10]{1,0} get-tuple-element((f32[128,10]{1,0}, f32[10,1024]{1,0}, f32[128,1024]{1,0}) %arg_tuple.1), index=0
%reshape.5 = f32[128,10]{1,0} reshape(f32[128,10]{1,0} %get-tuple-element.2)
%get-tuple-element.3 = f32[10,1024]{1,0} get-tuple-element((f32[128,10]{1,0}, f32[10,1024]{1,0}, f32[128,1024]{1,0}) %arg_tuple.1), index=1
%reshape.6 = f32[10,1024]{1,0} reshape(f32[10,1024]{1,0} %get-tuple-element.3)
%get-tuple-element.4 = f32[128,1024]{1,0} get-tuple-element((f32[128,10]{1,0}, f32[10,1024]{1,0}, f32[128,1024]{1,0}) %arg_tuple.1), index=2
%reshape.7 = f32[128,1024]{1,0} reshape(f32[128,1024]{1,0} %get-tuple-element.4)
ROOT %tuple.8 = (f32[128,10]{1,0}, f32[10,1024]{1,0}, f32[128,1024]{1,0}) tuple(f32[128,10]{1,0} %reshape.5, f32[10,1024]{1,0} %reshape.6, f32[128,1024]{1,0} %reshape.7), sharding={{devices=[1,2]0,1}, {maximal device=0}, {replicated}}
}
)";
EXPECT_EQ(expected_hlo_module_string,
status_or_hlo_module.ValueOrDie()->ToString());
}
// Verify that conversion from Graph to MLIR and empty shape representation
// function is successful.
TEST(CompileGraphToXlaHlo, Basic) {
FunctionLibraryDefinition flib_def(OpRegistry::Global(), {});
Graph graph(OpRegistry::Global());
Node* arg = test::graph::Arg(&graph, 0, DT_FLOAT);
test::graph::Retval(&graph, 0, arg);
XlaCompiler::CompilationResult result;
XlaCompiler::Argument compiler_arg;
compiler_arg.kind = XlaCompiler::Argument::kParameter;
compiler_arg.shape = TensorShape();
TF_ASSERT_OK(
CompileGraphToXlaHlo(graph, /*args=*/{compiler_arg}, "XLA_CPU_JIT",
/*use_tuple_args=*/false, flib_def, GraphDebugInfo(),
/*shape_representation_fn=*/nullptr, &result));
const xla::HloModuleConfig module_config(
result.computation->GetProgramShape().ValueOrDie());
auto status_or_hlo_module = xla::HloModule::CreateFromProto(
result.computation->proto(), module_config);
ASSERT_TRUE(status_or_hlo_module.ok());
constexpr char expected_hlo_module_string[] = R"(HloModule main.3
ENTRY %main.3 (Arg_0.1: f32[]) -> (f32[]) {
%Arg_0.1 = f32[] parameter(0)
ROOT %tuple.2 = (f32[]) tuple(f32[] %Arg_0.1)
}
)";
EXPECT_EQ(expected_hlo_module_string,
status_or_hlo_module.ValueOrDie()->ToString());
}
// Tests a conversion from Graph to MLIR with resource arguments.
TEST(CompileGraphToXlaHlo, Resources) {
FunctionLibraryDefinition flib_def(OpRegistry::Global(), {});
Graph graph(OpRegistry::Global());
Scope scope = Scope::NewRootScope().ExitOnError();
auto val = ops::_Arg(scope.WithOpName("arg0"), DT_FLOAT, 0);
auto var = ops::_Arg(scope.WithOpName("arg1"), DT_RESOURCE, 1);
auto assign =
ops::AssignVariableOp(scope.WithOpName("assign_variable"), var, val);
TF_ASSERT_OK(scope.ToGraph(&graph));
XlaCompiler::CompilationResult result;
XlaCompiler::Argument arg0;
arg0.kind = XlaCompiler::Argument::kParameter;
arg0.shape = TensorShape({2});
XlaCompiler::Argument arg1;
arg1.kind = XlaCompiler::Argument::kResource;
arg1.shape = TensorShape({2});
arg1.type = DT_FLOAT;
TF_ASSERT_OK(
CompileGraphToXlaHlo(graph, /*args=*/{arg0, arg1}, "XLA_CPU_JIT",
/*use_tuple_args=*/false, flib_def, GraphDebugInfo(),
/*shape_representation_fn=*/nullptr, &result));
EXPECT_EQ(result.outputs.size(), 0);
ASSERT_EQ(result.resource_updates.size(), 1);
const auto& resource_update = result.resource_updates[0];
EXPECT_EQ(resource_update.input_index, 1);
EXPECT_EQ(resource_update.modified, true);
EXPECT_EQ(resource_update.shape, TensorShape({2}));
EXPECT_EQ(resource_update.type, DT_FLOAT);
const xla::HloModuleConfig module_config(
result.computation->GetProgramShape().ValueOrDie());
auto status_or_hlo_module = xla::HloModule::CreateFromProto(
result.computation->proto(), module_config);
ASSERT_TRUE(status_or_hlo_module.ok());
constexpr char expected_hlo_module_string[] =
R"(HloModule main.4, input_output_alias={ {0}: 1 }
ENTRY %main.4 (Arg_0.1: f32[2], Arg_1.2: f32[2]) -> (f32[2]) {
%Arg_1.2 = f32[2]{0} parameter(1)
%Arg_0.1 = f32[2]{0} parameter(0)
ROOT %tuple.3 = (f32[2]{0}) tuple(f32[2]{0} %Arg_0.1)
}
)";
EXPECT_EQ(expected_hlo_module_string,
status_or_hlo_module.ValueOrDie()->ToString());
}
} // namespace
} // namespace tensorflow