blob: f0874bcd8d361da7ac7ecf29db91221063b84c95 [file] [log] [blame]
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/layout_assignment.h"
#include <initializer_list>
#include <memory>
#include <utility>
#include <vector>
#include "absl/types/span.h"
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/service/algebraic_simplifier.h"
#include "tensorflow/compiler/xla/service/computation_layout.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/service/pattern_matcher.h"
#include "tensorflow/compiler/xla/service/pattern_matcher_gmock.h"
#include "tensorflow/compiler/xla/shape_layout.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/test_helpers.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/compiler/xla/tests/test_utils.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
namespace xla {
namespace {
namespace m = xla::match;
using ::testing::ElementsAre;
class LayoutAssignmentTest : public HloTestBase {
protected:
void AssignLayouts(HloModule* m, ComputationLayout* entry_computation_layout,
ChannelLayoutConstraints* channel_constraints = nullptr) {
LayoutAssignment layout_assignment(
entry_computation_layout,
/*channel_constraints=*/channel_constraints);
EXPECT_IS_OK(layout_assignment.Run(m).status());
}
std::vector<int64> LayoutOf(HloModule* m, absl::string_view name) {
auto minor_to_major =
FindInstruction(m, name)->shape().layout().minor_to_major();
return std::vector<int64>(minor_to_major.begin(), minor_to_major.end());
}
void ExpectLayoutIs(const Shape& shape,
absl::Span<const int64> minor_to_major) {
const Layout expected = LayoutUtil::MakeLayout(minor_to_major);
EXPECT_TRUE(LayoutUtil::Equal(shape.layout(), expected))
<< "Expected layout " << expected << ", actual " << shape.layout();
}
void ExpectTupleLayoutIs(
const Shape& shape,
std::initializer_list<absl::Span<const int64>> minor_to_majors) {
int i = 0;
for (const absl::Span<const int64> minor_to_major : minor_to_majors) {
const Layout expected = LayoutUtil::MakeLayout(minor_to_major);
const Layout& actual = ShapeUtil::GetTupleElementShape(shape, i).layout();
EXPECT_TRUE(LayoutUtil::Equal(actual, expected))
<< "Expected tuple element " << i << " layout " << expected
<< ", actual " << actual;
++i;
}
}
};
TEST_F(LayoutAssignmentTest, ComputationLayout) {
// Verify the layouts of the root and parameter instructions of a computation
// match the ComputationLayout for two different layouts.
std::vector<std::vector<int64>> minor_to_majors = {{0, 1}, {1, 0}};
for (auto& minor_to_major : minor_to_majors) {
auto builder = HloComputation::Builder(TestName());
Shape ashape = ShapeUtil::MakeShape(F32, {42, 12});
auto param0 = builder.AddInstruction(
HloInstruction::CreateParameter(0, ashape, "param0"));
auto param1 = builder.AddInstruction(
HloInstruction::CreateParameter(1, ashape, "param1"));
auto add = builder.AddInstruction(
HloInstruction::CreateBinary(ashape, HloOpcode::kAdd, param0, param1));
auto m = CreateNewVerifiedModule();
HloComputation* computation = m->AddEntryComputation(builder.Build());
Layout layout = LayoutUtil::MakeLayout(minor_to_major);
Shape shape(ashape);
*shape.mutable_layout() = layout;
const ShapeLayout shape_layout(shape);
ComputationLayout computation_layout(computation->ComputeProgramShape());
*computation_layout.mutable_parameter_layout(0) = shape_layout;
*computation_layout.mutable_parameter_layout(1) = shape_layout;
*computation_layout.mutable_result_layout() = shape_layout;
AssignLayouts(m.get(), &computation_layout);
EXPECT_TRUE(LayoutUtil::Equal(layout, param0->shape().layout()));
EXPECT_TRUE(LayoutUtil::Equal(layout, param1->shape().layout()));
EXPECT_TRUE(LayoutUtil::Equal(layout, add->shape().layout()));
}
}
TEST_F(LayoutAssignmentTest, ComputationLayoutMixedLayout) {
// Verify the layouts of the root and parameter instructions of a computation
// match the ComputationLayout which has mixed layout.
auto builder = HloComputation::Builder(TestName());
Shape ashape = ShapeUtil::MakeShape(F32, {42, 12});
auto param0 = builder.AddInstruction(
HloInstruction::CreateParameter(0, ashape, "param0"));
auto param1 = builder.AddInstruction(
HloInstruction::CreateParameter(1, ashape, "param1"));
builder.AddInstruction(
HloInstruction::CreateBinary(ashape, HloOpcode::kAdd, param0, param1));
auto m = CreateNewVerifiedModule();
HloComputation* computation = m->AddEntryComputation(builder.Build());
Layout col_major_layout = LayoutUtil::MakeLayout({1, 0});
Shape col_major_shape(ashape);
*col_major_shape.mutable_layout() = col_major_layout;
const ShapeLayout col_major(col_major_shape);
Layout row_major_layout = LayoutUtil::MakeLayout({0, 1});
Shape row_major_shape(ashape);
*row_major_shape.mutable_layout() = row_major_layout;
const ShapeLayout row_major(row_major_shape);
ComputationLayout computation_layout(computation->ComputeProgramShape());
*computation_layout.mutable_parameter_layout(0) = col_major;
*computation_layout.mutable_parameter_layout(1) = row_major;
*computation_layout.mutable_result_layout() = col_major;
AssignLayouts(m.get(), &computation_layout);
EXPECT_TRUE(LayoutUtil::Equal(col_major_layout, param0->shape().layout()));
EXPECT_TRUE(LayoutUtil::Equal(row_major_layout, param1->shape().layout()));
EXPECT_TRUE(LayoutUtil::Equal(
col_major_layout, computation->root_instruction()->shape().layout()));
}
TEST_F(LayoutAssignmentTest, FusionInstruction) {
// Verify that the layout of the fused parameters in a fusion instruction
// match that of the fusion operands. Other fused instructions should have no
// layout.
std::vector<std::vector<int64>> minor_to_majors = {{0, 1}, {1, 0}};
for (auto& minor_to_major : minor_to_majors) {
auto builder = HloComputation::Builder(TestName());
auto constant_literal1 = LiteralUtil::CreateR2WithLayout<float>(
{{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout(minor_to_major));
auto constant_literal2 = LiteralUtil::CreateR2WithLayout<float>(
{{5.0, 6.0}, {7.0, 8.0}}, LayoutUtil::MakeLayout(minor_to_major));
Shape ashape = constant_literal1.shape();
auto constant1 = builder.AddInstruction(
HloInstruction::CreateConstant(std::move(constant_literal1)));
auto constant2 = builder.AddInstruction(
HloInstruction::CreateConstant(std::move(constant_literal2)));
auto add = builder.AddInstruction(HloInstruction::CreateBinary(
ashape, HloOpcode::kAdd, constant1, constant2));
auto negate1 = builder.AddInstruction(
HloInstruction::CreateUnary(ashape, HloOpcode::kNegate, add));
auto negate2 = builder.AddInstruction(
HloInstruction::CreateUnary(ashape, HloOpcode::kNegate, negate1));
auto m = CreateNewVerifiedModule();
HloComputation* computation = m->AddEntryComputation(builder.Build());
auto fusion = computation->CreateFusionInstruction(
{negate2, negate1, add}, HloInstruction::FusionKind::kLoop);
Layout layout = LayoutUtil::MakeLayout(minor_to_major);
Shape shape(ashape);
*shape.mutable_layout() = layout;
const ShapeLayout shape_layout(shape);
ComputationLayout computation_layout(computation->ComputeProgramShape());
*computation_layout.mutable_result_layout() = shape_layout;
AssignLayouts(m.get(), &computation_layout);
EXPECT_TRUE(LayoutUtil::Equal(
layout, fusion->fused_parameter(0)->shape().layout()));
EXPECT_TRUE(LayoutUtil::Equal(
layout, fusion->fused_parameter(1)->shape().layout()));
EXPECT_TRUE(LayoutUtil::Equal(
layout, fusion->fused_expression_root()->shape().layout()));
// Inner fused node should not have layout.
EXPECT_FALSE(LayoutUtil::HasLayout(
fusion->fused_expression_root()->operand(0)->shape()));
}
}
TEST_F(LayoutAssignmentTest, TupleLayout) {
// Verify the layouts of a tuple are assigned properly (the element layouts
// match their source).
auto builder = HloComputation::Builder(TestName());
auto constant0 = builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR2WithLayout<float>(
{{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({0, 1}))));
auto constant1 = builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR2WithLayout<float>(
{{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({1, 0}))));
auto tuple = builder.AddInstruction(
HloInstruction::CreateTuple({constant0, constant1}));
// To avoid having to construct a tuple layout in the ComputationLayout below,
// make the result of the instruction be an array.
auto get_element0 = builder.AddInstruction(
HloInstruction::CreateGetTupleElement(constant0->shape(), tuple, 0));
auto negate = builder.AddInstruction(HloInstruction::CreateUnary(
constant0->shape(), HloOpcode::kNegate, get_element0));
auto m = CreateNewVerifiedModule();
m->AddEntryComputation(builder.Build());
ComputationLayout computation_layout(
m->entry_computation()->ComputeProgramShape());
AssignLayouts(m.get(), &computation_layout);
EXPECT_TRUE(
LayoutUtil::LayoutsInShapesEqual(constant0->shape(), constant1->shape()));
EXPECT_TRUE(LayoutUtil::HasLayout(tuple->shape()));
EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(
negate->shape(), computation_layout.result_layout().shape()));
EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(
ShapeUtil::GetTupleElementShape(tuple->shape(), 1), constant1->shape()));
}
TEST_F(LayoutAssignmentTest, TupleSelect) {
// Verify layouts of a select with tuple operands is assigned properly.
auto builder = HloComputation::Builder(TestName());
auto constant0 = builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR2WithLayout<float>(
{{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({0, 1}))));
auto constant1 = builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR2WithLayout<float>(
{{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({1, 0}))));
auto tuple0 = builder.AddInstruction(
HloInstruction::CreateTuple({constant0, constant1}));
auto tuple1 = builder.AddInstruction(
HloInstruction::CreateTuple({constant0, constant1}));
auto pred = builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(true)));
auto select = builder.AddInstruction(HloInstruction::CreateTernary(
tuple0->shape(), HloOpcode::kTupleSelect, pred, tuple0, tuple1));
auto m = CreateNewVerifiedModule();
m->AddEntryComputation(builder.Build());
ComputationLayout computation_layout(
m->entry_computation()->ComputeProgramShape());
Shape result_shape =
ShapeUtil::MakeTupleShape({constant0->shape(), constant1->shape()});
TF_CHECK_OK(computation_layout.mutable_result_layout()->CopyLayoutFromShape(
result_shape));
AssignLayouts(m.get(), &computation_layout);
EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(result_shape, select->shape()));
}
TEST_F(LayoutAssignmentTest, ConflictingLayoutTuple) {
// Construct following computation which has conflicting layouts for two
// elements of a tuple which share the same source logicalb buffer:
//
// %constant = Constant(...)
// %inner_tuple = Tuple(%constant)
// %nested_tuple = Tuple(%inner_tuple, %inner_tuple)
//
// Result layout col-major for the first element and row-major for the
// second. This results in the conflict where the element of the inner_tuple
// needs to be both col and row major. This is resolved by deep-copying the
// tuple and assigning the layouts of the copied arrays as needed.
auto builder = HloComputation::Builder(TestName());
auto constant = builder.AddInstruction(HloInstruction::CreateConstant(
LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}})));
auto inner_tuple =
builder.AddInstruction(HloInstruction::CreateTuple({constant}));
auto nested_tuple = builder.AddInstruction(
HloInstruction::CreateTuple({inner_tuple, inner_tuple}));
auto m = CreateNewVerifiedModule();
m->AddEntryComputation(builder.Build());
ComputationLayout computation_layout(
m->entry_computation()->ComputeProgramShape());
Shape result_shape = nested_tuple->shape();
*ShapeUtil::GetMutableSubshape(&result_shape, /*index=*/{0, 0}) =
ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {1, 0});
*ShapeUtil::GetMutableSubshape(&result_shape, /*index=*/{1, 0}) =
ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {0, 1});
TF_CHECK_OK(computation_layout.mutable_result_layout()->CopyLayoutFromShape(
result_shape));
LayoutAssignment layout_assignment(&computation_layout);
AssignLayouts(m.get(), &computation_layout);
// Layout assignment should have deep copied the result of the computation to
// address the layout conflict. This results in several Tuple() and
// GetTupleElement() instructions. Running algebraic simplification should
// clean up the code to something like:
//
// %constant = Constant(...) layout={1,0}
// %tuple.0 = Tuple(%constant) layout=({1,0})
// %copy = Copy(%constant) layout={0,1} # layout transposed
// %tuple.1 = Tuple(%copy) layout=({0,1})
// %tuple.2 = Tuple(%tuple.0, %tuple.1) layout=(({1,0}), ({0,1}))
//
AlgebraicSimplifierOptions options(
[](const Shape&, const Shape&) { return false; });
options.set_is_layout_sensitive(true);
EXPECT_TRUE(AlgebraicSimplifier(options).Run(m.get()).ValueOrDie());
HloInstruction* root = m->entry_computation()->root_instruction();
// Verify layout of the root and the root's operands.
EXPECT_TRUE(ShapeUtil::Equal(result_shape, root->shape()));
EXPECT_TRUE(ShapeUtil::Equal(ShapeUtil::GetSubshape(result_shape, {0}),
root->operand(0)->shape()));
EXPECT_TRUE(ShapeUtil::Equal(ShapeUtil::GetSubshape(result_shape, {1}),
root->operand(1)->shape()));
// Verify the structure of the HLO graph.
EXPECT_THAT(root,
GmockMatch(m::Tuple(m::Tuple(m::Op().Is(constant)),
m::Tuple(m::Copy(m::Op().Is(constant))))));
}
TEST_F(LayoutAssignmentTest, ElementwiseAndReshape) {
// param -> log -> reshape -> tanh
auto builder = HloComputation::Builder(TestName());
Shape ashape = ShapeUtil::MakeShape(F32, {1, 2, 3, 1});
Shape bshape = ShapeUtil::MakeShape(F32, {3, 1, 2});
auto param = builder.AddInstruction(
HloInstruction::CreateParameter(0, ashape, "param"));
auto log = builder.AddInstruction(
HloInstruction::CreateUnary(ashape, HloOpcode::kLog, param));
auto reshape =
builder.AddInstruction(HloInstruction::CreateReshape(bshape, log));
auto tanh = builder.AddInstruction(
HloInstruction::CreateUnary(bshape, HloOpcode::kTanh, reshape));
auto m = CreateNewVerifiedModule();
HloComputation* computation = m->AddEntryComputation(builder.Build(tanh));
Shape ashape_with_layout(ashape);
Shape bshape_with_layout(bshape);
*ashape_with_layout.mutable_layout() = LayoutUtil::MakeLayout({0, 2, 1, 3});
*bshape_with_layout.mutable_layout() = LayoutUtil::MakeLayout({2, 1, 0});
ComputationLayout computation_layout(computation->ComputeProgramShape());
*computation_layout.mutable_parameter_layout(0) =
ShapeLayout(ashape_with_layout);
*computation_layout.mutable_result_layout() = ShapeLayout(bshape_with_layout);
AssignLayouts(m.get(), &computation_layout);
auto log_minor_to_major =
AsInt64Slice(log->shape().layout().minor_to_major());
EXPECT_GT(PositionInContainer(log_minor_to_major, 1),
PositionInContainer(log_minor_to_major, 2));
auto reshape_minor_to_major =
AsInt64Slice(reshape->shape().layout().minor_to_major());
EXPECT_GT(PositionInContainer(reshape_minor_to_major, 0),
PositionInContainer(reshape_minor_to_major, 2));
}
// Test whether LayoutAssignment assigns layouts to elementwise operations to
// keep linear indices valid across them, and to transpositions to make them
// bitcasts.
TEST_F(LayoutAssignmentTest, ElementwiseAndTranspose) {
// param -> log -> transpose -> tanh
auto builder = HloComputation::Builder(TestName());
Shape ashape = ShapeUtil::MakeShape(F32, {42, 12});
Shape bshape = ShapeUtil::MakeShape(F32, {12, 42});
auto param = builder.AddInstruction(
HloInstruction::CreateParameter(0, ashape, "param"));
auto log = builder.AddInstruction(
HloInstruction::CreateUnary(ashape, HloOpcode::kLog, param));
auto transpose = builder.AddInstruction(
HloInstruction::CreateTranspose(bshape, log, {1, 0}));
auto tanh = builder.AddInstruction(
HloInstruction::CreateUnary(bshape, HloOpcode::kTanh, transpose));
auto m = CreateNewVerifiedModule();
auto computation = m->AddEntryComputation(builder.Build(tanh));
Shape ashape_with_layout(ashape);
Shape bshape_with_layout(bshape);
*ashape_with_layout.mutable_layout() = LayoutUtil::MakeLayout({1, 0});
*bshape_with_layout.mutable_layout() = LayoutUtil::MakeLayout({0, 1});
ComputationLayout computation_layout(computation->ComputeProgramShape());
*computation_layout.mutable_parameter_layout(0) =
ShapeLayout(ashape_with_layout);
*computation_layout.mutable_result_layout() = ShapeLayout(bshape_with_layout);
AssignLayouts(m.get(), &computation_layout);
EXPECT_TRUE(
LayoutUtil::Equal(ashape_with_layout.layout(), log->shape().layout()));
EXPECT_TRUE(LayoutUtil::Equal(bshape_with_layout.layout(),
transpose->shape().layout()));
EXPECT_TRUE(
LayoutUtil::Equal(bshape_with_layout.layout(), tanh->shape().layout()));
}
// Test whether LayoutAssignment assigns layouts to transpositions to make them
// bitcasts.
TEST_F(LayoutAssignmentTest, BroadcastAndTranspose) {
// param -> broadcast -> transpose
auto builder = HloComputation::Builder(TestName());
Shape ashape = ShapeUtil::MakeShape(F32, {3, 4});
Shape bshape = ShapeUtil::MakeShape(F32, {2, 3, 4});
Shape cshape = ShapeUtil::MakeShape(F32, {4, 3, 2});
auto param = builder.AddInstruction(
HloInstruction::CreateParameter(0, ashape, "param"));
auto broadcast = builder.AddInstruction(
HloInstruction::CreateBroadcast(bshape, param, {1, 2}));
auto transpose = builder.AddInstruction(
HloInstruction::CreateTranspose(cshape, broadcast, {2, 1, 0}));
auto m = CreateNewVerifiedModule();
HloComputation* computation =
m->AddEntryComputation(builder.Build(transpose));
Shape input_shape_with_layout(ashape);
Shape output_shape_with_layout(cshape);
*input_shape_with_layout.mutable_layout() = LayoutUtil::MakeLayout({1, 0});
*output_shape_with_layout.mutable_layout() =
LayoutUtil::MakeLayout({2, 1, 0});
ComputationLayout computation_layout(computation->ComputeProgramShape());
*computation_layout.mutable_parameter_layout(0) =
ShapeLayout(input_shape_with_layout);
*computation_layout.mutable_result_layout() =
ShapeLayout(output_shape_with_layout);
AssignLayouts(m.get(), &computation_layout);
EXPECT_THAT(broadcast->shape().layout().minor_to_major(),
ElementsAre(0, 1, 2));
}
TEST_F(LayoutAssignmentTest, ReshapeOperandHasMultipleUsers) {
// param[4] -> broadcast[3x4] ------> transpose[4x3]-------- -------> tuple
// \ /
// \-> tanh[3x4] -> broadcast2[2x3x4] -/
//
// The layout of `transpose` is set to {1,0} because it provides a buffer to
// the computation result which has a fixed layout.. Therefore, `broadcast`
// (the operand of transpose) is expected to have layout {0,1} so that the
// transpose is a bitcast. Furthermore, `tanh` is expected to have the same
// layout as `broadcast` (i.e. {0,1}) because `tanh` is elementwise.
Shape f32_4 = ShapeUtil::MakeShape(F32, {4});
Shape f32_34 = ShapeUtil::MakeShape(F32, {3, 4});
Shape f32_43 = ShapeUtil::MakeShape(F32, {4, 3});
Shape f32_234 = ShapeUtil::MakeShape(F32, {2, 3, 4});
auto builder = HloComputation::Builder(TestName());
auto param = builder.AddInstruction(
HloInstruction::CreateParameter(0, f32_4, "param"));
auto broadcast = builder.AddInstruction(
HloInstruction::CreateBroadcast(f32_34, param, {1}));
auto transpose = builder.AddInstruction(
HloInstruction::CreateTranspose(f32_43, broadcast, {1, 0}));
auto tanh = builder.AddInstruction(
HloInstruction::CreateUnary(f32_34, HloOpcode::kTanh, broadcast));
auto broadcast2 = builder.AddInstruction(
HloInstruction::CreateBroadcast(f32_234, tanh, {1, 2}));
auto tuple = builder.AddInstruction(
HloInstruction::CreateTuple({transpose, broadcast2}));
auto m = CreateNewVerifiedModule();
HloComputation* computation = m->AddEntryComputation(builder.Build(tuple));
ComputationLayout computation_layout(computation->ComputeProgramShape());
Shape param_shape_with_layout(f32_4);
Shape transpose_shape_with_layout(f32_43);
Shape broadcast2_shape_with_layout(f32_234);
*param_shape_with_layout.mutable_layout() = LayoutUtil::MakeLayout({0});
*transpose_shape_with_layout.mutable_layout() =
LayoutUtil::MakeLayout({1, 0});
*broadcast2_shape_with_layout.mutable_layout() =
LayoutUtil::MakeLayout({2, 1, 0});
*computation_layout.mutable_parameter_layout(0) =
ShapeLayout(param_shape_with_layout);
*computation_layout.mutable_result_layout() =
ShapeLayout(ShapeUtil::MakeTupleShape(
{transpose_shape_with_layout, broadcast2_shape_with_layout}));
AssignLayouts(m.get(), &computation_layout);
EXPECT_THAT(broadcast->shape().layout().minor_to_major(), ElementsAre(0, 1));
EXPECT_THAT(transpose->shape().layout().minor_to_major(), ElementsAre(1, 0));
EXPECT_THAT(tanh->shape().layout().minor_to_major(), ElementsAre(0, 1));
}
class OperandsMustBeTheSameLayoutAssignment : public LayoutAssignment {
public:
explicit OperandsMustBeTheSameLayoutAssignment(
ComputationLayout* entry_computation_layout)
: LayoutAssignment(entry_computation_layout) {}
protected:
Status PropagateBufferConstraint(
const BufferLayoutConstraint& buffer_constraint,
LayoutConstraints* constraints) override {
const LogicalBuffer& buffer = buffer_constraint.buffer();
const HloInstruction* instruction = buffer.instruction();
// Force the operands' layout to the output layout.
for (int64_t operand_no = 0; operand_no < instruction->operand_count();
++operand_no) {
const HloInstruction* operand = instruction->operand(operand_no);
if (instruction->shape().rank() != operand->shape().rank()) {
continue;
}
TF_RETURN_IF_ERROR(constraints->SetArrayOperandLayout(
buffer_constraint.layout(), instruction, operand_no,
/*mandatory=*/true));
}
return PropagateBufferConstraintToUses(buffer_constraint, constraints);
}
};
TEST_F(LayoutAssignmentTest, MakeOperandsTheSame) {
// param0 -> concatenate -> reshape
// param1 -^
auto builder = HloComputation::Builder(TestName());
Shape ashape = ShapeUtil::MakeShape(F32, {50, 1});
Shape bshape = ShapeUtil::MakeShape(F32, {50, 2});
Shape cshape = ShapeUtil::MakeShape(F32, {100});
auto param0 = builder.AddInstruction(
HloInstruction::CreateParameter(0, ashape, "param"));
auto param1 = builder.AddInstruction(
HloInstruction::CreateParameter(1, ashape, "param"));
auto concatenate = builder.AddInstruction(
HloInstruction::CreateConcatenate(bshape, {param0, param1}, 1));
auto reshape = builder.AddInstruction(
HloInstruction::CreateReshape(cshape, concatenate));
auto m = CreateNewVerifiedModule();
HloComputation* computation = m->AddEntryComputation(builder.Build(reshape));
Shape param0_shape_with_layout(ashape);
Shape param1_shape_with_layout(ashape);
*param0_shape_with_layout.mutable_layout() = LayoutUtil::MakeLayout({0, 1});
*param1_shape_with_layout.mutable_layout() = LayoutUtil::MakeLayout({1, 0});
ComputationLayout computation_layout(computation->ComputeProgramShape());
*computation_layout.mutable_parameter_layout(0) =
ShapeLayout(param0_shape_with_layout);
*computation_layout.mutable_parameter_layout(1) =
ShapeLayout(param1_shape_with_layout);
OperandsMustBeTheSameLayoutAssignment layout_assignment(&computation_layout);
EXPECT_IS_OK(layout_assignment.Run(m.get()).status());
EXPECT_EQ(concatenate->operand(0)->shape().layout().minor_to_major(),
concatenate->operand(1)->shape().layout().minor_to_major());
EXPECT_EQ(concatenate->shape().layout().minor_to_major(),
concatenate->operand(1)->shape().layout().minor_to_major());
}
// Test layout assignment of a transpose into a bitcast based on its operand.
TEST_F(LayoutAssignmentTest, TransposeToBitcastFromOperand) {
auto builder = HloComputation::Builder(TestName());
Shape input_shape_with_layout =
ShapeUtil::MakeShapeWithLayout(F32, {3, 5, 6, 7}, {2, 0, 3, 1});
auto param = builder.AddInstruction(
HloInstruction::CreateParameter(0, input_shape_with_layout, "param"));
auto transpose = builder.AddInstruction(HloInstruction::CreateTranspose(
ShapeUtil::MakeShape(F32, {6, 7, 3, 5}), param, {2, 3, 0, 1}));
auto m = CreateNewVerifiedModule();
HloComputation* computation =
m->AddEntryComputation(builder.Build(transpose));
ComputationLayout computation_layout(computation->ComputeProgramShape());
AssignLayouts(m.get(), &computation_layout);
EXPECT_TRUE(ShapeUtil::TransposeIsBitcast(transpose->operand(0)->shape(),
transpose->shape(), {2, 3, 0, 1}));
}
// Test layout assignment of a transpose into a bitcast based on its user.
TEST_F(LayoutAssignmentTest, TransposeToBitcastToUser) {
auto builder = HloComputation::Builder(TestName());
Shape input_shape = ShapeUtil::MakeShape(F32, {3, 5, 6, 7});
auto constant = builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0f)));
auto broadcast = builder.AddInstruction(
HloInstruction::CreateBroadcast(input_shape, constant, {}));
auto transpose = builder.AddInstruction(HloInstruction::CreateTranspose(
ShapeUtil::MakeShape(F32, {6, 7, 3, 5}), broadcast, {2, 3, 0, 1}));
auto m = CreateNewVerifiedModule();
HloComputation* computation =
m->AddEntryComputation(builder.Build(transpose));
ComputationLayout computation_layout(computation->ComputeProgramShape());
AssignLayouts(m.get(), &computation_layout);
EXPECT_TRUE(ShapeUtil::TransposeIsBitcast(transpose->operand(0)->shape(),
transpose->shape(), {2, 3, 0, 1}));
}
// TransposeIsBitcast shouldn't be called without layout information.
TEST_F(LayoutAssignmentTest, TransposeIsBitcastFail) {
auto builder = HloComputation::Builder(TestName());
Shape input_shape = ShapeUtil::MakeShape(F32, {2, 2, 2});
Shape input_shape_with_layout(input_shape);
*input_shape_with_layout.mutable_layout() = LayoutUtil::MakeLayout({2, 1, 0});
auto param = builder.AddInstruction(
HloInstruction::CreateParameter(0, input_shape_with_layout, "param"));
auto hlo = builder.AddInstruction(
HloInstruction::CreateTranspose(input_shape, param, {0, 2, 1}));
// Clear the default layout assigned to the instruction.
LayoutUtil::ClearLayout(hlo->mutable_shape());
EXPECT_DEATH(ShapeUtil::TransposeIsBitcast(hlo->operand(0)->shape(),
hlo->shape(), hlo->dimensions()),
"LayoutUtil::HasLayout");
}
// ReshapeIsBitcast shouldn't be called without layout information.
TEST_F(LayoutAssignmentTest, ReshapeIsBitcastFail) {
auto builder = HloComputation::Builder(TestName());
Shape input_shape = ShapeUtil::MakeShape(F32, {2, 2, 2});
Shape input_shape_with_layout(input_shape);
*input_shape_with_layout.mutable_layout() = LayoutUtil::MakeLayout({2, 1, 0});
auto param = builder.AddInstruction(
HloInstruction::CreateParameter(0, input_shape_with_layout, "param"));
auto hlo =
builder.AddInstruction(HloInstruction::CreateReshape(input_shape, param));
// Clear the default layout assigned to the instruction.
LayoutUtil::ClearLayout(hlo->mutable_shape());
EXPECT_DEATH(
ShapeUtil::ReshapeIsBitcast(hlo->operand(0)->shape(), hlo->shape()),
"LayoutUtil::HasLayout");
}
// Check that the computation below doesn't crash the compiler.
//
// Within a fusion computation, only the parameters and result get assigned a
// layout. When we run the algebraic simplifier on this computation post layout
// assignment, it should not call TransposeIsBitcast on the `transpose` node
// inside the fusion computation as TransposeIsBitcast checks both input_shape
// and output_shape have layouts.
TEST_F(LayoutAssignmentTest, TransposeWithinFusionDoesNotCrash) {
const char* module_str = R"(
HloModule test_module
fused_computation {
param_1 = f32[2,2,2]{2,1,0} parameter(1)
transpose = f32[2,2,2]{2,1,0} transpose(param_1), dimensions={0,2,1}
reduce_1 = f32[] parameter(0)
broadcast_1 = f32[2,2,2]{2,1,0} broadcast(reduce_1), dimensions={}
ROOT divide_1 = f32[2,2,2]{2,1,0} divide(transpose, broadcast_1)
}
ENTRY entry_computation {
fusion.1 = f32[2,2,2]{2,1,0} parameter(1)
reduce.1 = f32[] parameter(0)
fusion.2 = f32[2,2,2]{2,1,0} fusion(reduce.1, fusion.1), kind=kLoop, calls=fused_computation
ROOT tuple.1 = (f32[2,2,2]{2,1,0}) tuple(fusion.2)
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> m,
ParseAndReturnVerifiedModule(module_str));
std::unique_ptr<HloModule> compiled_module =
backend()
.compiler()
->RunHloPasses(m->Clone(), backend().default_stream_executor(),
/*device_allocator=*/nullptr)
.ConsumeValueOrDie();
EXPECT_EQ(Status::OK(), backend()
.compiler()
->RunBackend(std::move(compiled_module),
backend().default_stream_executor(),
/*device_allocator=*/nullptr)
.status());
}
// A GTE inside of a fusion node inherits the layout of its operand (which
// should, if we keep following operands, eventually be a parameter).
TEST_F(LayoutAssignmentTest, GTEInheritsLayoutFromOperand) {
const char* module_str = R"(
HloModule test_module
fused_computation {
fparam = (f32[2,2,2], (f32[2,2,2], f32[2,2,2])) parameter(0)
gte0 = f32[2,2,2] get-tuple-element(fparam), index=0
gte1 = (f32[2,2,2], f32[2,2,2]) get-tuple-element(fparam), index=1
gte1a = f32[2,2,2] get-tuple-element(gte1), index=0
gte1b = f32[2,2,2] get-tuple-element(gte1), index=1
add = f32[2,2,2] add(gte1a, gte1b)
ROOT fresult = f32[2,2,2] add(gte0, add)
}
ENTRY entry_computation {
param = (f32[2,2,2], (f32[2,2,2], f32[2,2,2])) parameter(0)
ROOT fusion =
f32[2,2,2] fusion(param), kind=kLoop, calls=fused_computation
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> m,
ParseAndReturnVerifiedModule(module_str));
ComputationLayout computation_layout(
m->entry_computation()->ComputeProgramShape());
Shape param_shape = ShapeUtil::MakeTupleShape(
{ShapeUtil::MakeShapeWithLayout(F32, {2, 2, 2}, {0, 1, 2}),
ShapeUtil::MakeTupleShape({
ShapeUtil::MakeShapeWithLayout(F32, {2, 2, 2}, {1, 2, 0}),
ShapeUtil::MakeShapeWithLayout(F32, {2, 2, 2}, {2, 0, 1}),
})});
TF_ASSERT_OK(
computation_layout.mutable_parameter_layout(0)->CopyLayoutFromShape(
param_shape));
computation_layout.mutable_result_layout()->ResetLayout(
LayoutUtil::MakeLayout({2, 1, 0}));
AssignLayouts(m.get(), &computation_layout);
EXPECT_THAT(LayoutOf(m.get(), "gte0"), ElementsAre(0, 1, 2));
EXPECT_THAT(LayoutOf(m.get(), "gte1a"), ElementsAre(1, 2, 0));
EXPECT_THAT(LayoutOf(m.get(), "gte1b"), ElementsAre(2, 0, 1));
EXPECT_THAT(LayoutOf(m.get(), "fresult"), ElementsAre(2, 1, 0));
EXPECT_THAT(FindInstruction(m.get(), "gte1")
->shape()
.tuple_shapes(0)
.layout()
.minor_to_major(),
ElementsAre(1, 2, 0));
EXPECT_THAT(FindInstruction(m.get(), "gte1")
->shape()
.tuple_shapes(1)
.layout()
.minor_to_major(),
ElementsAre(2, 0, 1));
}
TEST_F(LayoutAssignmentTest, ConditionalAsymmetricLayout) {
auto builder = HloComputation::Builder(TestName());
auto m = CreateNewVerifiedModule();
Shape shape = ShapeUtil::MakeShape(F32, {128, 8});
Shape tshape = ShapeUtil::MakeTupleShape({shape, shape});
Shape result_tshape = ShapeUtil::MakeTupleShape({shape});
auto param0 = builder.AddInstruction(
HloInstruction::CreateParameter(0, shape, "param0"));
auto param1 = builder.AddInstruction(
HloInstruction::CreateParameter(1, shape, "param1"));
auto pred = builder.AddInstruction(HloInstruction::CreateParameter(
2, ShapeUtil::MakeShape(PRED, {}), "param2"));
auto tuple =
builder.AddInstruction(HloInstruction::CreateTuple({param0, param1}));
auto true_builder = HloComputation::Builder(TestName() + "_TrueBranch");
{
auto param = true_builder.AddInstruction(
HloInstruction::CreateParameter(0, tshape, "param"));
auto gte0 = true_builder.AddInstruction(
HloInstruction::CreateGetTupleElement(shape, param, 0));
auto gte1 = true_builder.AddInstruction(
HloInstruction::CreateGetTupleElement(shape, param, 1));
auto add = true_builder.AddInstruction(
HloInstruction::CreateBinary(shape, HloOpcode::kAdd, gte0, gte1));
true_builder.AddInstruction(HloInstruction::CreateTuple({add}));
}
HloComputation* true_computation =
m->AddEmbeddedComputation(true_builder.Build());
auto false_builder = HloComputation::Builder(TestName() + "_FalseBranch");
{
Shape xshape = ShapeUtil::MakeShapeWithLayout(F32, {128, 8}, {0, 1});
false_builder.AddInstruction(
HloInstruction::CreateParameter(0, tshape, "param"));
// Using infeed as layout assignment does not mess up with it.
auto token = false_builder.AddInstruction(HloInstruction::CreateToken());
auto infeed = false_builder.AddInstruction(
HloInstruction::CreateInfeed(xshape, token, ""));
auto infeed_data = false_builder.AddInstruction(
HloInstruction::CreateGetTupleElement(xshape, infeed, 0));
false_builder.AddInstruction(HloInstruction::CreateTuple({infeed_data}));
}
HloComputation* false_computation =
m->AddEmbeddedComputation(false_builder.Build());
builder.AddInstruction(HloInstruction::CreateConditional(
result_tshape, pred, tuple, true_computation, tuple, false_computation));
HloComputation* computation = m->AddEntryComputation(builder.Build());
ComputationLayout computation_layout(computation->ComputeProgramShape());
AssignLayouts(m.get(), &computation_layout);
const HloInstruction* true_root = true_computation->root_instruction();
const HloInstruction* false_root = false_computation->root_instruction();
EXPECT_THAT(true_root->opcode(), HloOpcode::kTuple);
EXPECT_THAT(false_root->opcode(), HloOpcode::kTuple);
const HloInstruction* true_result = true_root->operand(0);
const HloInstruction* false_result = false_root->operand(0);
EXPECT_TRUE(LayoutUtil::Equal(true_result->shape().layout(),
false_result->shape().layout()));
EXPECT_THAT(false_result->opcode(), HloOpcode::kCopy);
}
TEST_F(LayoutAssignmentTest, InternalErrorOnBitcast) {
auto builder = HloComputation::Builder(TestName());
auto constant0 = builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR2WithLayout<float>(
{{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({0, 1}))));
builder.AddInstruction(
HloInstruction::CreateBitcast(constant0->shape(), constant0));
auto m = CreateNewVerifiedModule();
m->AddEntryComputation(builder.Build());
ComputationLayout computation_layout(
m->entry_computation()->ComputeProgramShape());
LayoutAssignment layout_assignment(&computation_layout);
Status error_status = layout_assignment.Run(m.get()).status();
EXPECT_FALSE(error_status.ok());
EXPECT_THAT(
error_status.error_message(),
::testing::HasSubstr(
"Unexpected bitcast operation seen during layout assignment"));
}
TEST_F(LayoutAssignmentTest, ChannelLayoutMismatch) {
// Pin non matching layouts to parameter and root.
const char* module_str = R"(
HloModule test_module
ENTRY entry_computation {
param = (f32[2,2]) parameter(0)
gte = f32[2,2] get-tuple-element(param), index=0
token0 = token[] after-all()
recv = (f32[2,2], u32[], token[]) recv(token0), channel_id=1, sharding={maximal device=1}
recv-done = (f32[2,2], token[]) recv-done(recv), channel_id=1,
sharding={maximal device=1}
ROOT root = f32[2,2] get-tuple-element(recv-done), index=0
send = (f32[2,2], u32[], token[]) send(gte, token0), channel_id=1,
sharding={maximal device=0}
send-done = token[] send-done(send), channel_id=1, sharding={maximal device=0}
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> m,
ParseAndReturnVerifiedModule(module_str));
ComputationLayout computation_layout(
m->entry_computation()->ComputeProgramShape());
Shape param_shape = ShapeUtil::MakeTupleShape(
{ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {0, 1})});
TF_ASSERT_OK(
computation_layout.mutable_parameter_layout(0)->CopyLayoutFromShape(
param_shape));
computation_layout.mutable_result_layout()->ResetLayout(
LayoutUtil::MakeLayout({1, 0}));
ChannelLayoutConstraints channel_constraints;
AssignLayouts(m.get(), &computation_layout, &channel_constraints);
EXPECT_TRUE(ShapeUtil::Equal(FindInstruction(m.get(), "send")->shape(),
FindInstruction(m.get(), "recv")->shape()));
}
TEST_F(LayoutAssignmentTest, AllReduceLayoutMissmatch) {
// Pin non matching layouts to parameter and root.
const char* module_str = R"(
HloModule test_module
add {
lhs = f32[] parameter(0)
rhs = f32[] parameter(1)
ROOT add = f32[] add(lhs, rhs)
}
ENTRY entry_computation {
param = (f32[2,2]) parameter(0)
gte = f32[2,2] get-tuple-element(param), index=0
ar.0 = f32[2,2] all-reduce(gte),
channel_id=1, replica_groups={{0}}, to_apply=add,
sharding={maximal device=0}
const = f32[2,2] constant({{0,1},{2,3}})
ROOT ar.1 = f32[2,2] all-reduce(const),
channel_id=1, replica_groups={{0}}, to_apply=add,
sharding={maximal device=1}
})";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> m,
ParseAndReturnVerifiedModule(module_str));
ComputationLayout computation_layout(
m->entry_computation()->ComputeProgramShape());
Shape param_shape = ShapeUtil::MakeTupleShape(
{ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {0, 1})});
TF_ASSERT_OK(
computation_layout.mutable_parameter_layout(0)->CopyLayoutFromShape(
param_shape));
computation_layout.mutable_result_layout()->ResetLayout(
LayoutUtil::MakeLayout({1, 0}));
ChannelLayoutConstraints channel_constraints;
AssignLayouts(m.get(), &computation_layout, &channel_constraints);
EXPECT_THAT(LayoutOf(m.get(), "gte"), ElementsAre(0, 1));
EXPECT_THAT(LayoutOf(m.get(), "ar.0"), ElementsAre(0, 1));
EXPECT_THAT(LayoutOf(m.get(), "ar.1"), ElementsAre(0, 1));
const HloInstruction* root = m->entry_computation()->root_instruction();
EXPECT_THAT(root->shape().layout().minor_to_major(), ElementsAre(1, 0));
}
TEST_F(LayoutAssignmentTest, CopySliceOperandToAvoidImplicitLayoutChange) {
const char* module_str = R"(
HloModule CopySliceOperandToAvoidImplicitLayoutChange
ENTRY CopySliceOperandToAvoidImplicitLayoutChange {
par0 = f32[3,4]{1,0} parameter(0)
par1 = f32[4,5]{0,1} parameter(1)
slice0 = f32[3,4] slice(par1), slice={[1:4],[1:5]}
ROOT add0 = f32[3,4]{1,0} add(par0,slice0)
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> m,
ParseAndReturnVerifiedModule(module_str));
auto compiled_module =
backend()
.compiler()
->RunHloPasses(m->Clone(), backend().default_stream_executor(),
/*device_allocator=*/nullptr)
.ConsumeValueOrDie();
HloInstruction* root =
compiled_module->entry_computation()->root_instruction();
Shape shape_copy = ShapeUtil::MakeShapeWithLayout(F32, {4, 5}, {1, 0});
EXPECT_THAT(
root,
GmockMatch(m::Add(
m::Parameter(),
m::Slice(m::Copy(m::Parameter(1)).WithShapeEqualTo(&shape_copy)))));
}
TEST_F(LayoutAssignmentTest, CopyDSliceOperandToAvoidImplicitLayoutChange) {
const char* module_str = R"(
HloModule CopyDSliceOperandToAvoidImplicitLayoutChange
ENTRY CopyDSliceOperandToAvoidImplicitLayoutChange {
par0 = f32[3,4]{1,0} parameter(0)
par1 = f32[4,5]{0,1} parameter(1)
par2 = s32[] parameter(2)
par3 = s32[] parameter(3)
dslice0 = f32[3,4] dynamic-slice(par1, par2, par3), dynamic_slice_sizes={3,4}
ROOT add0 = f32[3,4]{1,0} add(par0,dslice0)
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> m,
ParseAndReturnVerifiedModule(module_str));
auto compiled_module =
backend()
.compiler()
->RunHloPasses(m->Clone(), backend().default_stream_executor(),
/*device_allocator=*/nullptr)
.ConsumeValueOrDie();
HloInstruction* root =
compiled_module->entry_computation()->root_instruction();
Shape shape_copy = ShapeUtil::MakeShapeWithLayout(F32, {4, 5}, {1, 0});
EXPECT_THAT(root,
GmockMatch(m::Add(
m::Parameter(),
m::DynamicSlice(
m::Copy(m::Parameter(1)).WithShapeEqualTo(&shape_copy),
m::Parameter(2), m::Parameter(3)))));
}
TEST_F(LayoutAssignmentTest, CopyConcatOperandToAvoidImplicitLayoutChange) {
const char* module_str = R"(
HloModule CopyConcatOperandToAvoidImplicitLayoutChange
ENTRY CopyConcatOperandToAvoidImplicitLayoutChange {
par0 = f32[3,8]{1,0} parameter(0)
par1 = f32[3,5]{0,1} parameter(1)
par2 = f32[3,3]{1,0} parameter(2)
concat0 = f32[3,8] concatenate(f32[3,5] par1, f32[3,3] par2),
dimensions={1}
ROOT add0 = f32[3,8]{1,0} add(par0,concat0)
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> m,
ParseAndReturnVerifiedModule(module_str));
auto compiled_module =
backend()
.compiler()
->RunHloPasses(m->Clone(), backend().default_stream_executor(),
/*device_allocator=*/nullptr)
.ConsumeValueOrDie();
HloInstruction* root =
compiled_module->entry_computation()->root_instruction();
Shape shape_copy = ShapeUtil::MakeShapeWithLayout(F32, {3, 5}, {1, 0});
EXPECT_THAT(
root,
GmockMatch(m::Add(
m::Parameter(),
m::Concatenate(m::Copy(m::Parameter(1)).WithShapeEqualTo(&shape_copy),
m::Parameter(2)))));
}
TEST_F(LayoutAssignmentTest,
ConvolutionOperandWithImplicitLayoutChangeNotCopied) {
const char* module_str = R"(
HloModule ConvolutionOperandWithImplicitLayoutChangeNotCopied
ENTRY ConvolutionOperandWithImplicitLayoutChangeNotCopied {
par0 = f32[128,3,230,230]{2,3,1,0} parameter(0)
par1 = f32[7,7,3,64]{3,2,0,1} parameter(1)
ROOT convolution0 = f32[128,64,112,112]{3,2,1,0} convolution(par0, par1),
window={size=7x7 stride=2x2}, dim_labels=bf01_01io->bf01,
feature_group_count=1
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> m,
ParseAndReturnVerifiedModule(module_str));
auto compiled_module =
backend()
.compiler()
->RunHloPasses(m->Clone(), backend().default_stream_executor(),
/*device_allocator=*/nullptr)
.ConsumeValueOrDie();
HloInstruction* root =
compiled_module->entry_computation()->root_instruction();
EXPECT_THAT(root,
GmockMatch(m::Convolution(m::Parameter(0), m::Parameter(1))));
}
TEST_F(LayoutAssignmentTest, PropagatingLayoutFromResultToOperand) {
const char* module_str = R"(
HloModule PropagatingLayoutFromResultToOperand
ENTRY PropagatingLayoutFromResultToOperand {
par0 = f32[4,5]{1,0} parameter(0)
ROOT slice0 = f32[3,4]{0,1} slice(par0), slice={[1:4],[1:5]}
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> m,
ParseAndReturnVerifiedModule(module_str));
auto compiled_module =
backend()
.compiler()
->RunHloPasses(m->Clone(), backend().default_stream_executor(),
/*device_allocator=*/nullptr)
.ConsumeValueOrDie();
HloInstruction* root =
compiled_module->entry_computation()->root_instruction();
Shape shape_copy = ShapeUtil::MakeShapeWithLayout(F32, {4, 5}, {0, 1});
EXPECT_THAT(root,
GmockMatch(m::Slice(
m::Copy(m::Parameter(0)).WithShapeEqualTo(&shape_copy))));
}
TEST_F(LayoutAssignmentTest, TupleCopyOnLayoutMismatch) {
// The first infeed uses layout {0,1}, while the second uses layout {1,0}.
// The mismatch forces a copy of the tuple. The tuple contains a token, so
// layout assignment will fail if it tries to copy the whole tuple.
const char* module_str = R"(
HloModule TupleCopyOnLayoutMismatch
condition.1 (tup: (s32[], token[], f32[512,1024]{0,1})) -> pred[] {
tup.1 = (s32[], token[], f32[512,1024]{0,1}) parameter(0)
counter.1 = s32[] get-tuple-element(tup.1), index=0
five = s32[] constant(5)
ROOT lt = pred[] compare(counter.1, five), direction=LT
}
body.2 (tup: (s32[], token[], f32[512,1024]{0,1})) -> (s32[], token[], f32[512,1024]{0,1}) {
tup.2 = (s32[], token[], f32[512,1024]{0,1}) parameter(0)
counter.2 = s32[] get-tuple-element(tup.2), index=0
tok.2 = token[] get-tuple-element(tup.2), index=1
ifeed.2 = (f32[512,1024]{1,0}, token[]) infeed(tok.2)
next_tok = token[] get-tuple-element(ifeed.2), index=1
next_buf = f32[512,1024]{1,0} get-tuple-element(ifeed.2), index=0
one = s32[] constant(1)
next_counter = s32[] add(counter.2, one)
ROOT tup = (s32[], token[], f32[512,1024]{0,1}) tuple(next_counter, next_tok, next_buf)
}
ENTRY main () -> f32[512,1024]{0,1} {
start_tok = token[] after-all()
ifeed.3 = (f32[512,1024]{0,1}, token[]) infeed(start_tok)
itok = token[] get-tuple-element(ifeed.3), index=1
ibuf = f32[512,1024]{0,1} get-tuple-element(ifeed.3), index=0
zero = s32[] constant(0)
itup = (s32[], token[], f32[512,1024]{0,1}) tuple(zero, itok, ibuf)
loop = (s32[], token[], f32[512,1024]{0,1}) while(itup), condition=condition.1, body=body.2
ROOT result = f32[512,1024]{0,1} get-tuple-element(loop), index=2
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> m,
ParseAndReturnVerifiedModule(module_str));
ComputationLayout computation_layout(
m->entry_computation()->ComputeProgramShape());
// Sanity check to verify that there's a layout mismatch.
EXPECT_THAT(LayoutOf(m.get(), "ibuf"), ElementsAre(0, 1));
EXPECT_THAT(LayoutOf(m.get(), "next_buf"), ElementsAre(1, 0));
AssignLayouts(m.get(), &computation_layout);
// Make sure that layout assignment did not magically eliminate the mismatch,
// in which case the test didn't prove anything.
EXPECT_THAT(LayoutOf(m.get(), "ibuf"), ElementsAre(0, 1));
EXPECT_THAT(LayoutOf(m.get(), "next_buf"), ElementsAre(1, 0));
}
TEST_F(LayoutAssignmentTest, CustomCallNotLayoutConstrained) {
const char* module_str = R"(
HloModule CustomCallNotLayoutConstrained
ENTRY %CustomCallWithNotLayoutConstrained (p: f32[42,2,3]) -> f32[1,2,3,4] {
%p = f32[42,2,3] parameter(0)
ROOT %custom-call = f32[1,2,3,4] custom-call(f32[42,2,3] %p), custom_call_target="baz"
}
)";
// Try with a couple different layouts. In each case the custom calls operand
// and result layout should match that of the computation.
{
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<VerifiedHloModule> m,
ParseAndReturnVerifiedModule(module_str, GetModuleConfigForTest()));
ComputationLayout computation_layout = m->entry_computation_layout();
*computation_layout.mutable_parameter_layout(0) =
ShapeLayout(ShapeUtil::MakeShapeWithLayout(F32, {42, 2, 3}, {0, 2, 1}));
*computation_layout.mutable_result_layout() = ShapeLayout(
ShapeUtil::MakeShapeWithLayout(F32, {1, 2, 3, 4}, {3, 2, 0, 1}));
AssignLayouts(m.get(), &computation_layout);
HloInstruction* root = m->entry_computation()->root_instruction();
ASSERT_THAT(root, GmockMatch(m::CustomCall(m::Parameter())));
ExpectLayoutIs(root->shape(), {3, 2, 0, 1});
ExpectLayoutIs(root->operand(0)->shape(), {0, 2, 1});
}
{
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<VerifiedHloModule> m,
ParseAndReturnVerifiedModule(module_str, GetModuleConfigForTest()));
ComputationLayout computation_layout = m->entry_computation_layout();
*computation_layout.mutable_parameter_layout(0) =
ShapeLayout(ShapeUtil::MakeShapeWithLayout(F32, {42, 2, 3}, {0, 1, 2}));
*computation_layout.mutable_result_layout() = ShapeLayout(
ShapeUtil::MakeShapeWithLayout(F32, {1, 2, 3, 4}, {0, 2, 3, 1}));
AssignLayouts(m.get(), &computation_layout);
HloInstruction* root = m->entry_computation()->root_instruction();
ASSERT_THAT(root, GmockMatch(m::CustomCall(m::Parameter())));
ExpectLayoutIs(root->shape(), {0, 2, 3, 1});
ExpectLayoutIs(root->operand(0)->shape(), {0, 1, 2});
}
}
TEST_F(LayoutAssignmentTest, CustomCallLayoutConstrained) {
const char* module_str = R"(
HloModule CustomCallLayoutConstrained
ENTRY %CustomCallWithLayoutConstraints (p0: f32[4,4], p1: f32[2,3]) -> f32[1,2,3,4] {
%p0 = f32[4,4] parameter(0)
%p1 = f32[2,3] parameter(1)
ROOT %custom-call = f32[1,2,3,4]{3,2,0,1} custom-call(f32[4,4] %p0, f32[2,3] %p1), custom_call_target="baz", operand_layout_constraints={f32[4,4]{0,1}, f32[2,3]{1,0}}
}
)";
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<VerifiedHloModule> m,
ParseAndReturnVerifiedModule(module_str, GetModuleConfigForTest()));
ComputationLayout computation_layout = m->entry_computation_layout();
*computation_layout.mutable_parameter_layout(0) =
ShapeLayout(ShapeUtil::MakeShapeWithLayout(F32, {4, 4}, {1, 0}));
*computation_layout.mutable_parameter_layout(1) =
ShapeLayout(ShapeUtil::MakeShapeWithLayout(F32, {2, 3}, {1, 0}));
*computation_layout.mutable_result_layout() = ShapeLayout(
ShapeUtil::MakeShapeWithLayout(F32, {1, 2, 3, 4}, {2, 1, 0, 3}));
AssignLayouts(m.get(), &computation_layout);
// The custom call should be partially encapsulated in kCopy instructions
// because of the layout mismatches.
ASSERT_THAT(m->entry_computation()->root_instruction(),
GmockMatch(m::Copy(m::CustomCall(m::Copy(), m::Parameter()))));
const HloInstruction* custom_call =
m->entry_computation()->root_instruction()->operand(0);
ExpectLayoutIs(custom_call->shape(), {3, 2, 0, 1});
ExpectLayoutIs(custom_call->operand(0)->shape(), {0, 1});
ExpectLayoutIs(custom_call->operand(1)->shape(), {1, 0});
}
TEST_F(LayoutAssignmentTest, CustomCallLayoutConstrainedAliasedOutput) {
const char* module_str = R"(
HloModule customcall.4
ENTRY %customcall.4 (parameter.1: f32[8,128], parameter.2: f32[8,128]) -> f32[8,128] {
%parameter.1 = f32[8,128]{1,0} parameter(0)
%parameter.2 = f32[8,128]{1,0} parameter(1)
ROOT %custom-call.3 = f32[8,128]{1,0} custom-call(f32[8,128]{1,0} %parameter.1, f32[8,128]{1,0} %parameter.2), custom_call_target="gpu_example_custom_call", operand_layout_constraints={f32[8,128]{1,0}, f32[8,128]{1,0}}, custom_call_has_side_effect=true, output_to_operand_aliasing={{}: (0, {})}
})";
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<VerifiedHloModule> m,
ParseAndReturnVerifiedModule(module_str, GetModuleConfigForTest()));
ComputationLayout computation_layout = m->entry_computation_layout();
*computation_layout.mutable_parameter_layout(0) =
ShapeLayout(ShapeUtil::MakeShapeWithLayout(F32, {8, 128}, {1, 0}));
*computation_layout.mutable_parameter_layout(1) =
ShapeLayout(ShapeUtil::MakeShapeWithLayout(F32, {8, 128}, {1, 0}));
*computation_layout.mutable_result_layout() =
ShapeLayout(ShapeUtil::MakeShapeWithLayout(F32, {8, 128}, {1, 0}));
AssignLayouts(m.get(), &computation_layout);
const HloInstruction* custom_call =
m->entry_computation()->root_instruction();
ExpectLayoutIs(custom_call->shape(), {1, 0});
ExpectLayoutIs(custom_call->operand(0)->shape(), {1, 0});
ExpectLayoutIs(custom_call->operand(1)->shape(), {1, 0});
}
TEST_F(LayoutAssignmentTest, CustomCallLayoutConstrainedZeroOperands) {
const char* module_str = R"(
HloModule CustomCallLayoutConstrainedZeroOperands
ENTRY %CustomCallLayoutConstrainedZeroOperands () -> f32[1,2,3,4] {
ROOT %custom-call = f32[1,2,3,4]{3,2,0,1} custom-call(), custom_call_target="baz", operand_layout_constraints={}
}
)";
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<VerifiedHloModule> m,
ParseAndReturnVerifiedModule(module_str, GetModuleConfigForTest()));
ComputationLayout computation_layout = m->entry_computation_layout();
*computation_layout.mutable_result_layout() = ShapeLayout(
ShapeUtil::MakeShapeWithLayout(F32, {1, 2, 3, 4}, {2, 1, 0, 3}));
AssignLayouts(m.get(), &computation_layout);
ASSERT_THAT(m->entry_computation()->root_instruction(),
GmockMatch(m::Copy(m::CustomCall())));
const HloInstruction* custom_call =
m->entry_computation()->root_instruction()->operand(0);
ExpectLayoutIs(custom_call->shape(), {3, 2, 0, 1});
}
TEST_F(LayoutAssignmentTest, CustomCallLayoutConstrainedTupleOperand) {
const char* module_str = R"(
HloModule CustomCallLayoutConstrainedTupleOperand
ENTRY %CustomCallLayoutConstrainedTupleOperand (p0: f32[4,4], p1: f32[2,3]) -> f32[1,2,3,4] {
%p0 = f32[4,4] parameter(0)
%p1 = f32[2,3] parameter(1)
%tuple = (f32[4,4], f32[2,3]) tuple(%p0, %p1)
ROOT %custom-call = f32[1,2,3,4]{3,2,0,1} custom-call(%tuple), custom_call_target="baz", operand_layout_constraints={(f32[4,4]{1,0}, f32[2,3]{0,1})}
}
)";
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<VerifiedHloModule> m,
ParseAndReturnVerifiedModule(module_str, GetModuleConfigForTest()));
ComputationLayout computation_layout = m->entry_computation_layout();
*computation_layout.mutable_parameter_layout(0) =
ShapeLayout(ShapeUtil::MakeShapeWithLayout(F32, {4, 4}, {1, 0}));
*computation_layout.mutable_parameter_layout(1) =
ShapeLayout(ShapeUtil::MakeShapeWithLayout(F32, {2, 3}, {1, 0}));
*computation_layout.mutable_result_layout() = ShapeLayout(
ShapeUtil::MakeShapeWithLayout(F32, {1, 2, 3, 4}, {2, 1, 0, 3}));
AssignLayouts(m.get(), &computation_layout);
HloInstruction* root = m->entry_computation()->root_instruction();
ExpectLayoutIs(root->shape(), {2, 1, 0, 3});
ASSERT_THAT(m->entry_computation()->root_instruction(),
GmockMatch(m::Copy(m::CustomCall(m::Tuple()))));
const HloInstruction* custom_call =
m->entry_computation()->root_instruction()->operand(0);
ExpectLayoutIs(custom_call->shape(), {3, 2, 0, 1});
ExpectTupleLayoutIs(custom_call->operand(0)->shape(), {{1, 0}, {0, 1}});
}
TEST_F(LayoutAssignmentTest, CustomCallLayoutConstrainedTupleResult) {
const char* module_str = R"(
HloModule CustomCallLayoutConstrainedTupleResult
ENTRY %CustomCallLayoutConstrainedTupleResult (p0: f32[4,4]) -> (f32[4,4]{1,0}, f32[2,3]{0,1}) {
%p0 = f32[4,4] parameter(0)
ROOT %custom-call = (f32[4,4]{1,0}, f32[2,3]{0,1}) custom-call(%p0), custom_call_target="baz", operand_layout_constraints={f32[4,4]{1,0}}
}
)";
// Try with a couple different layouts. In each case the custom calls operand
// and result layout should match that of the computation.
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<VerifiedHloModule> m,
ParseAndReturnVerifiedModule(module_str, GetModuleConfigForTest()));
ComputationLayout computation_layout = m->entry_computation_layout();
*computation_layout.mutable_parameter_layout(0) =
ShapeLayout(ShapeUtil::MakeShapeWithLayout(F32, {4, 4}, {1, 0}));
*computation_layout.mutable_result_layout() =
ShapeLayout(ShapeUtil::MakeTupleShape(
{ShapeUtil::MakeShapeWithLayout(F32, {4, 4}, {1, 0}),
ShapeUtil::MakeShapeWithLayout(F32, {2, 3}, {1, 0})}));
AssignLayouts(m.get(), &computation_layout);
ExpectTupleLayoutIs(m->result_shape(), {{1, 0}, {1, 0}});
const HloInstruction* custom_call = FindInstruction(m.get(), "custom-call");
ExpectTupleLayoutIs(custom_call->shape(), {{1, 0}, {0, 1}});
}
Status AssignLayoutsToComputation(
HloModule* m, ChannelLayoutConstraints* channel_constraints = nullptr) {
if (!m->entry_computation_layout().result_layout().LayoutIsSet()) {
m->mutable_entry_computation_layout()
->mutable_result_layout()
->SetToDefaultLayout();
}
LayoutAssignment layout_assignment(m->mutable_entry_computation_layout(),
channel_constraints);
return layout_assignment.Run(m).status();
}
TEST_F(LayoutAssignmentTest, OverwriteDiamondShapedConstraintsX) {
// Check that we handle a diamond-shaped graph correctly.
// transpose
// / \
// add |
// \ /
// tuple
auto b = HloComputation::Builder(TestName());
Shape ashape = ShapeUtil::MakeShape(F32, {12, 8});
Shape bshape = ShapeUtil::MakeShape(F32, {8, 12});
auto param0 =
b.AddInstruction(HloInstruction::CreateParameter(0, bshape, "input"));
auto param1 =
b.AddInstruction(HloInstruction::CreateParameter(1, ashape, "input"));
auto transpose =
b.AddInstruction(HloInstruction::CreateTranspose(ashape, param0, {1, 0}));
auto add = b.AddInstruction(
HloInstruction::CreateBinary(ashape, HloOpcode::kAdd, transpose, param1));
b.AddInstruction(HloInstruction::CreateTuple({add, transpose}));
auto m = CreateNewVerifiedModule();
m->AddEntryComputation(b.Build());
Shape ashape_major = ShapeUtil::MakeShapeWithLayout(F32, {12, 8}, {1, 0});
Shape ashape_minor = ShapeUtil::MakeShapeWithLayout(F32, {12, 8}, {0, 1});
*m->mutable_entry_computation_layout()->mutable_result_layout() =
ShapeLayout(ShapeUtil::MakeTupleShape({ashape_major, ashape_minor}));
const Layout r2_dim0major = LayoutUtil::MakeLayout({1, 0});
ForceParameterLayout(m.get(), 0, r2_dim0major);
ForceParameterLayout(m.get(), 1, r2_dim0major);
TF_ASSERT_OK(AssignLayoutsToComputation(m.get()));
EXPECT_THAT(add->shape().layout().minor_to_major(), ElementsAre(1, 0));
EXPECT_THAT(add->operand(0)->shape().layout().minor_to_major(),
ElementsAre(1, 0));
EXPECT_THAT(add->operand(1)->shape().layout().minor_to_major(),
ElementsAre(1, 0));
EXPECT_THAT(transpose->shape().layout().minor_to_major(), ElementsAre(0, 1));
}
// Tests that the layout assignment supports layout-constrained all-reduce with
// different operand layouts (b/146056839).
TEST_F(LayoutAssignmentTest, LayoutConstrainedAllReduce) {
const char* module_str = R"(
HloModule test_module
add {
lhs = f32[] parameter(0)
rhs = f32[] parameter(1)
ROOT add = f32[] add(lhs, rhs)
}
ENTRY entry_computation {
param = (f32[8,4]{0,1}, f32[16,2]{0,1}) parameter(0)
gte0 = f32[8,4] get-tuple-element(param), index=0
gte1 = f32[16,2] get-tuple-element(param), index=1
crs = (f32[8,4]{0,1}, f32[16,2]{1,0}) all-reduce(gte0, gte1),
replica_groups={}, constrain_layout=true, to_apply=add
gte2 = f32[8,4] get-tuple-element(crs), index=0
gte3 = f32[16,2] get-tuple-element(crs), index=1
ROOT result = (f32[8,4]{1,0}, f32[16,2]{1,0}) tuple(gte2, gte3)
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> m,
ParseAndReturnVerifiedModule(module_str));
ComputationLayout computation_layout(
m->entry_computation()->ComputeProgramShape(), /*ignore_layouts=*/false);
ChannelLayoutConstraints channel_constraints;
AssignLayouts(m.get(), &computation_layout, &channel_constraints);
const HloInstruction* crs = FindInstruction(m.get(), "crs");
ExpectTupleLayoutIs(crs->shape(), {{0, 1}, {1, 0}});
ExpectLayoutIs(crs->operand(0)->shape(), {0, 1});
ExpectLayoutIs(crs->operand(1)->shape(), {1, 0});
}
TEST_F(LayoutAssignmentTest, LayoutConstrainedAllToAll) {
const char* module_str = R"(
HloModule test_module
add {
lhs = f32[] parameter(0)
rhs = f32[] parameter(1)
ROOT add = f32[] add(lhs, rhs)
}
ENTRY entry_computation {
param = (f32[16,4]{0,1}, f32[16,4]{1,0}) parameter(0)
gte0 = f32[16,4] get-tuple-element(param), index=0
gte1 = f32[16,4] get-tuple-element(param), index=1
alltoall = (f32[16,4]{1,0}, f32[16,4]{1,0}) all-reduce(gte0, gte1),
replica_groups={{0,1}}, constrain_layout=true, to_apply=add
gte2 = f32[16,4] get-tuple-element(alltoall), index=0
gte3 = f32[16,4] get-tuple-element(alltoall), index=1
ROOT concat = f32[16,8]{0,1} concatenate(gte2, gte3), dimensions={1}
}
)";
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<HloModule> m,
ParseAndReturnVerifiedModule(module_str, /*replica_count=*/2));
ComputationLayout computation_layout(
m->entry_computation()->ComputeProgramShape(), /*ignore_layouts=*/false);
ChannelLayoutConstraints channel_constraints;
AssignLayouts(m.get(), &computation_layout, &channel_constraints);
const HloInstruction* alltoall = FindInstruction(m.get(), "alltoall");
ExpectTupleLayoutIs(alltoall->shape(), {{1, 0}, {1, 0}});
ExpectLayoutIs(alltoall->operand(0)->shape(), {1, 0});
ExpectLayoutIs(alltoall->operand(1)->shape(), {1, 0});
}
TEST_F(LayoutAssignmentTest, DynamicRoot) {
const char* module_str = R"(
HloModule test_module
ENTRY entry_computation {
param = f32[1,<=16]{0,1} parameter(0)
ROOT abs = f32[1,<=16]{0,1} abs(param)
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> m,
ParseAndReturnVerifiedModule(module_str));
ComputationLayout computation_layout(
m->entry_computation()->ComputeProgramShape(), /*ignore_layouts=*/false);
computation_layout.mutable_result_layout()->ClearDynamicShape();
AssignLayouts(m.get(), &computation_layout);
const HloInstruction* abs = FindInstruction(m.get(), "abs");
ExpectLayoutIs(abs->operand(0)->shape(), {0, 1});
ExpectLayoutIs(abs->shape(), {0, 1});
EXPECT_TRUE(abs->shape().is_dynamic_dimension(1));
}
// Test the ability to avoid copying across computations by reversing
// computation traversal order.
TEST_F(LayoutAssignmentTest, ReverseComputationOrderAvoidCopy) {
const char* module_str = R"(
HloModule ComputationLayoutAvoidCopy
call_1 {
%arg_tuple.1 = (f32[93184,4]) parameter(0)
%get-tuple-element.1 = f32[93184,4] get-tuple-element(%arg_tuple.1), index=0
ROOT %reshape.8494 = f32[2,512,364]{2,1,0} reshape(f32[93184,4]{1,0}%get-tuple-element.1)
}
on_true {
%arg_tuple.1 = (f32[93184,4]) parameter(0)
%get-tuple-element.1 = f32[93184,4] get-tuple-element(%arg_tuple.1), index=0
ROOT %reshape.8493 = f32[2,512,364]{2,1,0} reshape(f32[93184,4]{1,0}%get-tuple-element.1)
}
on_false {
%arg_tuple.2 = (f32[93184,4]) parameter(0)
%get-tuple-element.3 = f32[93184,4] get-tuple-element(%arg_tuple.2), index=0
%reshape.9717 = f32[2,512,364]{2,1,0} reshape(f32[93184,4]{1,0}%get-tuple-element.3)
ROOT %add = f32[2,512,364] add(%reshape.9717, %reshape.9717)
}
ENTRY main {
pred.1 = pred[] parameter(0)
arg.2 = f32[93184,4]{1,0} parameter(1)
arg_tuple.11 = (f32[93184,4]{1,0}) tuple(arg.2)
call.1 = f32[2,512,364] call(arg_tuple.11), to_apply=call_1
conditional = f32[2,512,364] conditional(pred.1, arg_tuple.11, arg_tuple.11),
true_computation=on_true, false_computation=on_false
ROOT add = f32[2,512,364] add(call.1, conditional)
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> m,
ParseAndReturnVerifiedModule(module_str));
std::cerr << m->ToString();
ComputationLayout computation_layout(
m->entry_computation()->ComputeProgramShape());
*computation_layout.mutable_parameter_layout(0) =
ShapeLayout(ShapeUtil::MakeShape(PRED, {}));
*computation_layout.mutable_parameter_layout(1) =
ShapeLayout(ShapeUtil::MakeShapeWithLayout(F32, {93184, 4}, {0, 1}));
*computation_layout.mutable_result_layout() = ShapeLayout(
ShapeUtil::MakeShapeWithLayout(F32, {2, 512, 364}, {0, 1, 2}));
std::cerr << computation_layout.ToString();
ChannelLayoutConstraints channel_constraints;
LayoutAssignment layout_assignment(
&computation_layout,
/*channel_constraints=*/&channel_constraints,
/* reverse_computation_order = */ true);
EXPECT_IS_OK(layout_assignment.Run(m.get()).status());
std::cerr << m->ToString();
const HloInstruction* call_1 = FindInstruction(m.get(), "reshape.8494");
ExpectLayoutIs(call_1->shape(), {0, 1, 2});
const HloInstruction* on_true = FindInstruction(m.get(), "reshape.8493");
ExpectLayoutIs(on_true->shape(), {0, 1, 2});
const HloInstruction* on_false = FindInstruction(m.get(), "reshape.9717");
ExpectLayoutIs(on_false->shape(), {0, 1, 2});
}
} // namespace
} // namespace xla