blob: 4ea40d204c40a0e80a4325643ae651e33fdc9eef [file] [log] [blame]
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include <optional>
#include <set>
#include <string>
#include <utility>
#include <vector>
#include "absl/container/flat_hash_map.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/protobuf_util.h"
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
#include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h"
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/test_helpers.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/window_util.h"
#include "tensorflow/core/lib/core/status_test_util.h"
namespace xla {
namespace {
using ::testing::ElementsAre;
using ::testing::UnorderedElementsAre;
class HloInstructionTest : public HloTestBase {
protected:
Shape r0f32_ = ShapeUtil::MakeShape(F32, {});
};
// Simple visitor that collects the number of users and operands for certain HLO
// nodes. It also verifies some of the DFS visiting invariants (operands visited
// before their users, nodes not visited twice, etc.)
class OpAndUserCollectingVisitor : public DfsHloVisitorWithDefault {
public:
Status DefaultAction(HloInstruction* hlo_instruction) override {
return Unimplemented("not implemented %s",
HloOpcodeString(hlo_instruction->opcode()));
}
Status HandleParameter(HloInstruction* parameter) override {
EXPECT_FALSE(count_.contains(parameter));
count_[parameter] = GetCountsForNode(parameter);
return OkStatus();
}
Status HandleConstant(HloInstruction* constant) override {
EXPECT_FALSE(count_.contains(constant));
count_[constant] = GetCountsForNode(constant);
return OkStatus();
}
Status HandleAdd(HloInstruction* add) override {
auto lhs = add->operand(0);
auto rhs = add->operand(1);
EXPECT_FALSE(count_.contains(add));
EXPECT_TRUE(count_.contains(lhs));
EXPECT_TRUE(count_.contains(rhs));
count_[add] = GetCountsForNode(add);
return OkStatus();
}
Status HandleNegate(HloInstruction* negate) override {
auto operand = negate->operand(0);
EXPECT_FALSE(count_.contains(negate));
EXPECT_TRUE(count_.contains(operand));
count_[negate] = GetCountsForNode(negate);
return OkStatus();
}
Status HandleMap(HloInstruction* map) override {
EXPECT_FALSE(count_.contains(map));
for (HloInstruction* arg : map->operands()) {
EXPECT_TRUE(count_.contains(arg));
}
count_[map] = GetCountsForNode(map);
return OkStatus();
}
Status HandleReduce(HloInstruction* reduce) override {
auto arg = reduce->operand(0);
auto init_value = reduce->operand(1);
EXPECT_FALSE(count_.contains(reduce));
EXPECT_TRUE(count_.contains(arg));
EXPECT_TRUE(count_.contains(init_value));
count_[reduce] = GetCountsForNode(reduce);
return OkStatus();
}
int64_t NumOperands(const HloInstruction* node) {
auto count_iterator = count_.find(node);
EXPECT_NE(count_.end(), count_iterator);
return count_iterator->second.operand_count;
}
int64_t NumUsers(const HloInstruction* node) {
auto count_iterator = count_.find(node);
EXPECT_NE(count_.end(), count_iterator);
return count_iterator->second.user_count;
}
private:
struct NumOpsAndUsers {
int64_t operand_count;
int64_t user_count;
};
// Helper function to count operands and users for the given HLO.
NumOpsAndUsers GetCountsForNode(const HloInstruction* node) {
NumOpsAndUsers counts{node->operand_count(), node->user_count()};
return counts;
}
// Counters for HLOs. Maps HLO to a NumOpsAndUsers.
absl::flat_hash_map<const HloInstruction*, NumOpsAndUsers> count_;
};
TEST_F(HloInstructionTest, BasicProperties) {
auto parameter = HloInstruction::CreateParameter(1, r0f32_, "foo");
EXPECT_EQ(HloOpcode::kParameter, parameter->opcode());
EXPECT_TRUE(ShapeUtil::IsScalarWithElementType(parameter->shape(), F32));
EXPECT_FALSE(ShapeUtil::IsScalarWithElementType(parameter->shape(), S32));
EXPECT_FALSE(parameter->operand_count());
}
TEST_F(HloInstructionTest, UserWithTwoOperands) {
// [Param foo]-----> |-----|
// | Add |
// [Param bar]-----> |-----|
HloComputation::Builder builder(TestName());
auto foo =
builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "foo"));
auto bar =
builder.AddInstruction(HloInstruction::CreateParameter(1, r0f32_, "bar"));
auto add = builder.AddInstruction(
HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, foo, bar));
auto module = CreateNewVerifiedModule();
module->AddEntryComputation(builder.Build());
EXPECT_THAT(add->operands(), UnorderedElementsAre(foo, bar));
EXPECT_THAT(foo->users(), UnorderedElementsAre(add));
EXPECT_THAT(bar->users(), UnorderedElementsAre(add));
OpAndUserCollectingVisitor visitor;
ASSERT_IS_OK(add->Accept(&visitor));
EXPECT_EQ(2, visitor.NumOperands(add));
EXPECT_EQ(0, visitor.NumUsers(add));
EXPECT_EQ(1, visitor.NumUsers(foo));
EXPECT_EQ(1, visitor.NumUsers(bar));
}
TEST_F(HloInstructionTest, MultipleUsers) {
// [Param foo]
// / | \
// / | \ [Param bar]
// / | \ |
// | | | |
// V V V V
// ------- ------- -----------
// | exp | | exp | | add |
// ------- ------- -----------
HloComputation::Builder builder(TestName());
auto foo =
builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "foo"));
auto bar =
builder.AddInstruction(HloInstruction::CreateParameter(1, r0f32_, "bar"));
auto exp1 = builder.AddInstruction(
HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, foo));
auto exp2 = builder.AddInstruction(
HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, foo));
auto add = builder.AddInstruction(
HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, foo, bar));
auto module = CreateNewVerifiedModule();
module->AddEntryComputation(builder.Build());
EXPECT_EQ(3, foo->user_count());
EXPECT_EQ(1, bar->user_count());
EXPECT_EQ(0, exp1->user_count());
EXPECT_EQ(0, exp2->user_count());
EXPECT_EQ(0, add->user_count());
OpAndUserCollectingVisitor visitor;
ASSERT_IS_OK(add->Accept(&visitor));
EXPECT_EQ(2, visitor.NumOperands(add));
EXPECT_EQ(3, visitor.NumUsers(foo));
}
TEST_F(HloInstructionTest, RepeatedUser) {
// Here we have a user 'add' nodes that uses the same HLO in both operands.
// Make sure we don't count it as two distinct users.
//
// [Param foo]
// | |
// | |
// | |
// V V
// -------
// | add |
// -------
HloComputation::Builder builder(TestName());
auto foo =
builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "foo"));
auto add = builder.AddInstruction(
HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, foo, foo));
auto module = CreateNewVerifiedModule();
module->AddEntryComputation(builder.Build());
EXPECT_EQ(1, foo->user_count());
// But 'add' still has two operands, even if both are the same HLO.
EXPECT_EQ(2, add->operand_count());
}
TEST_F(HloInstructionTest, MultipleUsersAndOperands) {
// [param0] [param1]
// | |
// | [c0] |
// | | |
// V | V
// ------- | -------
// | add | <---^---> | add |
// ------- -------
// | |
// \ ------- /
// ---->| add |<----
// -------
HloComputation::Builder builder(TestName());
auto param0 = builder.AddInstruction(
HloInstruction::CreateParameter(0, r0f32_, "param0"));
auto param1 = builder.AddInstruction(
HloInstruction::CreateParameter(1, r0f32_, "param1"));
auto c0 = builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.1f)));
auto addleft = builder.AddInstruction(
HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, param0, c0));
auto addright = builder.AddInstruction(
HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, c0, param1));
auto addtotal = builder.AddInstruction(
HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, addleft, addright));
auto module = CreateNewVerifiedModule();
module->AddEntryComputation(builder.Build());
OpAndUserCollectingVisitor visitor;
ASSERT_IS_OK(addtotal->Accept(&visitor));
EXPECT_EQ(2, visitor.NumUsers(c0));
EXPECT_EQ(2, visitor.NumOperands(addleft));
EXPECT_EQ(2, visitor.NumOperands(addright));
EXPECT_EQ(2, visitor.NumOperands(addtotal));
}
TEST_F(HloInstructionTest, MultipleUsersAndOperandsWithUnaryOps) {
// [param0] [c0] [param1]
// | | |
// | V |
// | ------- |
// | | neg | |
// | ------- |
// V | V
// ------- | -------
// | add | <---^---> | add |
// ------- -------
// | |
// \ ------- /
// ---->| add |<----
// -------
// |
// V
// -------
// | neg |
// -------
HloComputation::Builder builder(TestName());
auto param0 = builder.AddInstruction(
HloInstruction::CreateParameter(0, r0f32_, "param0"));
auto param1 = builder.AddInstruction(
HloInstruction::CreateParameter(1, r0f32_, "param1"));
auto c0 = builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.1f)));
auto neg1 = builder.AddInstruction(
HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, c0));
auto addleft = builder.AddInstruction(
HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, param0, neg1));
auto addright = builder.AddInstruction(
HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, neg1, param1));
auto addtotal = builder.AddInstruction(
HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, addleft, addright));
auto neg2 = builder.AddInstruction(
HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, addtotal));
auto module = CreateNewVerifiedModule();
module->AddEntryComputation(builder.Build());
OpAndUserCollectingVisitor visitor;
ASSERT_IS_OK(neg2->Accept(&visitor));
EXPECT_EQ(1, visitor.NumUsers(c0));
EXPECT_EQ(2, visitor.NumUsers(neg1));
EXPECT_EQ(2, visitor.NumOperands(addleft));
EXPECT_EQ(2, visitor.NumOperands(addright));
EXPECT_EQ(2, visitor.NumOperands(addtotal));
EXPECT_EQ(1, visitor.NumOperands(neg2));
EXPECT_EQ(0, visitor.NumUsers(neg2));
}
TEST_F(HloInstructionTest, TrivialMap) {
// This tests creating a trivial x+1 map as the only operation.
//
// param0[100x10] ---> (map x+1)
//
Shape r0f32 = ShapeUtil::MakeShape(F32, {});
Shape f32a100x10 = ShapeUtil::MakeShape(F32, {100, 10});
auto module = CreateNewVerifiedModule();
// Builds an x+1.0 computation to use in a Map.
auto embedded_builder = HloComputation::Builder("f32+1");
auto param = embedded_builder.AddInstruction(
HloInstruction::CreateParameter(0, r0f32, "x"));
auto value = embedded_builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
embedded_builder.AddInstruction(
HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, param, value));
auto add_f32 = module->AddEmbeddedComputation(embedded_builder.Build());
// Builds a parameter and feeds it to the map.
HloComputation::Builder builder(TestName());
auto param0 = builder.AddInstruction(
HloInstruction::CreateParameter(0, f32a100x10, "p"));
auto map = builder.AddInstruction(
HloInstruction::CreateMap(f32a100x10, {param0}, add_f32));
module->AddEntryComputation(builder.Build());
OpAndUserCollectingVisitor visitor;
ASSERT_IS_OK(map->Accept(&visitor));
// Check counts. We aren't walking the mapper computation yet.
EXPECT_EQ(1, visitor.NumUsers(param0));
EXPECT_EQ(0, visitor.NumUsers(map));
EXPECT_EQ(1, visitor.NumOperands(map));
// TODO(dehnert): Add walking and counters for the wrapped computation.
}
TEST_F(HloInstructionTest, TrivialReduce) {
// This tests creating a trivial x+y reduce as the only operation.
//
// param0[100x10] ---> (reduce x+y)
//
Shape r0f32 = ShapeUtil::MakeShape(F32, {});
Shape f32v100 = ShapeUtil::MakeShape(F32, {100});
Shape f32a100x10 = ShapeUtil::MakeShape(F32, {100, 10});
// Builds an x+y computation to use in a Reduce.
auto embedded_builder = HloComputation::Builder("f32+f32");
auto paramx = embedded_builder.AddInstruction(
HloInstruction::CreateParameter(0, r0f32, "x"));
auto paramy = embedded_builder.AddInstruction(
HloInstruction::CreateParameter(1, r0f32, "y"));
embedded_builder.AddInstruction(
HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, paramx, paramy));
auto module = CreateNewVerifiedModule();
auto add_f32 = module->AddEmbeddedComputation(embedded_builder.Build());
// Builds a parameter and an initial value and feeds them to the reduce.
HloComputation::Builder builder(TestName());
auto param0 = builder.AddInstruction(
HloInstruction::CreateParameter(0, f32a100x10, "p"));
auto const0 = builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.1f)));
auto reduce = builder.AddInstruction(
HloInstruction::CreateReduce(f32v100, param0, const0,
/*dimensions_to_reduce=*/{1}, add_f32));
module->AddEntryComputation(builder.Build());
OpAndUserCollectingVisitor visitor;
ASSERT_IS_OK(reduce->Accept(&visitor));
// Check counts. We aren't walking the reducer computation.
EXPECT_EQ(1, visitor.NumUsers(param0));
EXPECT_EQ(1, visitor.NumUsers(const0));
EXPECT_EQ(0, visitor.NumUsers(reduce));
EXPECT_EQ(2, visitor.NumOperands(reduce));
}
TEST_F(HloInstructionTest, ReplaceUseInBinaryOps) {
// Construct a graph of a few binary ops using two different
// parameters. Replace one of the parameters with the other parameter in one
// of the instructions.
HloComputation::Builder builder(TestName());
auto foo =
builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "foo"));
auto bar =
builder.AddInstruction(HloInstruction::CreateParameter(1, r0f32_, "bar"));
auto add_foobar = builder.AddInstruction(
HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, foo, bar));
auto add_foofoo = builder.AddInstruction(
HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, foo, foo));
builder.AddInstruction(HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd,
add_foobar, add_foofoo));
auto module = CreateNewVerifiedModule();
module->AddEntryComputation(builder.Build());
EXPECT_EQ(2, foo->user_count());
EXPECT_EQ(1, bar->user_count());
// Replace the use of foo in add_foofoo with bar.
ASSERT_IS_OK(foo->ReplaceUseWith(add_foofoo, bar));
EXPECT_EQ(1, foo->user_count());
EXPECT_EQ(2, bar->user_count());
EXPECT_THAT(foo->users(), UnorderedElementsAre(add_foobar));
EXPECT_THAT(add_foobar->operands(), ElementsAre(foo, bar));
EXPECT_THAT(bar->users(), UnorderedElementsAre(add_foobar, add_foofoo));
EXPECT_THAT(add_foobar->operands(), ElementsAre(foo, bar));
EXPECT_THAT(add_foofoo->operands(), ElementsAre(bar, bar));
}
TEST_F(HloInstructionTest, ReplaceUseInVariadicOp) {
// Construct a tuple containing several parameters. Replace one parameter with
// another in the tuple.
HloComputation::Builder builder(TestName());
auto foo =
builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "foo"));
auto bar =
builder.AddInstruction(HloInstruction::CreateParameter(1, r0f32_, "bar"));
auto baz =
builder.AddInstruction(HloInstruction::CreateParameter(2, r0f32_, "baz"));
auto tuple =
builder.AddInstruction(HloInstruction::CreateTuple({foo, bar, baz, foo}));
auto add_foobar = builder.AddInstruction(
HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, foo, bar));
auto module = CreateNewVerifiedModule();
module->AddEntryComputation(builder.Build());
EXPECT_EQ(2, foo->user_count());
EXPECT_THAT(foo->users(), UnorderedElementsAre(tuple, add_foobar));
// Replace the use of foo in tuple with bar.
ASSERT_IS_OK(foo->ReplaceUseWith(tuple, bar));
EXPECT_THAT(foo->users(), UnorderedElementsAre(add_foobar));
// Both uses of foo in tuple should have been replaced with bar.
EXPECT_THAT(tuple->operands(), ElementsAre(bar, bar, baz, bar));
}
TEST_F(HloInstructionTest, ReplaceUseInUnaryOp) {
// Construct a couple unary instructions which use a parameter. Replace the
// use of a parameter in one of the unary ops with the other parameter.
HloComputation::Builder builder(TestName());
auto foo =
builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "foo"));
auto bar =
builder.AddInstruction(HloInstruction::CreateParameter(1, r0f32_, "bar"));
auto exp = builder.AddInstruction(
HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, foo));
auto log = builder.AddInstruction(
HloInstruction::CreateUnary(r0f32_, HloOpcode::kLog, foo));
auto module = CreateNewVerifiedModule();
module->AddEntryComputation(builder.Build());
EXPECT_EQ(2, foo->user_count());
EXPECT_THAT(foo->users(), UnorderedElementsAre(exp, log));
EXPECT_EQ(0, bar->user_count());
// Replace the use of foo in exp with bar.
ASSERT_IS_OK(foo->ReplaceUseWith(exp, bar));
// The use of foo in log should not have been affected.
EXPECT_EQ(1, foo->user_count());
EXPECT_THAT(foo->users(), UnorderedElementsAre(log));
EXPECT_THAT(log->operands(), ElementsAre(foo));
// Bar should now be used in exp.
EXPECT_EQ(1, bar->user_count());
EXPECT_EQ(*bar->users().begin(), exp);
EXPECT_EQ(1, exp->operands().size());
EXPECT_EQ(*exp->operands().begin(), bar);
}
TEST_F(HloInstructionTest, ReplaceAllUsesWithInBinaryOps) {
// Construct a simple graph of a few binary ops using two different
// parameters. Replace all uses of one of the parameters with the other
// parameter.
HloComputation::Builder builder(TestName());
auto foo =
builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "foo"));
auto bar =
builder.AddInstruction(HloInstruction::CreateParameter(1, r0f32_, "bar"));
auto add_foobar = builder.AddInstruction(
HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, foo, bar));
auto add_foofoo = builder.AddInstruction(
HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, foo, foo));
builder.AddInstruction(HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd,
add_foobar, add_foofoo));
auto module = CreateNewVerifiedModule();
module->AddEntryComputation(builder.Build());
EXPECT_EQ(2, foo->user_count());
EXPECT_EQ(1, bar->user_count());
// Replace all uses of foo with bar.
ASSERT_IS_OK(foo->ReplaceAllUsesWith(bar));
EXPECT_EQ(0, foo->user_count());
EXPECT_EQ(2, bar->user_count());
EXPECT_THAT(bar->users(), UnorderedElementsAre(add_foobar, add_foofoo));
}
TEST_F(HloInstructionTest, ReplaceAllUsesInMultipleOps) {
// Construct a graph containing several ops (a unary, binary, and variadic)
// which use two parameters. Replace all uses of one of the parameters with
// the other parameter.
HloComputation::Builder builder(TestName());
auto foo =
builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "foo"));
auto bar =
builder.AddInstruction(HloInstruction::CreateParameter(1, r0f32_, "bar"));
auto add_foobar = builder.AddInstruction(
HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, foo, bar));
auto exp = builder.AddInstruction(
HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, foo));
auto tuple = builder.AddInstruction(HloInstruction::CreateTuple({foo, bar}));
auto module = CreateNewVerifiedModule();
module->AddEntryComputation(builder.Build());
EXPECT_EQ(3, foo->user_count());
EXPECT_EQ(2, bar->user_count());
// Replace all uses of foo with bar.
ASSERT_IS_OK(foo->ReplaceAllUsesWith(bar));
EXPECT_EQ(0, foo->user_count());
EXPECT_EQ(3, bar->user_count());
EXPECT_THAT(bar->users(), UnorderedElementsAre(add_foobar, exp, tuple));
}
// Simple visitor that collects and post-processes each node in the graph.
class NodeCollectorAndPostProcessor : public DfsHloVisitorWithDefault {
public:
NodeCollectorAndPostProcessor() {}
Status Postprocess(HloInstruction* hlo) override {
post_processed_nodes_.push_back(hlo);
return OkStatus();
}
Status DefaultAction(HloInstruction* hlo_instruction) override {
visited_nodes_.push_back(hlo_instruction);
return OkStatus();
}
const std::vector<const HloInstruction*>& visited_nodes() {
return visited_nodes_;
}
const std::vector<const HloInstruction*>& post_processed_nodes() {
return post_processed_nodes_;
}
private:
std::vector<const HloInstruction*> visited_nodes_;
std::vector<const HloInstruction*> post_processed_nodes_;
};
// Returns true if "vec" contains distinct nodes.
bool Distinct(const std::vector<const HloInstruction*>& vec) {
std::set<const HloInstruction*> distinct_nodes(vec.begin(), vec.end());
return distinct_nodes.size() == vec.size();
}
TEST_F(HloInstructionTest, PostProcessAllVisitedNodes) {
// Verifies all the nodes are visited and post-processed in the same order,
// and that each node is visited exactly once.
//
// /--> exp --\
// foo add
// \--> log --/
HloComputation::Builder builder(TestName());
auto foo =
builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "foo"));
auto exp = builder.AddInstruction(
HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, foo));
auto log = builder.AddInstruction(
HloInstruction::CreateUnary(r0f32_, HloOpcode::kLog, foo));
auto add = builder.AddInstruction(
HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, exp, log));
auto module = CreateNewVerifiedModule();
module->AddEntryComputation(builder.Build());
NodeCollectorAndPostProcessor visitor;
ASSERT_IS_OK(add->Accept(&visitor));
// Verifies all the nodes are visited and post-processed in the same order.
EXPECT_EQ(visitor.visited_nodes(), visitor.post_processed_nodes());
// Verifies each node is visited exactly once.
EXPECT_TRUE(Distinct(visitor.visited_nodes()));
}
TEST_F(HloInstructionTest, SingletonFusionOp) {
HloComputation::Builder builder(TestName());
// Create a fusion instruction containing a single unary operation.
auto constant = builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.1f)));
auto exp = builder.AddInstruction(
HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, constant));
auto module = CreateNewVerifiedModule();
auto* computation = module->AddEntryComputation(builder.Build());
auto* fusion = computation->CreateFusionInstruction(
{exp}, HloInstruction::FusionKind::kLoop);
EXPECT_THAT(fusion->operands(), ElementsAre(constant));
EXPECT_THAT(constant->users(), ElementsAre(fusion));
}
TEST_F(HloInstructionTest, BinaryFusionOp) {
HloComputation::Builder builder(TestName());
// Create a fusion instruction containing a single binary operation.
auto constant1 = builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.1f)));
auto constant2 = builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.1f)));
auto add = builder.AddInstruction(HloInstruction::CreateBinary(
r0f32_, HloOpcode::kAdd, constant1, constant2));
auto module = CreateNewVerifiedModule();
auto* computation = module->AddEntryComputation(builder.Build());
auto* fusion = computation->CreateFusionInstruction(
{add}, HloInstruction::FusionKind::kLoop);
EXPECT_THAT(fusion->operands(), ElementsAre(constant1, constant2));
EXPECT_THAT(constant1->users(), ElementsAre(fusion));
EXPECT_THAT(constant2->users(), ElementsAre(fusion));
}
TEST_F(HloInstructionTest, ChainFusionOp) {
HloComputation::Builder builder(TestName());
// Create a chain of fused unary ops.
auto constant = builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.1f)));
auto exp1 = builder.AddInstruction(
HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, constant));
auto exp2 = builder.AddInstruction(
HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, exp1));
auto exp3 = builder.AddInstruction(
HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, exp2));
auto module = CreateNewVerifiedModule();
auto* computation = module->AddEntryComputation(builder.Build());
auto* fusion = computation->CreateFusionInstruction(
{exp3, exp2, exp1}, HloInstruction::FusionKind::kLoop);
EXPECT_THAT(fusion->operands(), ElementsAre(constant));
EXPECT_THAT(constant->users(), ElementsAre(fusion));
}
TEST_F(HloInstructionTest, PreserveMetadataInFusionAndClone) {
HloComputation::Builder builder(TestName());
// Create a chain of fused unary ops.
auto constant = builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.1f)));
auto exp1 = builder.AddInstruction(
HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, constant));
auto exp2 = builder.AddInstruction(
HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, exp1));
OpMetadata metadata;
metadata.set_op_name("tf_op");
exp1->set_metadata(metadata);
exp2->set_metadata(metadata);
auto module = CreateNewVerifiedModule();
auto* computation = module->AddEntryComputation(builder.Build());
auto* fusion = computation->CreateFusionInstruction(
{exp2, exp1}, HloInstruction::FusionKind::kLoop);
EXPECT_TRUE(protobuf_util::ProtobufEquals(metadata, fusion->metadata()));
EXPECT_TRUE(protobuf_util::ProtobufEquals(
metadata, fusion->fused_expression_root()->metadata()));
EXPECT_TRUE(protobuf_util::ProtobufEquals(
metadata, fusion->fused_expression_root()->operand(0)->metadata()));
auto cloned = fusion->CloneWithNewOperands(fusion->shape(), {});
EXPECT_TRUE(protobuf_util::ProtobufEquals(metadata, fusion->metadata()));
}
TEST_F(HloInstructionTest, BinaryCallOp) {
HloComputation::Builder builder(TestName());
// Create a call instruction containing a single binary operation.
auto constant1 = builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.1f)));
auto constant2 = builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.1f)));
auto add = builder.AddInstruction(HloInstruction::CreateBinary(
r0f32_, HloOpcode::kAdd, constant1, constant2));
auto module = CreateNewVerifiedModule();
auto* computation = module->AddEntryComputation(builder.Build());
auto* call = computation->CreateCallInstruction({add});
EXPECT_THAT(call->operands(), ElementsAre(constant1, constant2));
EXPECT_THAT(constant1->users(), ElementsAre(call));
EXPECT_THAT(constant2->users(), ElementsAre(call));
}
TEST_F(HloInstructionTest, ChainCallOp) {
HloComputation::Builder builder(TestName());
// Create a chain of called unary ops.
auto constant = builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.1f)));
auto exp1 = builder.AddInstruction(
HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, constant));
auto exp2 = builder.AddInstruction(
HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, exp1));
auto exp3 = builder.AddInstruction(
HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, exp2));
auto module = CreateNewVerifiedModule();
auto* computation = module->AddEntryComputation(builder.Build());
auto* call = computation->CreateCallInstruction({exp3, exp2, exp1});
EXPECT_THAT(call->operands(), ElementsAre(constant));
EXPECT_THAT(constant->users(), ElementsAre(call));
}
TEST_F(HloInstructionTest, MultiOutputCallOp) {
HloComputation::Builder builder(TestName());
// Create a chain of called unary ops.
auto constant = builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.1f)));
auto exp1 = builder.AddInstruction(
HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, constant));
auto exp2 = builder.AddInstruction(
HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, exp1));
auto exp3 = builder.AddInstruction(
HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, exp2));
auto exp4 = builder.AddInstruction(
HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, constant));
auto add = builder.AddInstruction(
HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, exp3, exp4));
auto module = CreateNewVerifiedModule();
auto* computation = module->AddEntryComputation(builder.Build());
auto* call = computation->CreateCallInstruction({exp3, exp2, exp1});
call->AppendInstructionIntoCalledComputation(exp4, /*add_output=*/true);
EXPECT_THAT(call->operands(), ElementsAre(constant));
EXPECT_EQ(add->operand(0)->opcode(), HloOpcode::kGetTupleElement);
EXPECT_THAT(add->operand(0)->operands(), ElementsAre(call));
EXPECT_EQ(add->operand(1)->opcode(), HloOpcode::kGetTupleElement);
EXPECT_THAT(add->operand(1)->operands(), ElementsAre(call));
}
TEST_F(HloInstructionTest, AsyncOp) {
HloComputation::Builder builder(TestName());
// Create a call instruction containing a single binary operation.
auto constant1 = builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.1f)));
auto constant2 = builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.1f)));
auto add = builder.AddInstruction(HloInstruction::CreateBinary(
r0f32_, HloOpcode::kAdd, constant1, constant2));
auto module = CreateNewVerifiedModule();
auto* computation = module->AddEntryComputation(builder.Build());
TF_ASSERT_OK_AND_ASSIGN(
auto* async_done,
computation->CreateAsyncInstructions(
add, {ShapeUtil::MakeScalarShape(U32)}, "parallel_thread"));
auto* async_start = async_done->operand(0);
EXPECT_EQ(async_start->shape().tuple_shapes_size(), 3);
EXPECT_EQ(async_start->async_thread_name(), "parallel_thread");
EXPECT_EQ(async_done->async_thread_name(), "parallel_thread");
EXPECT_TRUE(ShapeUtil::Equal(async_start->shape().tuple_shapes(2),
ShapeUtil::MakeScalarShape(U32)));
EXPECT_EQ(async_start->async_wrapped_computation()->thread_name(),
"parallel_thread");
EXPECT_EQ(async_done->async_wrapped_computation()->thread_name(),
"parallel_thread");
EXPECT_THAT(async_start->operands(), ElementsAre(constant1, constant2));
EXPECT_THAT(constant1->users(), ElementsAre(async_start));
EXPECT_THAT(constant2->users(), ElementsAre(async_start));
EXPECT_EQ(computation->root_instruction(), async_done);
}
TEST_F(HloInstructionTest, PreserveOutfeedShapeThroughClone) {
HloComputation::Builder builder(TestName());
auto constant = builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR2<float>({
{1, 2},
{3, 4},
})));
auto shape10 = ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {1, 0});
auto shape01 = ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {0, 1});
auto token = builder.AddInstruction(HloInstruction::CreateToken());
auto outfeed10 = builder.AddInstruction(
HloInstruction::CreateOutfeed(shape10, constant, token, ""));
auto outfeed01 = builder.AddInstruction(
HloInstruction::CreateOutfeed(shape01, constant, token, ""));
auto clone01 = builder.AddInstruction(outfeed01->Clone());
auto clone10 = builder.AddInstruction(outfeed10->Clone());
EXPECT_TRUE(ShapeUtil::Equal(clone01->outfeed_shape(), shape01));
EXPECT_TRUE(ShapeUtil::Equal(clone10->outfeed_shape(), shape10));
}
TEST_F(HloInstructionTest, PreserveTupleShapeThroughClone) {
HloComputation::Builder builder(TestName());
auto* constant = builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR2<float>({
{1, 2},
{3, 4},
})));
auto* tuple =
builder.AddInstruction(HloInstruction::CreateTuple({constant, constant}));
*ShapeUtil::GetMutableSubshape(tuple->mutable_shape(), {0})
->mutable_layout() = LayoutUtil::MakeLayout({0, 1});
*ShapeUtil::GetMutableSubshape(tuple->mutable_shape(), {1})
->mutable_layout() = LayoutUtil::MakeLayout({1, 0});
auto tuple_clone = tuple->Clone();
EXPECT_TRUE(ShapeUtil::Equal(tuple_clone->shape(), tuple->shape()));
}
TEST_F(HloInstructionTest, PreserveShardingThroughCompatibleClone) {
HloSharding sharding = HloSharding::AssignDevice(5);
HloComputation::Builder builder(TestName());
auto* constant = builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR2<float>({
{1, 2},
{3, 4},
})));
auto* tuple =
builder.AddInstruction(HloInstruction::CreateTuple({constant, constant}));
tuple->set_sharding(sharding);
// Compatible with original shape as tuple tree structure and leaf ranks are
// identical
auto clone_shape = ShapeUtil::MakeShape(F32, {3, 3});
clone_shape = ShapeUtil::MakeTupleShape({clone_shape, clone_shape});
auto tuple_clone = tuple->CloneWithNewOperands(clone_shape, {});
EXPECT_EQ(tuple_clone->sharding(), sharding);
}
TEST_F(HloInstructionTest,
DoNotPreserveShardingThroughTupleTreeIncompatibleClone) {
HloSharding sharding = HloSharding::AssignDevice(5);
HloComputation::Builder builder(TestName());
auto* constant = builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR2<float>({
{1, 2},
{3, 4},
})));
auto* tuple =
builder.AddInstruction(HloInstruction::CreateTuple({constant, constant}));
tuple->set_sharding(sharding);
// Incompatible with original shape as tuple tree structure is different
auto clone_shape = ShapeUtil::MakeShape(F32, {2, 2});
clone_shape =
ShapeUtil::MakeTupleShape({clone_shape, clone_shape, clone_shape});
auto tuple_clone = tuple->CloneWithNewOperands(clone_shape, {});
EXPECT_FALSE(tuple_clone->has_sharding());
}
TEST_F(HloInstructionTest,
DoNotPreserveShardingThroughLeafRankIncompatibleClone) {
HloSharding sharding = HloSharding::AssignDevice(5);
HloComputation::Builder builder(TestName());
auto* constant = builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR2<float>({
{1, 2},
{3, 4},
})));
auto* tuple =
builder.AddInstruction(HloInstruction::CreateTuple({constant, constant}));
tuple->set_sharding(sharding);
// Incompatible with original shape as tuple tree structure is different
auto clone_shape = ShapeUtil::MakeShape(F32, {1, 2, 3});
clone_shape = ShapeUtil::MakeTupleShape({clone_shape, clone_shape});
auto tuple_clone = tuple->CloneWithNewOperands(clone_shape, {});
EXPECT_FALSE(tuple_clone->has_sharding());
}
TEST_F(HloInstructionTest, FusionOpWithCalledComputations) {
// Create a fusion instruction containing a single unary operation.
const Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
auto module = CreateNewVerifiedModule();
auto make_map_computation = [&]() {
auto builder = HloComputation::Builder("FusionMap");
builder.AddInstruction(
HloInstruction::CreateParameter(0, scalar_shape, "param"));
return module->AddEmbeddedComputation(builder.Build());
};
HloComputation* computation_x = make_map_computation();
HloComputation* computation_y = make_map_computation();
HloComputation::Builder builder(TestName());
auto constant = builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.1f)));
auto map_1_x = builder.AddInstruction(
HloInstruction::CreateMap(scalar_shape, {constant}, computation_x));
auto map_2_x = builder.AddInstruction(
HloInstruction::CreateMap(scalar_shape, {map_1_x}, computation_x));
auto map_3_y = builder.AddInstruction(
HloInstruction::CreateMap(scalar_shape, {map_2_x}, computation_y));
auto* computation = module->AddEntryComputation(builder.Build());
auto* fusion = computation->CreateFusionInstruction(
{map_3_y}, HloInstruction::FusionKind::kLoop);
auto* fused_computation = fusion->fused_instructions_computation();
EXPECT_THAT(fusion->called_computations(), ElementsAre(fused_computation));
fusion->FuseInstruction(map_2_x);
EXPECT_THAT(fusion->called_computations(), ElementsAre(fused_computation));
fusion->FuseInstruction(map_1_x);
EXPECT_THAT(fusion->called_computations(), ElementsAre(fused_computation));
}
TEST_F(HloInstructionTest, ComplexFusionOp) {
HloComputation::Builder builder(TestName());
// Fuse all instructions in complicated expression:
//
// add = Add(C1, C2)
// clamp = Clamp(C2, add, add)
// exp = Exp(add)
// mul = Mul(exp, C3)
// sub = Sub(mul, clamp)
// tuple = Tuple({sub, sub, mul, C1})
//
// Notable complexities are repeated operands in the same instruction,
// different shapes, use of value in different expressions.
auto c1 = builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.1f)));
auto c2 = builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.1f)));
auto c3 = builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(9.0f)));
auto add = builder.AddInstruction(
HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, c1, c2));
auto clamp = builder.AddInstruction(
HloInstruction::CreateTernary(r0f32_, HloOpcode::kClamp, c2, add, add));
auto exp = builder.AddInstruction(
HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, add));
auto mul = builder.AddInstruction(
HloInstruction::CreateBinary(r0f32_, HloOpcode::kMultiply, exp, c3));
auto sub = builder.AddInstruction(
HloInstruction::CreateBinary(r0f32_, HloOpcode::kSubtract, mul, clamp));
auto tuple =
builder.AddInstruction(HloInstruction::CreateTuple({sub, sub, mul, c1}));
auto module = CreateNewVerifiedModule();
auto* computation = module->AddEntryComputation(builder.Build());
auto* fusion = computation->CreateFusionInstruction(
{tuple, sub, mul, exp, clamp, add}, HloInstruction::FusionKind::kLoop);
// Operands in the fusion instruction's operands() vector should be in the
// order in which their users were added fused.
EXPECT_THAT(fusion->operands(), ElementsAre(c1, c3, c2));
EXPECT_THAT(c1->users(), ElementsAre(fusion));
}
// Convenience function for comparing two HloInstructions.
static bool Identical(const HloInstruction& instruction1,
const HloInstruction& instruction2) {
// Verify Identical is reflexive for both instructions.
EXPECT_TRUE(instruction1.Identical(instruction1));
EXPECT_TRUE(instruction2.Identical(instruction2));
bool is_equal = instruction1.Identical(instruction2);
// Verify Identical is symmetric.
EXPECT_EQ(is_equal, instruction2.Identical(instruction1));
return is_equal;
}
// Convenience function for comparing two HloInstructions for structural
// equality.
static bool StructuralEqual(const HloInstruction& instruction1,
const HloInstruction& instruction2) {
auto eq_operand_shapes = [](const HloInstruction* a,
const HloInstruction* b) {
return ShapeUtil::Equal(a->shape(), b->shape());
};
auto eq_computations = [](const HloComputation* a, const HloComputation* b) {
return *a == *b;
};
// Verify Identical is reflexive for both instructions.
EXPECT_TRUE(
instruction1.Identical(instruction1, eq_operand_shapes, eq_computations));
EXPECT_TRUE(
instruction2.Identical(instruction2, eq_operand_shapes, eq_computations));
bool is_equal =
instruction1.Identical(instruction2, eq_operand_shapes, eq_computations);
// Verify Identical is symmetric.
EXPECT_EQ(is_equal, instruction2.Identical(instruction1, eq_operand_shapes,
eq_computations));
return is_equal;
}
TEST_F(HloInstructionTest, IdenticalInstructions) {
// Test HloInstruction::Identical with some subset of instructions types.
// Create a set of random constant operands to use below. Make them matrices
// so dimensions are interesting.
auto operand1 = HloInstruction::CreateConstant(
LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}));
auto operand2 = HloInstruction::CreateConstant(
LiteralUtil::CreateR2<float>({{10.0, 20.0}, {30.0, 40.0}}));
auto vector_operand = HloInstruction::CreateConstant(
LiteralUtil::CreateR1<float>({42.0, 123.0}));
Shape shape = operand1->shape();
// Convenient short names for the operands.
HloInstruction* op1 = operand1.get();
HloInstruction* op2 = operand2.get();
// Operations which only depend on their operands and opcode.
EXPECT_TRUE(
Identical(*HloInstruction::CreateUnary(shape, HloOpcode::kCopy, op1),
*HloInstruction::CreateUnary(shape, HloOpcode::kCopy, op1)));
EXPECT_FALSE(
Identical(*HloInstruction::CreateUnary(shape, HloOpcode::kCopy, op1),
*HloInstruction::CreateUnary(shape, HloOpcode::kCopy, op2)));
EXPECT_FALSE(
Identical(*HloInstruction::CreateUnary(shape, HloOpcode::kCopy, op1),
*HloInstruction::CreateUnary(shape, HloOpcode::kNegate, op1)));
// Tuples.
EXPECT_TRUE(Identical(*HloInstruction::CreateTuple({op1, op2}),
*HloInstruction::CreateTuple({op1, op2})));
EXPECT_FALSE(Identical(*HloInstruction::CreateTuple({op1, op2}),
*HloInstruction::CreateTuple({op2, op1})));
// Broadcasts.
EXPECT_TRUE(Identical(*HloInstruction::CreateBroadcast(shape, op1, {0, 1}),
*HloInstruction::CreateBroadcast(shape, op1, {0, 1})));
EXPECT_FALSE(Identical(*HloInstruction::CreateBroadcast(shape, op1, {0, 1}),
*HloInstruction::CreateBroadcast(shape, op1, {1, 0})));
Shape bcast_shape1 = ShapeUtil::MakeShape(F32, {2, 2, 42});
Shape bcast_shape2 = ShapeUtil::MakeShape(F32, {2, 2, 123});
EXPECT_FALSE(
Identical(*HloInstruction::CreateBroadcast(bcast_shape1, op1, {0, 1}),
*HloInstruction::CreateBroadcast(bcast_shape2, op1, {0, 1})));
// Binary operands.
EXPECT_TRUE(Identical(
*HloInstruction::CreateBinary(shape, HloOpcode::kAdd, op1, op2),
*HloInstruction::CreateBinary(shape, HloOpcode::kAdd, op1, op2)));
EXPECT_FALSE(Identical(
*HloInstruction::CreateBinary(shape, HloOpcode::kAdd, op1, op2),
*HloInstruction::CreateBinary(shape, HloOpcode::kDivide, op2, op1)));
EXPECT_FALSE(Identical(
*HloInstruction::CreateBinary(shape, HloOpcode::kAdd, op1, op2),
*HloInstruction::CreateBinary(shape, HloOpcode::kDivide, op1, op2)));
}
TEST_F(HloInstructionTest, IdenticalCallInstructions) {
const char* const hlo_string = R"(
HloModule Module
subcomp1 (x: f32[]) -> f32[] {
x = f32[] parameter(0)
ROOT n = f32[] sine(x)
}
subcomp2 (x: f32[]) -> f32[] {
x = f32[] parameter(0)
ROOT n = f32[] cosine(x)
}
ENTRY entry (param: f32[]) -> (f32[], f32[], f32[]) {
p = f32[] parameter(0)
t1 = f32[] call(p), to_apply=subcomp1
t2 = f32[] call(p), to_apply=subcomp1
t3 = f32[] call(p), to_apply=subcomp2
ROOT t = (f32[], f32[], f32[]) tuple(t1, t2, t3)
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(hlo_string));
auto* root = module->entry_computation()->root_instruction();
auto* t1 = root->operand(0);
auto* t2 = root->operand(1);
auto* t3 = root->operand(2);
EXPECT_TRUE(StructuralEqual(*t1, *t2));
EXPECT_FALSE(StructuralEqual(*t1, *t3));
}
TEST_F(HloInstructionTest, FunctionVisitor) {
// Verify the function visitor HloInstruction::Accept visits all instructions
// from a root properly given the following graph:
//
// param
// / \
// negate exp
// \ /
// add
const Shape f32 = ShapeUtil::MakeShape(F32, {});
HloComputation::Builder builder(TestName());
auto param =
builder.AddInstruction(HloInstruction::CreateParameter(0, f32, "0"));
auto negate = builder.AddInstruction(
HloInstruction::CreateUnary(f32, HloOpcode::kNegate, param));
auto exp = builder.AddInstruction(
HloInstruction::CreateUnary(f32, HloOpcode::kExp, param));
auto add = builder.AddInstruction(
HloInstruction::CreateBinary(f32, HloOpcode::kAdd, negate, exp));
auto module = CreateNewVerifiedModule();
module->AddEntryComputation(builder.Build());
int visit_num = 0;
absl::flat_hash_map<HloInstruction*, int> visit_order;
FunctionVisitor visitor([&visit_num, &visit_order](HloInstruction* inst) {
EXPECT_FALSE(visit_order.contains(inst));
visit_order[inst] = visit_num;
visit_num++;
return OkStatus();
});
EXPECT_IS_OK(add->Accept(&visitor));
EXPECT_EQ(0, visit_order.at(param));
// negate and exp can be visited in an arbitrary order.
EXPECT_TRUE(visit_order.at(exp) == 1 || visit_order.at(exp) == 2);
EXPECT_TRUE(visit_order.at(negate) == 1 || visit_order.at(negate) == 2);
EXPECT_NE(visit_order.at(exp), visit_order.at(negate));
EXPECT_EQ(3, visit_order.at(add));
}
TEST_F(HloInstructionTest, FullyElementwise) {
const Shape r1f32 = ShapeUtil::MakeShape(F32, {5});
HloComputation::Builder builder(TestName());
auto x =
builder.AddInstruction(HloInstruction::CreateParameter(0, r1f32, "x"));
auto y =
builder.AddInstruction(HloInstruction::CreateParameter(1, r1f32, "y"));
auto add = builder.AddInstruction(
HloInstruction::CreateBinary(r1f32, HloOpcode::kAdd, x, y));
auto module = CreateNewVerifiedModule();
module->AddEntryComputation(builder.Build());
EXPECT_TRUE(add->IsElementwise());
for (int i = 0; i < add->operand_count(); ++i) {
EXPECT_TRUE(add->IsElementwiseOnOperand(i));
}
}
TEST_F(HloInstructionTest, MapIsElementwise) {
auto module = CreateNewVerifiedModule();
const Shape r2f32 = ShapeUtil::MakeShapeWithLayout(F32, {10, 10}, {1, 0});
HloComputation::Builder builder(TestName());
HloComputation::Builder map_builder("id");
map_builder.AddInstruction(
HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {}), "p0"));
auto map_computation = module->AddEmbeddedComputation(map_builder.Build());
auto x =
builder.AddInstruction(HloInstruction::CreateParameter(0, r2f32, "x"));
auto map = builder.AddInstruction(
HloInstruction::CreateMap(r2f32, {x}, map_computation));
module->AddEntryComputation(builder.Build());
EXPECT_TRUE(map->IsElementwise());
}
TEST_F(HloInstructionTest, PartiallyElementwise) {
const Shape r1f32 = ShapeUtil::MakeShape(F32, {5});
const Shape r2f32 = ShapeUtil::MakeShape(F32, {3, 5});
// Fused expression:
//
// p0 p1 p2 p3
// \ / / |
// mul / |
// \ / |
// div broadcast
// \ /
// max
//
// The fusion instruction is not elementwise on p3 because the broadcast is
// not elementwise.
HloComputation::Builder builder("PartiallyElementwise");
HloInstruction* p0 =
builder.AddInstruction(HloInstruction::CreateParameter(0, r2f32, "p0"));
HloInstruction* p1 =
builder.AddInstruction(HloInstruction::CreateParameter(1, r2f32, "p1"));
HloInstruction* p2 =
builder.AddInstruction(HloInstruction::CreateParameter(2, r2f32, "p2"));
HloInstruction* p3 =
builder.AddInstruction(HloInstruction::CreateParameter(3, r1f32, "p3"));
HloInstruction* mul = builder.AddInstruction(
HloInstruction::CreateBinary(r2f32, HloOpcode::kMultiply, p0, p1));
HloInstruction* div = builder.AddInstruction(
HloInstruction::CreateBinary(r2f32, HloOpcode::kDivide, mul, p2));
// Dimension 0 of shape [5] is mapped to dimension 1 of shape [3x5].
HloInstruction* broadcast =
builder.AddInstruction(HloInstruction::CreateBroadcast(r2f32, p3, {1}));
HloInstruction* max = builder.AddInstruction(
HloInstruction::CreateBinary(r2f32, HloOpcode::kMaximum, div, broadcast));
auto module = CreateNewVerifiedModule();
auto* computation = module->AddEntryComputation(builder.Build());
HloInstruction* fusion = computation->CreateFusionInstruction(
{max, broadcast, div, mul}, HloInstruction::FusionKind::kLoop);
EXPECT_FALSE(fusion->IsElementwise());
for (int64_t operand_idx = 0; operand_idx < fusion->operand_count();
++operand_idx) {
const HloInstruction* operand = fusion->operand(operand_idx);
if (operand == p3) {
EXPECT_FALSE(fusion->IsElementwiseOnOperand(operand_idx));
} else {
EXPECT_TRUE(fusion->IsElementwiseOnOperand(operand_idx));
}
}
}
TEST_F(HloInstructionTest, PartiallyElementwiseWithReuse) {
// Fused expression:
// y
// /
// x broadcast
// \ / |
// min |
// \ /
// sub
//
const Shape r0f32 = ShapeUtil::MakeShape(F32, {});
const Shape r1f32 = ShapeUtil::MakeShape(F32, {5});
HloComputation::Builder builder("PartiallyElementwiseWithReuse");
HloInstruction* x =
builder.AddInstruction(HloInstruction::CreateParameter(0, r1f32, "x"));
HloInstruction* y =
builder.AddInstruction(HloInstruction::CreateParameter(1, r0f32, "y"));
HloInstruction* broadcast =
builder.AddInstruction(HloInstruction::CreateBroadcast(r1f32, y, {}));
HloInstruction* min = builder.AddInstruction(
HloInstruction::CreateBinary(r1f32, HloOpcode::kMinimum, x, broadcast));
HloInstruction* sub = builder.AddInstruction(HloInstruction::CreateBinary(
r1f32, HloOpcode::kSubtract, min, broadcast));
auto module = CreateNewVerifiedModule();
auto* computation = module->AddEntryComputation(builder.Build());
HloInstruction* fusion = computation->CreateFusionInstruction(
{sub, broadcast, min}, HloInstruction::FusionKind::kLoop);
EXPECT_FALSE(fusion->IsElementwise());
for (int64_t operand_idx = 0; operand_idx < fusion->operand_count();
++operand_idx) {
if (fusion->operand(operand_idx) == y) {
EXPECT_FALSE(fusion->IsElementwiseOnOperand(operand_idx));
} else {
EXPECT_TRUE(fusion->IsElementwiseOnOperand(operand_idx));
}
}
}
TEST_F(HloInstructionTest, CloneOfFusionPreservesShape) {
// Fused expression:
//
// x y
// | |
// | transpose
// \ /
// dot
//
// Tests that shapes aren't mangled by Clone().
const Shape s1 = ShapeUtil::MakeShape(F32, {5, 10});
const Shape s2 = ShapeUtil::MakeShape(F32, {20, 10});
const Shape s2t = ShapeUtil::MakeShape(F32, {10, 20});
const Shape sout = ShapeUtil::MakeShape(F32, {5, 20});
HloComputation::Builder builder("TransposeDot");
HloInstruction* x =
builder.AddInstruction(HloInstruction::CreateParameter(0, s1, "x"));
HloInstruction* y =
builder.AddInstruction(HloInstruction::CreateParameter(1, s2, "y"));
HloInstruction* reshape =
builder.AddInstruction(HloInstruction::CreateTranspose(s2t, y, {1, 0}));
DotDimensionNumbers dot_dnums;
dot_dnums.add_lhs_contracting_dimensions(1);
dot_dnums.add_rhs_contracting_dimensions(0);
HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateDot(
sout, x, reshape, dot_dnums, DefaultPrecisionConfig(2)));
auto module = CreateNewVerifiedModule();
auto* computation = module->AddEntryComputation(builder.Build());
HloInstruction* fusion = computation->CreateFusionInstruction(
{dot, reshape}, HloInstruction::FusionKind::kLoop);
auto fusion2 = fusion->Clone();
const HloInstruction* root = fusion->fused_expression_root();
const HloInstruction* root2 = fusion2->fused_expression_root();
EXPECT_TRUE(ShapeUtil::Equal(root->shape(), root2->shape()));
EXPECT_TRUE(
ShapeUtil::Equal(root->operand(0)->shape(), root2->operand(0)->shape()));
EXPECT_TRUE(
ShapeUtil::Equal(root->operand(1)->shape(), root2->operand(1)->shape()));
EXPECT_TRUE(ShapeUtil::Equal(root->operand(1)->operand(0)->shape(),
root2->operand(1)->operand(0)->shape()));
EXPECT_TRUE(StructuralEqual(*fusion, *fusion2));
}
TEST_F(HloInstructionTest, FuseInstructionKeepsInstruction) {
constexpr char kHloString[] = R"(
HloModule test_module
fused_add {
p0 = f32[32,32]{1,0} parameter(0)
p1 = f32[32,32]{1,0} parameter(1)
ROOT add = f32[32,32]{1,0} add(p0, p1)
}
ENTRY reduce {
p2 = f32[32,32]{1,0} parameter(0)
p3 = f32[32,32]{1,0} parameter(1)
c1 = f32[] constant(1)
broadcast = f32[32,32]{1,0} broadcast(c1), dimensions={}
mul = f32[32,32]{1,0} multiply(p2, p3)
ROOT add = f32[32,32]{1,0} fusion(mul, broadcast), kind=kLoop, calls=fused_add
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(kHloString));
HloInstruction* fused_add = module->entry_computation()->root_instruction();
HloInstruction* mul = fused_add->mutable_operand(0);
EXPECT_EQ(1, mul->user_count());
fused_add->FuseInstruction(mul);
EXPECT_EQ(0, mul->user_count());
// The fused instruction is still present in the computation.
EXPECT_EQ(fused_add->parent(), mul->parent());
}
TEST_F(HloInstructionTest, FuseInstructionIntoMultiOutputKeepsInstruction) {
constexpr char kHloString[] = R"(
HloModule test_module
fused_add {
p0 = f32[32,32]{1,0} parameter(0)
p1 = f32[32,32]{1,0} parameter(1)
ROOT add = f32[32,32]{1,0} add(p0, p1)
}
ENTRY reduce {
p2 = f32[32,32]{1,0} parameter(0)
p3 = f32[32,32]{1,0} parameter(1)
c1 = f32[] constant(1)
mul = f32[32,32]{1,0} multiply(p2, p3)
broadcast = f32[32,32]{1,0} broadcast(c1), dimensions={}
add = f32[32,32]{1,0} fusion(mul, broadcast), kind=kLoop, calls=fused_add
ROOT root = (f32[32,32]{1,0}, f32[32,32]{1,0}) tuple(mul, add)
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(kHloString));
HloInstruction* root = module->entry_computation()->root_instruction();
HloInstruction* mul = root->mutable_operand(0);
HloInstruction* fused_add = root->mutable_operand(1);
EXPECT_EQ(2, mul->user_count());
fused_add->FuseInstructionIntoMultiOutput(mul);
EXPECT_EQ(0, mul->user_count());
// The fused instruction is still present in the computation.
EXPECT_EQ(root->parent(), mul->parent());
}
TEST_F(HloInstructionTest, NoRedundantFusionOperandsAfterReplacingUse) {
// Fused expression:
//
// x y
// | |
// | transpose
// \ /
// dot
const Shape s = ShapeUtil::MakeShape(F32, {10, 10});
HloComputation::Builder builder("TransposeDot");
HloInstruction* x =
builder.AddInstruction(HloInstruction::CreateParameter(0, s, "x"));
HloInstruction* y =
builder.AddInstruction(HloInstruction::CreateParameter(1, s, "y"));
HloInstruction* reshape =
builder.AddInstruction(HloInstruction::CreateTranspose(s, y, {1, 0}));
DotDimensionNumbers dot_dnums;
dot_dnums.add_lhs_contracting_dimensions(1);
dot_dnums.add_rhs_contracting_dimensions(0);
HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateDot(
s, x, reshape, dot_dnums, DefaultPrecisionConfig(2)));
auto module = CreateNewVerifiedModule();
auto* computation = module->AddEntryComputation(builder.Build());
HloInstruction* fusion = computation->CreateFusionInstruction(
{dot, reshape}, HloInstruction::FusionKind::kLoop);
EXPECT_TRUE(x->ReplaceAllUsesWith(y).ok());
EXPECT_THAT(fusion->operands(), UnorderedElementsAre(y));
EXPECT_EQ(fusion->fused_instructions_computation()->num_parameters(), 1);
}
TEST_F(HloInstructionTest, FusionEquality) {
auto module = CreateNewVerifiedModule();
HloComputation::Builder builder(TestName());
// Create two fusion instructions containing a single unary operation.
auto parameter =
builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "x"));
auto exp = builder.AddInstruction(
HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, parameter));
auto neg = builder.AddInstruction(
HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, parameter));
auto* computation = module->AddEntryComputation(builder.Build());
auto* fusion = computation->CreateFusionInstruction(
{exp}, HloInstruction::FusionKind::kLoop);
auto* fusion2 = computation->CreateFusionInstruction(
{neg}, HloInstruction::FusionKind::kLoop);
EXPECT_FALSE(StructuralEqual(*fusion, *fusion2));
auto clone = fusion->Clone();
EXPECT_TRUE(StructuralEqual(*fusion, *clone));
}
TEST_F(HloInstructionTest, NestedFusionEquality) {
auto module = CreateNewVerifiedModule();
HloComputation::Builder builder(TestName());
// Build a nested fusion computation.
Shape data_shape = ShapeUtil::MakeShape(F32, {2, 2});
auto a = builder.AddInstruction(HloInstruction::CreateConstant(
LiteralUtil::CreateR2<float>({{1.0, 0.0}, {0.0, 1.0}})));
auto b = builder.AddInstruction(HloInstruction::CreateConstant(
LiteralUtil::CreateR2<float>({{2.0, 2.0}, {2.0, 2.0}})));
auto b_t = builder.AddInstruction(
HloInstruction::CreateTranspose(data_shape, b, {1, 0}));
DotDimensionNumbers dot_dnums;
dot_dnums.add_lhs_contracting_dimensions(1);
dot_dnums.add_rhs_contracting_dimensions(0);
auto dot = builder.AddInstruction(HloInstruction::CreateDot(
data_shape, a, b_t, dot_dnums, DefaultPrecisionConfig(2)));
auto one = builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
auto add_operand = builder.AddInstruction(
HloInstruction::CreateBroadcast(data_shape, one, {}));
auto add = builder.AddInstruction(HloInstruction::CreateBinary(
data_shape, HloOpcode::kAdd, dot, add_operand));
auto sub = builder.AddInstruction(HloInstruction::CreateBinary(
data_shape, HloOpcode::kSubtract, dot, add_operand));
builder.AddInstruction(
HloInstruction::CreateBinary(data_shape, HloOpcode::kMultiply, add, sub));
auto computation = module->AddEntryComputation(builder.Build());
auto nested_fusion = computation->CreateFusionInstruction(
{dot, b_t}, HloInstruction::FusionKind::kLoop);
auto fusion = computation->CreateFusionInstruction(
{add, nested_fusion}, HloInstruction::FusionKind::kOutput);
auto fusion2 = computation->CreateFusionInstruction(
{sub, nested_fusion}, HloInstruction::FusionKind::kOutput);
auto clone = fusion->Clone();
EXPECT_TRUE(StructuralEqual(*fusion, *clone));
EXPECT_FALSE(StructuralEqual(*fusion, *fusion2));
}
TEST_F(HloInstructionTest, CloneSuffixNames) {
// Test that the suffix string added to cloned instructions is not
// duplicated. Rather a numeric incrementing value should be appended. That
// is, we want "foo.clone2", not "foo.clone.clone".
// Test cloning the same instruction multiple times.
auto foo =
HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {}), "foo");
EXPECT_EQ(foo->Clone()->name(), "foo.clone");
EXPECT_EQ(foo->Clone()->Clone()->name(), "foo.clone2");
EXPECT_EQ(foo->Clone()->Clone()->Clone()->name(), "foo.clone3");
// Test custom suffixes.
EXPECT_EQ(foo->Clone("bar")->name(), "foo.bar");
EXPECT_EQ(foo->Clone("bar")->Clone("bar")->name(), "foo.bar2");
EXPECT_EQ(foo->Clone("bar")->Clone("bar")->Clone()->name(), "foo.bar2.clone");
// Test instruction name with a dot.
auto foo_baz = HloInstruction::CreateParameter(
0, ShapeUtil::MakeShape(F32, {}), "foo.baz");
EXPECT_EQ(foo_baz->Clone()->name(), "foo.baz.clone");
// Test incrementing a large number after the suffix.
auto foo_clone234 = HloInstruction::CreateParameter(
0, ShapeUtil::MakeShape(F32, {}), "foo.clone234");
EXPECT_EQ(foo_clone234->Clone()->name(), "foo.clone235");
// Test a non-numeric string after the cloning suffix.
auto foo_clonexyz = HloInstruction::CreateParameter(
0, ShapeUtil::MakeShape(F32, {}), "foo.clonexyz");
EXPECT_EQ(foo_clonexyz->Clone()->name(), "foo.clonexyz.clone");
// Test a name with multiple appearances of the suffix.
auto foo_clone_clone3 = HloInstruction::CreateParameter(
0, ShapeUtil::MakeShape(F32, {}), "foo.clone.clone3");
EXPECT_EQ(foo_clone_clone3->Clone()->name(), "foo.clone.clone4");
}
TEST_F(HloInstructionTest, Stringification) {
// Tests stringification of a simple op, fusion, while, and conditional.
const Shape s1 = ShapeUtil::MakeShape(F32, {5, 10});
const Shape s2 = ShapeUtil::MakeShape(F32, {20, 10});
const Shape s2t = ShapeUtil::MakeShape(F32, {10, 20});
const Shape sout = ShapeUtil::MakeShape(F32, {5, 20});
HloComputation::Builder builder("TransposeDot");
HloInstruction* x =
builder.AddInstruction(HloInstruction::CreateParameter(0, s1, "x"));
HloInstruction* y =
builder.AddInstruction(HloInstruction::CreateParameter(1, s2, "y"));
HloInstruction* reshape =
builder.AddInstruction(HloInstruction::CreateTranspose(s2t, y, {1, 0}));
DotDimensionNumbers dot_dnums;
dot_dnums.add_lhs_contracting_dimensions(1);
dot_dnums.add_rhs_contracting_dimensions(0);
HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateDot(
sout, x, reshape, dot_dnums, DefaultPrecisionConfig(2)));
auto options = HloPrintOptions().set_print_metadata(false);
EXPECT_EQ(dot->ToString(options),
"%dot = f32[5,20]{1,0} dot(f32[5,10]{1,0} %x, f32[10,20]{1,0} "
"%transpose), lhs_contracting_dims={1}, rhs_contracting_dims={0}");
auto options2 = HloPrintOptions()
.set_print_metadata(false)
.set_print_operand_shape(false)
.set_print_percent(false)
.set_include_layout_in_shapes(false);
EXPECT_EQ(dot->ToString(options2),
"dot = f32[5,20] dot(x, transpose), "
"lhs_contracting_dims={1}, rhs_contracting_dims={0}");
auto module = CreateNewVerifiedModule();
auto* computation = module->AddEntryComputation(builder.Build());
HloInstruction* loop = builder.AddInstruction(
HloInstruction::CreateWhile(sout, computation, computation, x));
EXPECT_EQ(loop->ToString(options),
"%while = f32[5,20]{1,0} while(f32[5,10]{1,0} %x), "
"condition=%TransposeDot, body=%TransposeDot");
auto pred = builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(true)));
HloInstruction* conditional =
builder.AddInstruction(HloInstruction::CreateConditional(
sout, pred, x, computation, x, computation));
EXPECT_EQ(conditional->ToString(options),
"%conditional = f32[5,20]{1,0} conditional(pred[] %constant, "
"f32[5,10]{1,0} %x, f32[5,10]{1,0} %x), "
"true_computation=%TransposeDot, false_computation=%TransposeDot");
}
TEST_F(HloInstructionTest, StringifyGather_0) {
Shape input_tensor_shape = ShapeUtil::MakeShape(F32, {50, 49, 48, 47, 46});
Shape start_indices_tensor_shape =
ShapeUtil::MakeShape(S64, {10, 9, 8, 7, 5});
Shape gather_result_shape =
ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28, 27, 26});
HloComputation::Builder builder("Gather");
HloInstruction* input = builder.AddInstruction(
HloInstruction::CreateParameter(0, input_tensor_shape, "input_tensor"));
HloInstruction* start_indices =
builder.AddInstruction(HloInstruction::CreateParameter(
1, start_indices_tensor_shape, "start_indices"));
HloInstruction* gather_instruction = builder.AddInstruction(
HloInstruction::CreateGather(gather_result_shape, input, start_indices,
HloGatherInstruction::MakeGatherDimNumbers(
/*offset_dims=*/{4, 5, 6, 7, 8},
/*collapsed_slice_dims=*/{},
/*start_index_map=*/{0, 1, 2, 3, 4},
/*index_vector_dim=*/4),
/*slice_sizes=*/{30, 29, 28, 27, 26},
/*indices_are_sorted=*/false));
auto module = CreateNewVerifiedModule();
module->AddEntryComputation(builder.Build());
EXPECT_EQ(gather_instruction->ToString(),
"%gather = f32[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} "
"gather(f32[50,49,48,47,46]{4,3,2,1,0} %input_tensor, "
"s64[10,9,8,7,5]{4,3,2,1,0} %start_indices), "
"offset_dims={4,5,6,7,8}, collapsed_slice_dims={}, "
"start_index_map={0,1,2,3,4}, "
"index_vector_dim=4, slice_sizes={30,29,28,27,26}");
}
TEST_F(HloInstructionTest, StringifyGather_1) {
Shape input_tensor_shape = ShapeUtil::MakeShape(F32, {50, 49, 48, 47, 46});
Shape start_indices_tensor_shape =
ShapeUtil::MakeShape(S64, {10, 9, 5, 7, 6});
Shape gather_result_shape =
ShapeUtil::MakeShape(F32, {10, 9, 7, 6, 30, 29, 28, 27, 26});
HloComputation::Builder builder("Gather");
HloInstruction* input = builder.AddInstruction(
HloInstruction::CreateParameter(0, input_tensor_shape, "input_tensor"));
HloInstruction* start_indices =
builder.AddInstruction(HloInstruction::CreateParameter(
1, start_indices_tensor_shape, "start_indices"));
HloInstruction* gather_instruction = builder.AddInstruction(
HloInstruction::CreateGather(gather_result_shape, input, start_indices,
HloGatherInstruction::MakeGatherDimNumbers(
/*offset_dims=*/{4, 5, 6, 7, 8},
/*collapsed_slice_dims=*/{},
/*start_index_map=*/{0, 1, 2, 3, 4},
/*index_vector_dim=*/2),
/*slice_sizes=*/{30, 29, 28, 27, 26},
/*indices_are_sorted=*/false));
auto module = CreateNewVerifiedModule();
module->AddEntryComputation(builder.Build());
EXPECT_EQ(gather_instruction->ToString(),
"%gather = f32[10,9,7,6,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} "
"gather(f32[50,49,48,47,46]{4,3,2,1,0} %input_tensor, "
"s64[10,9,5,7,6]{4,3,2,1,0} %start_indices), "
"offset_dims={4,5,6,7,8}, collapsed_slice_dims={}, "
"start_index_map={0,1,2,3,4}, "
"index_vector_dim=2, slice_sizes={30,29,28,27,26}");
}
TEST_F(HloInstructionTest, StringifyScatter) {
Shape input_tensor_shape = ShapeUtil::MakeShape(F32, {50, 49, 48, 47, 46});
Shape scatter_indices_tensor_shape =
ShapeUtil::MakeShape(S64, {10, 9, 5, 7, 6});
Shape scatter_updates_shape =
ShapeUtil::MakeShape(F32, {10, 9, 7, 6, 30, 29, 28, 27, 26});
HloComputation::Builder builder("Scatter");
HloInstruction* input = builder.AddInstruction(
HloInstruction::CreateParameter(0, input_tensor_shape, "input_tensor"));
HloInstruction* scatter_indices =
builder.AddInstruction(HloInstruction::CreateParameter(
1, scatter_indices_tensor_shape, "scatter_indices"));
HloInstruction* scatter_updates =
builder.AddInstruction(HloInstruction::CreateParameter(
2, scatter_updates_shape, "scatter_updates"));
HloComputation::Builder update_builder("Scatter.update");
update_builder.AddInstruction(
HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {}), "p1"));
update_builder.AddInstruction(
HloInstruction::CreateParameter(1, ShapeUtil::MakeShape(F32, {}), "p2"));
auto module = CreateNewVerifiedModule();
auto* update_computation =
module->AddEmbeddedComputation(update_builder.Build());
HloInstruction* scatter_instruction =
builder.AddInstruction(HloInstruction::CreateScatter(
input_tensor_shape, input, scatter_indices, scatter_updates,
update_computation,
HloScatterInstruction::MakeScatterDimNumbers(
/*update_window_dims=*/{4, 5, 6, 7, 8},
/*inserted_window_dims=*/{},
/*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
/*index_vector_dim=*/2),
/*indices_are_sorted=*/false,
/*unique_indices=*/false));
module->AddEntryComputation(builder.Build());
EXPECT_EQ(
scatter_instruction->ToString(),
"%scatter = f32[50,49,48,47,46]{4,3,2,1,0} "
"scatter(f32[50,49,48,47,46]{4,3,2,1,0} %input_tensor, "
"s64[10,9,5,7,6]{4,3,2,1,0} %scatter_indices, "
"f32[10,9,7,6,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} %scatter_updates), "
"update_window_dims={4,5,6,7,8}, inserted_window_dims={}, "
"scatter_dims_to_operand_dims={0,1,2,3,4}, index_vector_dim=2, "
"to_apply=%Scatter.update");
}
TEST_F(HloInstructionTest, StringifyAsyncOps) {
const Shape s1 = ShapeUtil::MakeShape(F32, {10});
const Shape s2 = ShapeUtil::MakeShape(F32, {20});
const Shape s_tuple = ShapeUtil::MakeTupleShape(
{ShapeUtil::MakeTupleShape({s1}), s2, ShapeUtil::MakeShape(S32, {})});
HloComputation::Builder async_builder("AsyncOp");
HloInstruction* param = async_builder.AddInstruction(
HloInstruction::CreateParameter(0, s1, "p0"));
async_builder.AddInstruction(
HloInstruction::CreateCustomCall(s2, {param},
/*custom_call_target=*/"foo"));
std::unique_ptr<HloComputation> async_computation = async_builder.Build();
HloComputation::Builder entry_builder("Entry");
HloInstruction* entry_param = entry_builder.AddInstruction(
HloInstruction::CreateParameter(0, s1, "p0"));
HloInstruction* async_start =
entry_builder.AddInstruction(HloInstruction::CreateAsyncStart(
s_tuple, {entry_param}, async_computation.get(),
/*async_group_id=*/std::nullopt,
/*async_thread_name=*/"parallel_thread"));
HloInstruction* async_update =
entry_builder.AddInstruction(HloInstruction::CreateAsyncUpdate(
s_tuple, async_start, async_computation.get(),
/*async_group_id=*/std::nullopt,
/*async_thread_name=*/"parallel_thread"));
entry_builder.AddInstruction(
HloInstruction::CreateAsyncDone(s2, async_update, async_computation.get(),
/*async_group_id=*/std::nullopt,
/*async_thread_name=*/"parallel_thread"));
auto module = CreateNewVerifiedModule();
module->AddEntryComputation(entry_builder.Build());
module->AddEmbeddedComputation(std::move(async_computation));
const std::string expected_with_syntax_sugar =
R"(HloModule StringifyAsyncOps, entry_computation_layout={(f32[10]{0})->f32[20]{0}}
ENTRY %Entry (p0: f32[10]) -> f32[20] {
%p0 = f32[10]{0} parameter(0)
%async-start = ((f32[10]{0}), f32[20]{0}, s32[]) custom-call-start(f32[10]{0} %p0), async_thread_name="parallel_thread", custom_call_target="foo"
%async-update = ((f32[10]{0}), f32[20]{0}, s32[]) custom-call-update(((f32[10]{0}), f32[20]{0}, s32[]) %async-start), async_thread_name="parallel_thread", custom_call_target="foo"
ROOT %async-done = f32[20]{0} custom-call-done(((f32[10]{0}), f32[20]{0}, s32[]) %async-update), async_thread_name="parallel_thread", custom_call_target="foo"
}
)";
EXPECT_EQ(module->ToString(), expected_with_syntax_sugar);
const std::string expected_without_syntax_sugar =
R"(HloModule StringifyAsyncOps, entry_computation_layout={(f32[10]{0})->f32[20]{0}}
%AsyncOp (p0.1: f32[10]) -> f32[20] {
%p0.1 = f32[10]{0} parameter(0)
ROOT %custom-call = f32[20]{0} custom-call(f32[10]{0} %p0.1), custom_call_target="foo"
}, thread_name="parallel_thread"
ENTRY %Entry (p0: f32[10]) -> f32[20] {
%p0 = f32[10]{0} parameter(0)
%async-start = ((f32[10]{0}), f32[20]{0}, s32[]) async-start(f32[10]{0} %p0), async_thread_name="parallel_thread", calls=%AsyncOp
%async-update = ((f32[10]{0}), f32[20]{0}, s32[]) async-update(((f32[10]{0}), f32[20]{0}, s32[]) %async-start), async_thread_name="parallel_thread", calls=%AsyncOp
ROOT %async-done = f32[20]{0} async-done(((f32[10]{0}), f32[20]{0}, s32[]) %async-update), async_thread_name="parallel_thread", calls=%AsyncOp
}
)";
auto options = HloPrintOptions().set_syntax_sugar_async_ops(false);
EXPECT_EQ(module->ToString(options), expected_without_syntax_sugar);
}
TEST_F(HloInstructionTest, CanonicalStringificationFusion) {
// Tests stringification of a simple op, fusion, while, and conditional.
const Shape s1 = ShapeUtil::MakeShape(F32, {5, 10});
const Shape s2 = ShapeUtil::MakeShape(F32, {20, 10});
const Shape s2t = ShapeUtil::MakeShape(F32, {10, 20});
const Shape sout = ShapeUtil::MakeShape(F32, {5, 20});
HloComputation::Builder builder("TransposeDot");
HloInstruction* x =
builder.AddInstruction(HloInstruction::CreateParameter(0, s1, "x"));
HloInstruction* y =
builder.AddInstruction(HloInstruction::CreateParameter(1, s2, "y"));
HloInstruction* reshape =
builder.AddInstruction(HloInstruction::CreateTranspose(s2t, y, {1, 0}));
DotDimensionNumbers dot_dnums;
dot_dnums.add_lhs_contracting_dimensions(1);
dot_dnums.add_rhs_contracting_dimensions(0);
HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateDot(
sout, x, reshape, dot_dnums, DefaultPrecisionConfig(2)));
auto options = HloPrintOptions().Canonical();
EXPECT_EQ(dot->ToString(options),
"f32[5,20]{1,0} dot(f32[5,10]{1,0}, f32[10,20]{1,0}), "
"lhs_contracting_dims={1}, rhs_contracting_dims={0}");
auto module = CreateNewVerifiedModule();
auto* computation = module->AddEntryComputation(builder.Build());
constexpr char kParallelThreadName[] = "parallel_thread";
computation->SetThreadName(kParallelThreadName);
HloInstruction* fusion = computation->CreateFusionInstruction(
{dot, reshape}, HloInstruction::FusionKind::kLoop);
fusion->set_called_computations_thread_name(
kParallelThreadName,
/*skip_async_thread_name_overwrite*/ false);
const std::string expected_fusion =
R"(f32[5,20]{1,0} fusion(f32[5,10]{1,0}, f32[20,10]{1,0}), kind=kLoop, calls=
{
tmp_0 = f32[5,10]{1,0} parameter(0)
tmp_1 = f32[20,10]{1,0} parameter(1)
tmp_2 = f32[10,20]{1,0} transpose(f32[20,10]{1,0} tmp_1), dimensions={1,0}
ROOT tmp_3 = f32[5,20]{1,0} dot(f32[5,10]{1,0} tmp_0, f32[10,20]{1,0} tmp_2), lhs_contracting_dims={1}, rhs_contracting_dims={0}
}, thread_name="parallel_thread")";
EXPECT_EQ(fusion->ToString(options), expected_fusion);
}
TEST_F(HloInstructionTest, CanonicalStringificationWhile) {
// Tests stringification of a simple op, fusion, while, and conditional.
const Shape s1 = ShapeUtil::MakeShape(F32, {5, 10});
const Shape s2 = ShapeUtil::MakeShape(F32, {20, 10});
const Shape s2t = ShapeUtil::MakeShape(F32, {10, 20});
const Shape sout = ShapeUtil::MakeShape(F32, {5, 20});
HloComputation::Builder builder("TransposeDot");
HloInstruction* x =
builder.AddInstruction(HloInstruction::CreateParameter(0, s1, "x"));
HloInstruction* y =
builder.AddInstruction(HloInstruction::CreateParameter(1, s2, "y"));
HloInstruction* reshape =
builder.AddInstruction(HloInstruction::CreateTranspose(s2t, y, {1, 0}));
DotDimensionNumbers dot_dnums;
dot_dnums.add_lhs_contracting_dimensions(1);
dot_dnums.add_rhs_contracting_dimensions(0);
HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateDot(
sout, x, reshape, dot_dnums, DefaultPrecisionConfig(2)));
auto module = CreateNewVerifiedModule();
auto* computation = module->AddEntryComputation(builder.Build());
computation->CreateFusionInstruction({dot, reshape},
HloInstruction::FusionKind::kLoop);
HloInstruction* loop = builder.AddInstruction(
HloInstruction::CreateWhile(sout, computation, computation, x));
auto options = HloPrintOptions().Canonical();
const std::string expected_loop =
R"(f32[5,20]{1,0} while(f32[5,10]{1,0}), condition=
{
tmp_0 = f32[5,10]{1,0} parameter(0)
tmp_1 = f32[20,10]{1,0} parameter(1)
ROOT tmp_2 = f32[5,20]{1,0} fusion(f32[5,10]{1,0} tmp_0, f32[20,10]{1,0} tmp_1), kind=kLoop, calls=
{
tmp_0 = f32[5,10]{1,0} parameter(0)
tmp_1 = f32[20,10]{1,0} parameter(1)
tmp_2 = f32[10,20]{1,0} transpose(f32[20,10]{1,0} tmp_1), dimensions={1,0}
ROOT tmp_3 = f32[5,20]{1,0} dot(f32[5,10]{1,0} tmp_0, f32[10,20]{1,0} tmp_2), lhs_contracting_dims={1}, rhs_contracting_dims={0}
}
}, body=
{
tmp_0 = f32[5,10]{1,0} parameter(0)
tmp_1 = f32[20,10]{1,0} parameter(1)
ROOT tmp_2 = f32[5,20]{1,0} fusion(f32[5,10]{1,0} tmp_0, f32[20,10]{1,0} tmp_1), kind=kLoop, calls=
{
tmp_0 = f32[5,10]{1,0} parameter(0)
tmp_1 = f32[20,10]{1,0} parameter(1)
tmp_2 = f32[10,20]{1,0} transpose(f32[20,10]{1,0} tmp_1), dimensions={1,0}
ROOT tmp_3 = f32[5,20]{1,0} dot(f32[5,10]{1,0} tmp_0, f32[10,20]{1,0} tmp_2), lhs_contracting_dims={1}, rhs_contracting_dims={0}
}
})";
EXPECT_EQ(loop->ToString(options), expected_loop);
}
TEST_F(HloInstructionTest, CanonicalStringificationConditional) {
// Tests stringification of a simple op, fusion, while, and conditional.
const Shape s1 = ShapeUtil::MakeShape(F32, {5, 10});
const Shape s2 = ShapeUtil::MakeShape(F32, {20, 10});
const Shape s2t = ShapeUtil::MakeShape(F32, {10, 20});
const Shape sout = ShapeUtil::MakeShape(F32, {5, 20});
HloComputation::Builder builder("TransposeDot");
HloInstruction* x =
builder.AddInstruction(HloInstruction::CreateParameter(0, s1, "x"));
HloInstruction* y =
builder.AddInstruction(HloInstruction::CreateParameter(1, s2, "y"));
HloInstruction* reshape =
builder.AddInstruction(HloInstruction::CreateTranspose(s2t, y, {1, 0}));
DotDimensionNumbers dot_dnums;
dot_dnums.add_lhs_contracting_dimensions(1);
dot_dnums.add_rhs_contracting_dimensions(0);
HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateDot(
sout, x, reshape, dot_dnums, DefaultPrecisionConfig(2)));
auto module = CreateNewVerifiedModule();
auto* computation = module->AddEntryComputation(builder.Build());
computation->CreateFusionInstruction({dot, reshape},
HloInstruction::FusionKind::kLoop);
builder.AddInstruction(
HloInstruction::CreateWhile(sout, computation, computation, x));
auto pred = builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(true)));
HloInstruction* conditional =
builder.AddInstruction(HloInstruction::CreateConditional(
sout, pred, x, computation, x, computation));
auto options = HloPrintOptions().Canonical();
const std::string expected_conditional =
R"(f32[5,20]{1,0} conditional(pred[], f32[5,10]{1,0}, f32[5,10]{1,0}), true_computation=
{
tmp_0 = f32[5,10]{1,0} parameter(0)
tmp_1 = f32[20,10]{1,0} parameter(1)
ROOT tmp_2 = f32[5,20]{1,0} fusion(f32[5,10]{1,0} tmp_0, f32[20,10]{1,0} tmp_1), kind=kLoop, calls=
{
tmp_0 = f32[5,10]{1,0} parameter(0)
tmp_1 = f32[20,10]{1,0} parameter(1)
tmp_2 = f32[10,20]{1,0} transpose(f32[20,10]{1,0} tmp_1), dimensions={1,0}
ROOT tmp_3 = f32[5,20]{1,0} dot(f32[5,10]{1,0} tmp_0, f32[10,20]{1,0} tmp_2), lhs_contracting_dims={1}, rhs_contracting_dims={0}
}
}, false_computation=
{
tmp_0 = f32[5,10]{1,0} parameter(0)
tmp_1 = f32[20,10]{1,0} parameter(1)
ROOT tmp_2 = f32[5,20]{1,0} fusion(f32[5,10]{1,0} tmp_0, f32[20,10]{1,0} tmp_1), kind=kLoop, calls=
{
tmp_0 = f32[5,10]{1,0} parameter(0)
tmp_1 = f32[20,10]{1,0} parameter(1)
tmp_2 = f32[10,20]{1,0} transpose(f32[20,10]{1,0} tmp_1), dimensions={1,0}
ROOT tmp_3 = f32[5,20]{1,0} dot(f32[5,10]{1,0} tmp_0, f32[10,20]{1,0} tmp_2), lhs_contracting_dims={1}, rhs_contracting_dims={0}
}
})";
EXPECT_EQ(conditional->ToString(options), expected_conditional);
}
TEST_F(HloInstructionTest, CheckDeepClone) {
const char* const hlo_string = R"(
HloModule Module
addy (lhs: s32[], rhs: s32[]) -> s32[] {
lhs = s32[] parameter(0)
rhs = s32[] parameter(1)
ROOT zadd = s32[] add(lhs, rhs)
}
calla (x: s32[]) -> s32[] {
x = s32[] parameter(0)
reduce = s32[] reduce-window(x, x), to_apply=addy
ROOT xadd = s32[] add(x, reduce)
}
body (bparam: s32[]) -> s32[] {
constant = s32[] constant(1)
bparam = s32[] parameter(0)
v = s32[] call(bparam), to_apply=calla
ROOT add = s32[] add(constant, bparam)
}
condition (cparam: s32[]) -> pred[] {
xconstant = s32[] constant(5)
cparam = s32[] parameter(0)
ROOT greater-than = pred[] compare(xconstant, cparam), direction=GT
}
ENTRY entry (param: s32[]) -> s32[] {
eparam = s32[] parameter(0)
ROOT while = s32[] while(eparam), condition=condition, body=body
}
)";
// Check that deep clones really deep clones every instruction and
// computations, without leaving dangling pointers to the old module.
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(hlo_string));
std::unique_ptr<HloModule> clone = module->Clone();
for (HloComputation* computation : clone->computations()) {
EXPECT_EQ(computation->parent(), clone.get());
for (HloInstruction* instruction : computation->instructions()) {
EXPECT_EQ(instruction->parent()->parent(), clone.get());
}
}
}
TEST_F(HloInstructionTest, IdenticalAccountsForBackendConfig) {
const Shape shape = ShapeUtil::MakeShape(F32, {42});
HloComputation::Builder builder("test");
HloInstruction* p =
builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "p"));
HloInstruction* add1 = builder.AddInstruction(
HloInstruction::CreateBinary(shape, HloOpcode::kAdd, p, p));
HloInstruction* add2 = builder.AddInstruction(
HloInstruction::CreateBinary(shape, HloOpcode::kAdd, p, p));
EXPECT_TRUE(add1->Identical(*add2));
add1->set_raw_backend_config_string("abc");
EXPECT_FALSE(add1->Identical(*add2));
}
TEST_F(HloInstructionTest, IdenticalAccountsForCustomCallWindow) {
auto instr1 = HloInstruction::CreateCustomCall(ShapeUtil::MakeShape(F32, {}),
/*operands=*/{},
/*custom_call_target=*/"foo");
auto instr2 = instr1->Clone();
EXPECT_TRUE(instr1->Identical(*instr2));
Window w = window_util::MakeWindow({1, 2, 3});
instr1->set_window(w);
EXPECT_FALSE(instr1->Identical(*instr2));
}
TEST_F(HloInstructionTest, IdenticalAccountsForCustomCallDnums) {
auto instr1 = HloInstruction::CreateCustomCall(ShapeUtil::MakeShape(F32, {}),
/*operands=*/{},
/*custom_call_target=*/"foo");
auto instr2 = instr1->Clone();
EXPECT_TRUE(instr1->Identical(*instr2));
ConvolutionDimensionNumbers dnums;
dnums.set_output_batch_dimension(42);
instr1->set_convolution_dimension_numbers(dnums);
EXPECT_FALSE(instr1->Identical(*instr2));
}
TEST_F(HloInstructionTest, IdenticalAccountsForCustomCallHasSideEffect) {
auto instr1 = HloInstruction::CreateCustomCall(ShapeUtil::MakeShape(F32, {}),
/*operands=*/{},
/*custom_call_target=*/"foo");
auto instr2 = instr1->Clone();
EXPECT_TRUE(instr1->Identical(*instr2));
auto custom_call_instr1 = Cast<HloCustomCallInstruction>(instr1.get());
custom_call_instr1->set_custom_call_has_side_effect(true);
EXPECT_FALSE(instr1->Identical(*instr2));
}
TEST_F(HloInstructionTest, CloneWindowOnCustomCall) {
auto instr = HloInstruction::CreateCustomCall(ShapeUtil::MakeShape(F32, {}),
/*operands=*/{},
/*custom_call_target=*/"foo");
Window w = window_util::MakeWindow({1, 2, 3});
instr->set_window(w);
auto clone = instr->Clone();
EXPECT_TRUE(protobuf_util::ProtobufEquals(clone->window(), w))
<< clone->window().DebugString();
}
TEST_F(HloInstructionTest, CloneDnumsOnCustomCall) {
auto instr = HloInstruction::CreateCustomCall(ShapeUtil::MakeShape(F32, {}),
/*operands=*/{},
/*custom_call_target=*/"foo");
ConvolutionDimensionNumbers dnums;
dnums.set_output_batch_dimension(42);
instr->set_convolution_dimension_numbers(dnums);
auto clone = instr->Clone();
EXPECT_TRUE(protobuf_util::ProtobufEquals(
clone->convolution_dimension_numbers(), dnums))
<< clone->convolution_dimension_numbers().DebugString();
}
TEST_F(HloInstructionTest, CloneHasSideEffectOnCustomCall) {
auto instr = HloInstruction::CreateCustomCall(ShapeUtil::MakeShape(F32, {}),
/*operands=*/{},
/*custom_call_target=*/"foo");
auto custom_call_instr = Cast<HloCustomCallInstruction>(instr.get());
EXPECT_FALSE(custom_call_instr->custom_call_has_side_effect());
custom_call_instr->set_custom_call_has_side_effect(true);
EXPECT_TRUE(custom_call_instr->custom_call_has_side_effect());
auto clone = instr->Clone();
auto custom_call_clone = Cast<HloCustomCallInstruction>(clone.get());
EXPECT_TRUE(custom_call_clone->custom_call_has_side_effect());
}
TEST_F(HloInstructionTest, CustomCallHasSideEffect) {
auto instr = HloInstruction::CreateCustomCall(ShapeUtil::MakeShape(F32, {}),
/*operands=*/{},
/*custom_call_target=*/"foo");
auto custom_call_instr = Cast<HloCustomCallInstruction>(instr.get());
EXPECT_FALSE(instr->HasSideEffect());
custom_call_instr->set_custom_call_has_side_effect(true);
EXPECT_TRUE(instr->HasSideEffect());
}
TEST_F(HloInstructionTest, PreserveOperandPrecisionOnCloneConv) {
constexpr char kHloString[] = R"(
HloModule test_module
ENTRY test {
arg0 = f32[1,2,1] parameter(0)
arg1 = f32[1,1,1] parameter(1)
ROOT conv = f32[1,2,1] convolution(arg0, arg1), window={size=1},
dim_labels=b0f_0io->b0f, operand_precision={high,default}
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(kHloString));
auto* conv = module->entry_computation()->root_instruction();
auto clone = conv->Clone();
EXPECT_THAT(
clone->precision_config().operand_precision(),
::testing::ElementsAre(PrecisionConfig::HIGH, PrecisionConfig::DEFAULT));
}
TEST_F(HloInstructionTest, PreserveOuterDimensionPartitionsOnClone) {
constexpr char kHloString[] = R"(
HloModule test_module
ENTRY test {
ROOT iota = f32[100] iota(), iota_dimension=0, outer_dimension_partitions={0, 50}
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(kHloString));
auto* iota = module->entry_computation()->root_instruction();
auto clone = iota->Clone();
EXPECT_THAT(clone->outer_dimension_partitions(),
::testing::ElementsAre(0, 50));
}
TEST_F(HloInstructionTest, ReuseReshapeOfFusionParameter) {
// Create a fusion node which uses the reshape of a parameter twice. Because
// it's the same reshape, this counts as UseKind::kUsePermutingElements, which
// is exposed publicly as "does not reuse this operand".
constexpr char kHloString[] = R"(
HloModule test_module
f {
p = f32[3,2] parameter(0)
r = f32[2,3] reshape(p)
x = f32[2,3] multiply(r, r)
y = f32[2,3] add(r, r)
ROOT sum = f32[2,3] add(x, y)
}
ENTRY test {
p = f32[3,2] parameter(0)
ROOT fusion = f32[2,3] fusion(p), calls=f, kind=kLoop
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(kHloString));
const HloInstruction* root = module->entry_computation()->root_instruction();
EXPECT_FALSE(root->ReusesOperandElements(0));
}
TEST_F(HloInstructionTest, ReuseMultipleReshapesOfFusionParameter) {
// Create a fusion node which uses two different reshapes of a parameter
// twice. Because they're not the same reshapes, this counts as
// UseKind::kUsePermutingElements, which is exposed publicly as "does reuse
// this operand".
constexpr char kHloString[] = R"(
HloModule test_module
f {
p = f32[3,2] parameter(0)
r1 = f32[2,3] reshape(p)
r2 = f32[6,1] reshape(p)
ROOT result = (f32[2,3], f32[6,1]) tuple(r1, r2)
}
ENTRY test {
p = f32[3,2] parameter(0)
ROOT fusion = (f32[2,3], f32[6,1]) fusion(p), calls=f, kind=kLoop
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(kHloString));
const HloInstruction* root = module->entry_computation()->root_instruction();
EXPECT_TRUE(root->ReusesOperandElements(0));
}
TEST_F(HloInstructionTest, BitcastDoesNotReuseElements) {
constexpr char kHloString[] = R"(
HloModule test_module
ENTRY test {
p = f32[3,2]{0,1} parameter(0)
ROOT bitcast = f32[6] bitcast(p)
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(kHloString));
const HloInstruction* root = module->entry_computation()->root_instruction();
EXPECT_FALSE(root->ReusesOperandElements(0));
}
TEST_F(HloInstructionTest, GatherDoesNotReuseElements) {
constexpr char kHloString[] = R"(
HloModule test_module
ENTRY test {
input = f32[50,49,48,47,46]{4,3,2,1,0} parameter(0)
idx = s64[10,9,8,7,5]{4,3,2,1,0} parameter(1)
ROOT gather = f32[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0}
gather(input, idx), offset_dims={4,5,6,7,8}, collapsed_slice_dims={},
start_index_map={0,1,2,3,4}, index_vector_dim=4,
slice_sizes={30,29,28,27,26}
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(kHloString));
const HloInstruction* root = module->entry_computation()->root_instruction();
EXPECT_FALSE(root->ReusesOperandElements(0));
EXPECT_FALSE(root->ReusesOperandElements(1));
}
TEST_F(HloInstructionTest, BackendConfigCanContainNonFiniteFloats) {
HloComputation::Builder b(TestName());
Shape shape = ShapeUtil::MakeShape(F32, {2, 2});
auto p0 = b.AddInstruction(HloInstruction::CreateParameter(0, shape, "p0"));
DotDimensionNumbers dot_dnums;
dot_dnums.add_lhs_contracting_dimensions(1);
dot_dnums.add_rhs_contracting_dimensions(0);
auto dot = b.AddInstruction(HloInstruction::CreateDot(
shape, p0, p0, dot_dnums, DefaultPrecisionConfig(2)));
gpu::GemmBackendConfig orig_config;
orig_config.set_alpha_real(std::numeric_limits<double>::infinity());
orig_config.set_alpha_imag(std::numeric_limits<double>::quiet_NaN());
TF_ASSERT_OK(dot->set_backend_config(orig_config));
TF_ASSERT_OK_AND_ASSIGN(auto new_config,
dot->backend_config<gpu::GemmBackendConfig>());
EXPECT_GT(new_config.alpha_real(), std::numeric_limits<double>::max());
EXPECT_NE(new_config.alpha_imag(), new_config.alpha_imag());
}
} // namespace
} // namespace xla