blob: a7a4b1366c7363cef280cb15e390066be1879639 [file] [log] [blame]
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include <memory>
#include <string>
#include <utility>
#include "absl/memory/memory.h"
#include "absl/strings/match.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/string_view.h"
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
#include "tensorflow/compiler/xla/service/pattern_matcher.h"
#include "tensorflow/compiler/xla/service/pattern_matcher_gmock.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/tests/verified_hlo_module.h"
#include "tensorflow/compiler/xla/window_util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
namespace xla {
namespace {
namespace m = ::xla::match;
using absl::string_view;
struct TestData {
string test_name;
string module_string;
int64 replica_count = 1;
bool enable_verification = true;
};
string TestDataToString(const ::testing::TestParamInfo<TestData>& data) {
return data.param.test_name;
}
// For each string below, we check that:
// - we parse it to an HloModule successfully, and
// - the stringification of the resulting HloModule is equal to our original
// string.
std::vector<TestData> CreateTestCases() {
// clang-format off
return std::vector<TestData>({
// ax + y
{
"AxpyParam",
R"(HloModule axpy_module
ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] {
%alpha = f32[] parameter(0)
%broadcast = f32[2,4]{1,0} broadcast(f32[] %alpha), dimensions={}
%x = f32[2,4]{1,0} parameter(1)
%multiply = f32[2,4]{1,0} multiply(f32[2,4]{1,0} %broadcast, f32[2,4]{1,0} %x)
%y = f32[2,4]{1,0} parameter(2)
ROOT %add = f32[2,4]{1,0} add(f32[2,4]{1,0} %multiply, f32[2,4]{1,0} %y)
}
)"
},
// parameter replication
{
"ParamReplication",
R"(HloModule param_replication_module
ENTRY %param_replication (a: f32[], b: (f32[2,4], (f32[2,4]))) -> (f32[], (f32[2,4], (f32[2,4]))) {
%a = f32[] parameter(0), parameter_replication={true}
%b = (f32[2,4]{1,0}, (f32[2,4]{1,0})) parameter(1), parameter_replication={false,true}
ROOT %tuple = (f32[], (f32[2,4]{1,0}, (f32[2,4]{1,0}))) tuple(f32[] %a, (f32[2,4]{1,0}, (f32[2,4]{1,0})) %b)
}
)"
},
// pred constant
{
"ConstantPred",
R"(HloModule constant_pred_module
ENTRY %constant_pred () -> pred[] {
ROOT %constant = pred[] constant(true), metadata={op_type="const" op_name="\"it\'s not a problem\n" source_file="path/to/test.cc" source_line=68}, backend_config="foo\" bar"
}
)"
},
// pred array constant
{
"ConstantPredArray",
R"(HloModule module
ENTRY %constant_pred_array () -> pred[2,3] {
ROOT %constant = pred[2,3]{1,0} constant({ { 0, 1, 0 }, { 1, 0, 1 } })
}
)"
},
// s32 constant
{
"ConstantS32",
R"(HloModule constant_s32_module
ENTRY %constant_s32 () -> s32[] {
ROOT %constant = s32[] constant(-42)
}
)"
},
// f32 constant, but the value is not a decimal and there is a backend
// configuration
{
"ConstantF32",
R"(HloModule ConstantF32_module
ENTRY %ConstantF32.v4 () -> f32[] {
ROOT %constant = f32[] constant(42), backend_config="this is a configuration"
}
)"
},
// f32 constant, rank 1 empty array.
{
"ConstantF32R1Empty",
R"(HloModule ConstantF32Empty_module
ENTRY %ConstantF32Empty.v4 () -> f32[0] {
ROOT %constant = f32[0]{0} constant({})
}
)"
},
// f32 constant, rank 4 empty array.
{
"ConstantF32R4Empty",
R"(HloModule ConstantF32R4Empty_module
ENTRY %ConstantF32R4Empty.v4 () -> f32[2,0,4,3] {
ROOT %constant = f32[2,0,4,3]{3,2,1,0} constant({ { /*i0=0*/ }, { /*i0=1*/ } })
}
)"
},
// constant 4D
{
"Constant4D",
R"(HloModule Small_3x2x1x1_module
ENTRY %Small_3x2x1x1.v1 () -> f32[3,2,1,1] {
ROOT %constant = f32[3,2,1,1]{3,2,1,0} constant({ { /*i0=0*/ { /*i1=0*/ {-1} }, { /*i1=1*/ {4.1} } }, { /*i0=1*/ { /*i1=0*/ {2} }, { /*i1=1*/ {4.1} } }, { /*i0=2*/ { /*i1=0*/ {5} }, { /*i1=1*/ {4.4} } } })
}
)"
},
// non-finite constants: nan, inf, -inf
{
"ConstantNonFinite",
R"(HloModule IsFiniteR1F32s_module
ENTRY %IsFiniteR1F32s.v2 () -> pred[6] {
%constant = f32[6]{0} constant({nan, 7, nan, -1, inf, -inf})
ROOT %is-finite = pred[6]{0} is-finite(f32[6]{0} %constant)
}
)"
},
// constant f16
{
"ConstantF16",
R"(HloModule ConstantF16_module
ENTRY %ConstantF16.v4 () -> f16[] {
ROOT %constant = f16[] constant(500)
}
)"
},
// bf16
{
"BF16",
R"(HloModule BF16
ENTRY %BF16.v4 () -> bf16[] {
ROOT %constant = bf16[] constant(500)
}
)"
},
// constant + constant
{
"AddConstants",
R"(HloModule add_constants_module
ENTRY %add_constants () -> f32[] {
%constant = f32[] constant(3.14)
ROOT %add = f32[] add(f32[] %constant, f32[] %constant)
}
)"
},
// tuple constant
{
"TupleConstant",
R"(HloModule TupleConstant_module
ENTRY %TupleConstant.v1 () -> (f32[2,1], f32[2]) {
ROOT %constant = (f32[2,1]{1,0}, f32[2]{0}) constant(( { {1}, {2} }, {2, 42} ))
}
)"
},
// v1 > v2 ? v1 : v2
{
"SelectR1F32",
R"(HloModule SelectR1F32WithCmpR1F32sFromParamsSmall_module
ENTRY %SelectR1F32WithCmpR1F32sFromParamsSmall.v4 (v1: f32[4], v2: f32[4]) -> f32[4] {
%v1 = f32[4]{0} parameter(0), sharding={maximal device=1}
%v2 = f32[4]{0} parameter(1), sharding={maximal device=1}
%greater-than = pred[4]{0} compare(f32[4]{0} %v1, f32[4]{0} %v2), direction=GT, type=TOTALORDER, sharding={replicated}
ROOT %select = f32[4]{0} select(pred[4]{0} %greater-than, f32[4]{0} %v1, f32[4]{0} %v2), sharding={}
}
)"
},
// empty tuple
{
"EmptyTupleCreate",
R"(HloModule EmptyTupleCreate_module
ENTRY %EmptyTupleCreate.v1 () -> () {
ROOT %tuple = () tuple()
}
)"
},
// tuple
{
"TupleCreate",
R"(HloModule TupleCreate_module
ENTRY %TupleCreate.v4 (v1: f32[], v2: f32[3], v3: f32[2,3]) -> (f32[], f32[3], f32[2,3]) {
%v1 = f32[] parameter(0)
%v2 = f32[3]{0} parameter(1)
%v3 = f32[2,3]{1,0} parameter(2)
ROOT %tuple = (f32[], f32[3]{0}, f32[2,3]{1,0}) tuple(f32[] %v1, f32[3]{0} %v2, f32[2,3]{1,0} %v3)
}
)"
},
// tuple
{
"LargeTupleRoundTrip",
R"(HloModule LargeTupleRoundTrip_module
ENTRY %TupleCreate.v4 (v: f32[]) -> (f32[], f32[], f32[], f32[], f32[], /*index=5*/f32[]) {
%v = f32[] parameter(0)
ROOT %tuple = (f32[], f32[], f32[], f32[], f32[], /*index=5*/f32[]) tuple(f32[] %v, f32[] %v, f32[] %v, f32[] %v, f32[] %v, /*index=5*/f32[] %v)
}
)"
},
{
"ShardedTupleCreate",
R"(HloModule ShardedTupleCreate_module
ENTRY %ShardedTupleCreate.v4 (v1: f32[], v2: f32[3], v3: f32[2,3]) -> (f32[], f32[3], f32[2,3]) {
%v1 = f32[] parameter(0), sharding={manual}
%v2 = f32[3]{0} parameter(1)
%v3 = f32[2,3]{1,0} parameter(2)
ROOT %tuple = (f32[], f32[3]{0}, f32[2,3]{1,0}) tuple(f32[] %v1, f32[3]{0} %v2, f32[2,3]{1,0} %v3), sharding={{manual}, {maximal device=0}, {replicated}}
}
)"
},
{
"DomainParsing",
R"(HloModule DomainParsing_module
ENTRY %DomainParsing (v1: f32[]) -> f32[] {
%v1 = f32[] parameter(0)
ROOT %dom = f32[] domain(f32[] %v1), domain={kind="sharding", entry={maximal device=0}, exit={maximal device=1}}
}
)"
},
// int32 result = 0;
// while (result < 5) { result = result + 1; }
{
"WhileWithScalarS32Result",
R"(HloModule WhileWithScalarS32Result_module
%body.v3 (prev.1: s32[]) -> s32[] {
%constant = s32[] constant(1)
%prev.1 = s32[] parameter(0)
ROOT %add = s32[] add(s32[] %constant, s32[] %prev.1)
}
%condition.v3 (prev.2: s32[]) -> pred[] {
%constant.1 = s32[] constant(5)
%prev.2 = s32[] parameter(0)
ROOT %greater-than = pred[] compare(s32[] %constant.1, s32[] %prev.2), direction=GT
}
ENTRY %WhileWithScalarS32Result.v2 () -> s32[] {
%constant.2 = s32[] constant(0)
ROOT %while = s32[] while(s32[] %constant.2), condition=%condition.v3, body=%body.v3
}
)"
},
// copy-start and copy-done
{
"CopyStartAndCopyDone",
R"(HloModule CopyStartAndCopyDone_module
ENTRY %CopyStartAndCopyDone (v1: f32[], v2: f32[2,3]) -> (f32[], f32[2,3]) {
%v1 = f32[] parameter(0)
%copy-start.1 = (f32[], f32[], u32[]) copy-start(f32[] %v1), is_cross_program_prefetch=true
%copy-done.1 = f32[] copy-done((f32[], f32[], u32[]) %copy-start.1)
%v2 = f32[2,3]{1,0:S(1)} parameter(1)
%copy-start.2 = (f32[2,3]{1,0:S(2)}, f32[2,3]{1,0:S(1)}, u32[]) copy-start(f32[2,3]{1,0:S(1)} %v2)
%copy-done.2 = f32[2,3]{1,0:S(2)} copy-done((f32[2,3]{1,0:S(2)}, f32[2,3]{1,0:S(1)}, u32[]) %copy-start.2)
ROOT %tuple = (f32[], f32[2,3]{1,0:S(2)}) tuple(f32[] %copy-done.1, f32[2,3]{1,0:S(2)} %copy-done.2)
}
)"
},
// send and recv
{
"SendRecv",
R"(HloModule TwoSendRecvBothWayRecvFist_module
ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> (f32[], token[]) {
%token0 = token[] after-all()
%recv = (f32[], u32[], token[]) recv(token[] %token0), channel_id=15, sharding={maximal device=1}
ROOT %recv-done = (f32[], token[]) recv-done((f32[], u32[], token[]) %recv), channel_id=15, sharding={maximal device=1}
%constant = f32[] constant(2.1), sharding={maximal device=0}
%send = (f32[], u32[], token[]) send(f32[] %constant, token[] %token0), channel_id=16, sharding={maximal device=0}, control-predecessors={%recv}
%send-done = token[] send-done((f32[], u32[], token[]) %send), channel_id=16, sharding={maximal device=0}
}
)"
},
{
"SendRecvWithHostTransfer",
R"(HloModule HostTransferSendRecv_module
ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> (f32[], token[]) {
%token0 = token[] after-all()
%recv = (f32[], u32[], token[]) recv(token[] %token0), channel_id=15, is_host_transfer=true
ROOT %recv-done = (f32[], token[]) recv-done((f32[], u32[], token[]) %recv), channel_id=15, is_host_transfer=true
%constant = f32[] constant(2.1), sharding={maximal device=0}
%send = (f32[], u32[], token[]) send(f32[] %constant, token[] %token0), channel_id=16, is_host_transfer=true
%send-done = token[] send-done((f32[], u32[], token[]) %send), channel_id=16, is_host_transfer=true
}
)"
},
// get-tuple-element
{
"GetTupleElement",
R"(HloModule GetTupleElement_module
ENTRY %GetTupleElement.v4 () -> s32[2,3] {
%constant = f32[3]{0} constant({1, 2, 3})
%constant.1 = s32[2,3]{1,0} constant({ { 1, 2, 3 }, { 4, 5, 6 } })
%tuple = (f32[3]{0}, s32[2,3]{1,0}) tuple(f32[3]{0} %constant, s32[2,3]{1,0} %constant.1)
ROOT %get-tuple-element = s32[2,3]{1,0} get-tuple-element((f32[3]{0}, s32[2,3]{1,0}) %tuple), index=1, sharding={maximal device=0}
}
)"
},
// call
{
"Call",
R"(HloModule CallR0F32IdentityScalar_module
%Identity.v1 (x: f32[]) -> f32[] {
ROOT %x = f32[] parameter(0)
}
ENTRY %CallR0F32IdentityScalar.v2 () -> f32[] {
%constant = f32[] constant(42)
ROOT %call = f32[] call(f32[] %constant), to_apply=%Identity.v1
}
)"
},
// CustomCall with backend_config.
{
"CustomCallWithOpaque",
R"(HloModule custom_call
ENTRY %CustomCall () -> f32[1,2,3] {
%constant = f32[1]{0} constant({12345})
ROOT %custom-call = f32[1,2,3]{0,2,1} custom-call(f32[1]{0} %constant), custom_call_target="foo\"bar", backend_config="this string is opaque"
}
)"
},
// CustomCall with literal.
{
"CustomCallWithLiteral",
R"(HloModule custom_call
ENTRY %CustomCall () -> f32[1,2,3] {
%constant = f32[1]{0} constant({12345})
ROOT %custom-call = f32[1,2,3]{0,2,1} custom-call(f32[1]{0} %constant), custom_call_target="foo\"bar", literal=s32[2]{0} {1, 2}
}
)"
},
// CustomCall with literal tuple.
{
"CustomCallWithLiteralTuple",
R"(HloModule custom_call
ENTRY %CustomCall () -> f32[1,2,3] {
%constant = f32[1]{0} constant({12345})
ROOT %custom-call = f32[1,2,3]{0,2,1} custom-call(f32[1]{0} %constant), custom_call_target="foo\"bar", literal=( s32[4]{0} {4, 128, 128, 3}, pred[4]{0} {1, 0, 0, 0} )
}
)"
},
// CustomCall with literal R0.
{
"CustomCallWithLiteralR0",
R"(HloModule custom_call
ENTRY %CustomCall () -> f32[1,2,3] {
%constant = f32[1]{0} constant({12345})
ROOT %custom-call = f32[1,2,3]{0,2,1} custom-call(f32[1]{0} %constant), custom_call_target="foo\"bar", literal=f32[] 0.1
}
)"
},
// reduce window
{
"ReduceWindow",
R"(HloModule R4UnitWindow_module
%add_F32.v3 (lhs: f32[], rhs: f32[]) -> f32[] {
%lhs = f32[] parameter(0)
%rhs = f32[] parameter(1)
ROOT %add = f32[] add(f32[] %lhs, f32[] %rhs)
}
ENTRY %R4UnitWindow.v3 (operand: f32[13,12,8,15]) -> f32[13,3,8,15] {
%operand = f32[13,12,8,15]{0,3,2,1} parameter(0)
%constant = f32[] constant(0)
ROOT %reduce-window = f32[13,3,8,15]{0,3,2,1} reduce-window(f32[13,12,8,15]{0,3,2,1} %operand, f32[] %constant), window={size=1x1x7x1 stride=1x4x1x1 pad=0_0x0_0x3_3x0_0}, to_apply=%add_F32.v3
}
)"
},
// reduce window on scalar
{
"ReduceWindowScalar",
R"(HloModule reduce_window_scalar
%add_F32.v3 (lhs: f32[], rhs: f32[]) -> f32[] {
%lhs = f32[] parameter(0)
%rhs = f32[] parameter(1)
ROOT %add = f32[] add(f32[] %lhs, f32[] %rhs)
}
ENTRY %R4UnitWindowScalar () -> f32[] {
%constant = f32[] constant(42)
%constant.1 = f32[] constant(1)
ROOT %reduce-window = f32[] reduce-window(f32[] %constant, f32[] %constant.1), to_apply=%add_F32.v3
}
)"
},
// reduce window on scalar
{
"ReduceWindowVariadic",
R"(HloModule reduce_window_variadic
%add_F32.v3 (lhs1: f32[], lhs2: f32[], rhs1: f32[], rhs2: f32[]) -> (f32[], f32[]) {
%lhs1 = f32[] parameter(0)
%rhs1 = f32[] parameter(2)
%add1 = f32[] add(f32[] %lhs1, f32[] %rhs1)
%lhs2 = f32[] parameter(1)
%rhs2 = f32[] parameter(3)
%add2 = f32[] add(f32[] %lhs2, f32[] %rhs2)
ROOT %tuple1 = (f32[], f32[]) tuple(f32[] %add1, f32[] %add2)
}
ENTRY %R4UnitWindowScalar () -> (f32[], f32[]) {
%constant = f32[] constant(42)
%constant.1 = f32[] constant(1)
ROOT %reduce-window = (f32[], f32[]) reduce-window(f32[] %constant, f32[] %constant, f32[] %constant.1, f32[] %constant.1), to_apply=%add_F32.v3
}
)"
},
// convolution
{
"Convolution",
R"(HloModule Convolve1D1Window_0_module
ENTRY %Convolve1D1Window_0.v3 (input: f32[1,2,1], filter: f32[1,1,1]) -> f32[1,2,1] {
%input = f32[1,2,1]{2,1,0} parameter(0)
%copy = f32[1,2,1]{2,0,1} copy(f32[1,2,1]{2,1,0} %input)
%filter = f32[1,1,1]{2,1,0} parameter(1)
ROOT %convolution = f32[1,2,1]{2,0,1} convolution(f32[1,2,1]{2,0,1} %copy, f32[1,1,1]{2,1,0} %filter), window={size=1}, dim_labels=b0f_0io->b0f, operand_precision={high,default}
}
)"
},
// convolution dynamic
{
"ConvolutionDynamic",
R"(HloModule Convolve1D1Window_0_module
ENTRY %Convolve1D1Window_0.v3 (input: f32[1,2,1], filter: f32[1,1,1]) -> f32[1,2,1] {
%input = f32[1,2,1]{2,1,0} parameter(0)
%copy = f32[1,2,1]{2,0,1} copy(f32[1,2,1]{2,1,0} %input)
%filter = f32[1,1,1]{2,1,0} parameter(1)
ROOT %custom-call.52 = f32[1,2,1]{2,0,1} custom-call(f32[1,2,1]{2,0,1} %copy, f32[1,1,1]{2,1,0} %filter), window={size=1}, dim_labels=b0f_0io->b0f, operand_precision={high,default}, custom_call_target="DynamicConvolutionForward", metadata={op_type="Conv2D" op_name="conv1d"}
}
)"
},
// convolution rank 2
{
"ConvolutionR2",
R"(HloModule ConvolveR2_module
ENTRY %ConvolveR2.v3 (input: f32[1,2], filter: f32[2,2]) -> f32[1,2] {
%input = f32[1,2]{1,0} parameter(0)
%filter = f32[2,2]{1,0} parameter(1)
ROOT %convolution = f32[1,2]{0,1} convolution(f32[1,2]{1,0} %input, f32[2,2]{1,0} %filter), dim_labels=bf_io->bf
}
)"
},
// convolution backward
{
"ConvolutionBackward",
R"(HloModule ConvolveBackward_module
ENTRY %ConvolveBackward (input: f32[128,7,7,512], filter: f32[3,3,512,512]) -> f32[128,14,14,512] {
%input = f32[128,7,7,512]{0,3,2,1} parameter(0)
%filter = f32[3,3,512,512]{3,2,1,0} parameter(1)
ROOT %convolution-base-dilated = f32[128,14,14,512]{0,3,2,1} convolution(f32[128,7,7,512]{0,3,2,1} %input, f32[3,3,512,512]{3,2,1,0} %filter), window={size=3x3 pad=1_2x1_2 lhs_dilate=2x2 rhs_reversal=1x1}, dim_labels=b01f_01oi->b01f
}
)"
},
// reverse(constant)
{
"Reverse4D",
R"(HloModule Reverse4DFloatArrayOnDim01_module
ENTRY %Reverse4DFloatArrayOnDim01.v2 () -> f32[4,3,2,1] {
%constant = f32[4,3,2,1]{0,1,2,3} constant({ { /*i0=0*/ { /*i1=0*/ {1}, {2} }, { /*i1=1*/ {3}, {4} }, { /*i1=2*/ {5}, {6} } }, { /*i0=1*/ { /*i1=0*/ {7}, {8} }, { /*i1=1*/ {9}, {10} }, { /*i1=2*/ {11}, {12} } }, { /*i0=2*/ { /*i1=0*/ {13}, {14} }, { /*i1=1*/ {15}, {16} }, { /*i1=2*/ {17}, {18} } }, { /*i0=3*/ { /*i1=0*/ {19}, {20} }, { /*i1=1*/ {21}, {22} }, { /*i1=2*/ {23}, {24} } } })
ROOT %reverse = f32[4,3,2,1]{0,1,2,3} reverse(f32[4,3,2,1]{0,1,2,3} %constant), dimensions={0,1}
}
)"
},
// concat
{
"Concat",
R"(HloModule Concat2x3With2x5_module
ENTRY %Concat2x3With2x5.v3 () -> f32[2,8] {
%constant = f32[2,3]{1,0} constant({ { 0, 1, 2 }, { 1000, 1001, 1002 } })
%constant.1 = f32[2,5]{1,0} constant({ { 64, 65, 66, 67, 68 }, { 1064, 1065, 1066, 1067, 1068 } })
ROOT %concatenate = f32[2,8]{1,0} concatenate(f32[2,3]{1,0} %constant, f32[2,5]{1,0} %constant.1), dimensions={1}
}
)"
},
// select and scatter
{
"SelectAndScatter",
R"(HloModule R4F32OverlapSmall_module
%ge_F32.v3 (lhs: f32[], rhs: f32[]) -> pred[] {
%lhs = f32[] parameter(0)
%rhs = f32[] parameter(1)
ROOT %greater-than-or-equal-to = pred[] compare(f32[] %lhs, f32[] %rhs), direction=GE, type=TOTALORDER
}
%add_F32.v3 (lhs.1: f32[], rhs.1: f32[]) -> f32[] {
%lhs.1 = f32[] parameter(0)
%rhs.1 = f32[] parameter(1)
ROOT %add = f32[] add(f32[] %lhs.1, f32[] %rhs.1)
}
ENTRY %R4F32OverlapSmall.v4 () -> f32[4,5,1,1] {
%constant = f32[4,5,1,1]{3,2,1,0} constant({ { /*i0=0*/ { /*i1=0*/ {7} }, { /*i1=1*/ {2} }, { /*i1=2*/ {5} }, { /*i1=3*/ {3} }, { /*i1=4*/ {8} } }, { /*i0=1*/ { /*i1=0*/ {3} }, { /*i1=1*/ {8} }, { /*i1=2*/ {9} }, { /*i1=3*/ {3} }, { /*i1=4*/ {4} } }, { /*i0=2*/ { /*i1=0*/ {1} }, { /*i1=1*/ {5} }, { /*i1=2*/ {7} }, { /*i1=3*/ {5} }, { /*i1=4*/ {6} } }, { /*i0=3*/ { /*i1=0*/ {0} }, { /*i1=1*/ {6} }, { /*i1=2*/ {2} }, { /*i1=3*/ {10} }, { /*i1=4*/ {2} } } })
%constant.1 = f32[2,2,1,1]{3,2,1,0} constant({ { /*i0=0*/ { /*i1=0*/ {2} }, { /*i1=1*/ {6} } }, { /*i0=1*/ { /*i1=0*/ {3} }, { /*i1=1*/ {1} } } })
%constant.2 = f32[] constant(0)
ROOT %select-and-scatter = f32[4,5,1,1]{3,2,1,0} select-and-scatter(f32[4,5,1,1]{3,2,1,0} %constant, f32[2,2,1,1]{3,2,1,0} %constant.1, f32[] %constant.2), window={size=2x3x1x1 stride=2x2x1x1}, select=%ge_F32.v3, scatter=%add_F32.v3
}
)"
},
// select and scatter on scalar
{
"SelectAndScatterScalar",
R"(HloModule select_and_scatter_scalar
%ge_F32.v3 (lhs: f32[], rhs: f32[]) -> pred[] {
%lhs = f32[] parameter(0)
%rhs = f32[] parameter(1)
ROOT %greater-than-or-equal-to = pred[] compare(f32[] %lhs, f32[] %rhs), direction=GE
}
%add_F32.v3 (lhs.1: f32[], rhs.1: f32[]) -> f32[] {
%lhs.1 = f32[] parameter(0)
%rhs.1 = f32[] parameter(1)
ROOT %add = f32[] add(f32[] %lhs.1, f32[] %rhs.1)
}
ENTRY %SelectAndScatterScalar () -> f32[] {
%constant = f32[] constant(42)
%constant.1 = f32[] constant(1)
%constant.2 = f32[] constant(2)
ROOT %select-and-scatter = f32[] select-and-scatter(f32[] %constant, f32[] %constant.1, f32[] %constant.2), select=%ge_F32.v3, scatter=%add_F32.v3
}
)"
},
// slice
{
"Slice",
R"(HloModule slice_module
ENTRY %slice.v2 (p0: f32[3,3,4,4]) -> f32[3,3,2,4] {
%p0 = f32[3,3,4,4]{3,2,1,0} parameter(0)
ROOT %slice = f32[3,3,2,4]{3,2,1,0} slice(f32[3,3,4,4]{3,2,1,0} %p0), slice={[0:3:1], [0:3:1], [0:4:2], [0:4:1]}
}
)"
},
// slice, no stride
{
"SliceNoStride",
R"(HloModule Slice3x3x3_To_1x3x3_F32_module
ENTRY %Slice3x3x3_To_1x3x3_F32.v2 () -> f32[1,3,3] {
%constant = f32[3,3,3]{2,1,0} constant({ { { 0, 1, 2 }, { 3, 4, 5 }, { 6, 7, 8 } }, { { 9, 10, 11 }, { 12, 13, 14 }, { 15, 16, 17 } }, { { 18, 19, 20 }, { 21, 22, 23 }, { 24, 25, 26 } } })
ROOT %slice = f32[1,3,3]{2,1,0} slice(f32[3,3,3]{2,1,0} %constant), slice={[0:1], [0:3], [0:3]}
}
)"
},
// slice R0
{
"SliceR0",
R"(HloModule SliceR0_module
ENTRY %SliceR0.v2 () -> s32[] {
%constant = s32[] constant(1)
ROOT %slice = s32[] slice(s32[] %constant), slice={}
}
)"
},
// transpose
{
"Transpose",
R"(HloModule Transpose_module
ENTRY %Transpose.v2 () -> s32[1,2,3] {
%constant = s32[1,2,3]{2,1,0} constant({ { { 1, 2, 3 }, { 4, 5, 6 } } })
ROOT %transpose = s32[1,2,3]{2,1,0} transpose(s32[1,2,3]{2,1,0} %constant), dimensions={0,1,2}
}
)"
},
{
"TransposeC128",
R"(HloModule TransposeC128_module
ENTRY %Transpose.v3 (input: c128[1,2,3]) -> c128[1,2,3] {
%input = c128[1,2,3]{2,1,0} parameter(0)
ROOT %transpose = c128[1,2,3]{2,1,0} transpose(c128[1,2,3]{2,1,0} %input), dimensions={0,1,2}
}
)"
},
// Triangular solve
{
"TriangularSolve",
R"(HloModule TriangularSolve_module
ENTRY %SimpleRightLowerNotranspose.4 (a.1: f32[4,4], b.2: f32[3,4]) -> f32[3,4] {
%a.1 = f32[4,4]{1,0} parameter(0)
%b.2 = f32[3,4]{1,0} parameter(1)
ROOT %triangular-solve.3 = f32[3,4]{1,0} triangular-solve(f32[4,4]{1,0} %a.1, f32[3,4]{1,0} %b.2), lower=true, transpose_a=NO_TRANSPOSE
}
)"
},
// Dynamic slice
{
"DynamicSlice",
R"(HloModule DynamicSlice_module
ENTRY %DynamicSlice.v5 (original_parameter: s32[2,2,258], start_index: s32[1]) -> s32[2,2,258] {
%original_parameter = s32[2,2,258]{2,1,0} parameter(0)
%constant = s32[1]{0} constant({0})
%start_index = s32[1]{0} parameter(1)
%concatenate = s32[3]{0} concatenate(s32[1]{0} %constant, s32[1]{0} %constant, s32[1]{0} %start_index), dimensions={0}
ROOT %dynamic-slice = s32[2,2,258]{2,1,0} dynamic-slice(s32[2,2,258]{2,1,0} %original_parameter, s32[3]{0} %concatenate), dynamic_slice_sizes={2,2,258}
}
)"
},
// Dynamic slice with scalar indices
{
"DynamicSliceScalarIndices",
R"(HloModule DynamicSlice_module
ENTRY %DynamicSlice.v5 (original_parameter: s32[2,2,258], start_index: s32[]) -> s32[2,2,258] {
%original_parameter = s32[2,2,258]{2,1,0} parameter(0)
%constant = s32[] constant(0)
%start_index = s32[] parameter(1)
ROOT %dynamic-slice = s32[2,2,258]{2,1,0} dynamic-slice(s32[2,2,258]{2,1,0} %original_parameter, s32[] %constant, s32[] %constant, s32[] %start_index), dynamic_slice_sizes={2,2,258}
}
)"
},
// Dynamic update slice
{
"DynamicUpdateSlice",
R"(HloModule DynamicSlice_module
ENTRY %DynamicUpdateSlice.v4 (input: s32[1,1,25,1], update: s32[1,1,2,1], start_indices: s32[4]) -> s32[1,1,25,1] {
%input = s32[1,1,25,1]{3,2,1,0} parameter(0)
%update = s32[1,1,2,1]{3,2,1,0} parameter(1)
%start_indices = s32[4]{0} parameter(2)
ROOT %dynamic-update-slice = s32[1,1,25,1]{3,2,1,0} dynamic-update-slice(s32[1,1,25,1]{3,2,1,0} %input, s32[1,1,2,1]{3,2,1,0} %update, s32[4]{0} %start_indices)
}
)"
},
// Dynamic update slice with scalar indices
{
"DynamicUpdateSliceScalarIndex",
R"(HloModule DynamicUpdateSlice_module
ENTRY %DynamicUpdateSlice.v4 (input: s32[1,1,25,1], update: s32[1,1,2,1], start_index.0: s32[], start_index.1: s32[], start_index.2: s32[], start_index.3: s32[]) -> s32[1,1,25,1] {
%input = s32[1,1,25,1]{3,2,1,0} parameter(0)
%update = s32[1,1,2,1]{3,2,1,0} parameter(1)
%start_index.0 = s32[] parameter(2)
%start_index.1 = s32[] parameter(3)
%start_index.2 = s32[] parameter(4)
%start_index.3 = s32[] parameter(5)
ROOT %dynamic-update-slice = s32[1,1,25,1]{3,2,1,0} dynamic-update-slice(s32[1,1,25,1]{3,2,1,0} %input, s32[1,1,2,1]{3,2,1,0} %update, s32[] %start_index.0, s32[] %start_index.1, s32[] %start_index.2, /*index=5*/s32[] %start_index.3)
}
)"
},
// batch norm training
{
"BatchNormTraining",
R"(HloModule BasicTraining_module
ENTRY %BasicTraining.v4 () -> (f32[2,2,1,2], f32[2], f32[2]) {
%constant = f32[2,2,1,2]{3,2,1,0} constant({ { /*i0=0*/ { /*i1=0*/ { 1, 2 } }, { /*i1=1*/ { 3, 4 } } }, { /*i0=1*/ { /*i1=0*/ { 5, 6 } }, { /*i1=1*/ { 7, 8 } } } })
%constant.1 = f32[2]{0} constant({2, 3})
%constant.2 = f32[2]{0} constant({1, 2})
ROOT %batch-norm-training = (f32[2,2,1,2]{3,2,1,0}, f32[2]{0}, f32[2]{0}) batch-norm-training(f32[2,2,1,2]{3,2,1,0} %constant, f32[2]{0} %constant.1, f32[2]{0} %constant.2), epsilon=0.001, feature_index=3
}
)"
},
// batch norm inference
{
"BatchNormInference",
R"(HloModule BatchNormInference_module
ENTRY %BatchNormInference.v6 (input: f32[2,2,2,2], offset: f32[2], scale: f32[2], mean: f32[2], variance: f32[2]) -> f32[2,2,2,2] {
%input = f32[2,2,2,2]{3,2,1,0} parameter(0)
%offset = f32[2]{0} parameter(1)
%scale = f32[2]{0} parameter(2)
%mean = f32[2]{0} parameter(3)
%variance = f32[2]{0} parameter(4)
ROOT %batch-norm-inference = f32[2,2,2,2]{3,2,1,0} batch-norm-inference(f32[2,2,2,2]{3,2,1,0} %input, f32[2]{0} %offset, f32[2]{0} %scale, f32[2]{0} %mean, f32[2]{0} %variance), epsilon=0.001, feature_index=0
}
)"
},
// batch norm grad
{
"BatchNormGrad",
R"(HloModule BatchNormGrad_module
ENTRY %BatchNormGrad.v4 (input: f32[2,2,2,2], scale: f32[2], mean: f32[2], variance: f32[2], grad_output: f32[2,2,2,2]) -> (f32[2,2,2,2], f32[2], f32[2]) {
%input = f32[2,2,2,2]{3,2,1,0} parameter(0)
%scale = f32[2]{0} parameter(1)
%mean = f32[2]{0} parameter(2)
%variance = f32[2]{0} parameter(3)
%grad_output = f32[2,2,2,2]{3,2,1,0} parameter(4)
ROOT %batch-norm-grad = (f32[2,2,2,2]{3,2,1,0}, f32[2]{0}, f32[2]{0}) batch-norm-grad(f32[2,2,2,2]{3,2,1,0} %input, f32[2]{0} %scale, f32[2]{0} %mean, f32[2]{0} %variance, f32[2,2,2,2]{3,2,1,0} %grad_output), epsilon=0.001, feature_index=0
}
)"
},
// fft
{
"Fft",
R"(HloModule Fft_module
ENTRY %Fft (input: c64[8,32]) -> c64[8,32] {
%input = c64[8,32]{1,0} parameter(0)
ROOT %fft = c64[8,32]{1,0} fft(c64[8,32]{1,0} %input), fft_type=FFT, fft_length={32}
}
)"
},
// ifft
{
"Ifft2d",
R"(HloModule Ifft2d_module
ENTRY %Ifft2d (input: c64[5,8,32]) -> c64[5,8,32] {
%input = c64[5,8,32]{2,1,0} parameter(0)
ROOT %fft = c64[5,8,32]{2,1,0} fft(c64[5,8,32]{2,1,0} %input), fft_type=IFFT, fft_length={8,32}
}
)"
},
// rfft2d
{
"Rfft2d",
R"(HloModule Rfft2d_module
ENTRY %Rfft2d (input: f32[5,64,32]) -> c64[5,64,17] {
%input = f32[5,64,32]{2,1,0} parameter(0)
ROOT %fft = c64[5,64,17]{2,1,0} fft(f32[5,64,32]{2,1,0} %input), fft_type=RFFT, fft_length={64,32}
}
)"
},
// irfft3d
{
"Irfft3d",
R"(HloModule Irfft3d_module
ENTRY %Irfft3d (input: c64[5,64,128,33]) -> f32[5,64,128,64] {
%input = c64[5,64,128,33]{3,2,1,0} parameter(0)
ROOT %fft = f32[5,64,128,64]{3,2,1,0} fft(c64[5,64,128,33]{3,2,1,0} %input), fft_type=IRFFT, fft_length={64,128,64}
}
)"
},
// pad
{
"Pad",
R"(HloModule Pad1DS3Array_module
ENTRY %Pad1DS3Array.v3 () -> f32[7] {
%constant = f32[3]{0} constant({1, 2, 3})
%constant.1 = f32[] constant(0.1)
ROOT %pad = f32[7]{0} pad(f32[3]{0} %constant, f32[] %constant.1), padding=3_1
}
)"
},
// pad has interior
{
"PadHasInterior",
R"(HloModule PadHasInterior_module
ENTRY %PadHasInterior.v3 (input: f32[1,25,7,7]) -> f32[1,25,17,11] {
%input = f32[1,25,7,7]{3,2,1,0} parameter(0)
%constant = f32[] constant(-5.123)
ROOT %pad = f32[1,25,17,11]{3,2,1,0} pad(f32[1,25,7,7]{3,2,1,0} %input, f32[] %constant), padding=0_0_0x0_0_0x2_2_1x2_2_0
}
)"
},
// Negative padding
{
"PadHasNegativePadding",
R"(HloModule PadHasNegativePadding_module
ENTRY %PadHasNegativePadding (input: f32[1,25,7,7,10]) -> f32[1,15,6,3,35] {
%input = f32[1,25,7,7,10]{4,3,2,1,0} parameter(0)
%constant = f32[] constant(-5.123)
ROOT %pad = f32[1,15,6,3,35]{4,3,2,1,0} pad(f32[1,25,7,7,10]{4,3,2,1,0} %input, f32[] %constant), padding=0_0_0x0_-10_0x0_-1_0x-2_-2_0x-1_-1_3
}
)"
},
// fusion
{
"Fusion",
R"(HloModule fusion_module
%fused_computation (constant.param_0: f32[3,2,1,1], constant.1.param_1: f32[2]) -> f32[3,2,1,1] {
%constant.param_0 = f32[3,2,1,1]{3,2,1,0} parameter(0)
%constant.1.param_1 = f32[2]{0} parameter(1)
%broadcast = f32[3,2,1,1]{3,2,1,0} broadcast(f32[2]{0} %constant.1.param_1), dimensions={1}
ROOT %subtract = f32[3,2,1,1]{3,2,1,0} subtract(f32[3,2,1,1]{3,2,1,0} %constant.param_0, f32[3,2,1,1]{3,2,1,0} %broadcast)
}
ENTRY %fusion.v3 () -> f32[3,2,1,1] {
%constant = f32[3,2,1,1]{3,2,1,0} constant({ { /*i0=0*/ { /*i1=0*/ {-1} }, { /*i1=1*/ {4.1} } }, { /*i0=1*/ { /*i1=0*/ {2} }, { /*i1=1*/ {4.1} } }, { /*i0=2*/ { /*i1=0*/ {5} }, { /*i1=1*/ {4.4} } } })
%constant.1 = f32[2]{0} constant({3.14, 4.25})
ROOT %fusion = f32[3,2,1,1]{3,2,1,0} fusion(f32[3,2,1,1]{3,2,1,0} %constant, f32[2]{0} %constant.1), kind=kLoop, calls=%fused_computation
}
)"
},
{
"Gather",
R"(HloModule StringifyGather
ENTRY %Gather (input_tensor: f32[50,49,48,47,46], start_indices: s64[10,9,8,7,5]) -> f32[10,9,8,7,30,29,28,27,26] {
%input_tensor = f32[50,49,48,47,46]{4,3,2,1,0} parameter(0)
%start_indices = s64[10,9,8,7,5]{4,3,2,1,0} parameter(1)
ROOT %gather = f32[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} gather(f32[50,49,48,47,46]{4,3,2,1,0} %input_tensor, s64[10,9,8,7,5]{4,3,2,1,0} %start_indices), offset_dims={4,5,6,7,8}, collapsed_slice_dims={}, start_index_map={0,1,2,3,4}, index_vector_dim=4, slice_sizes={30,29,28,27,26}
}
)"
},
{
"SortedGather",
R"(HloModule StringifyGather
ENTRY %Gather (input_tensor: f32[50,49,48,47,46], start_indices: s64[10,9,8,7,5]) -> f32[10,9,8,7,30,29,28,27,26] {
%input_tensor = f32[50,49,48,47,46]{4,3,2,1,0} parameter(0)
%start_indices = s64[10,9,8,7,5]{4,3,2,1,0} parameter(1)
ROOT %gather = f32[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} gather(f32[50,49,48,47,46]{4,3,2,1,0} %input_tensor, s64[10,9,8,7,5]{4,3,2,1,0} %start_indices), offset_dims={4,5,6,7,8}, collapsed_slice_dims={}, start_index_map={0,1,2,3,4}, index_vector_dim=4, slice_sizes={30,29,28,27,26}, indices_are_sorted=true
}
)"
},
{
"Scatter",
R"(HloModule StringifyScatter
%add_F32.v3 (lhs: f32[], rhs: f32[]) -> f32[] {
%lhs = f32[] parameter(0)
%rhs = f32[] parameter(1)
ROOT %add = f32[] add(f32[] %lhs, f32[] %rhs)
}
ENTRY %Scatter (input_tensor: f32[50,49,48,47,46], scatter_indices: s64[10,9,8,7,5], updates: f32[10,9,8,7,30,29,28,27,26]) -> f32[50,49,48,47,46] {
%input_tensor = f32[50,49,48,47,46]{4,3,2,1,0} parameter(0)
%scatter_indices = s64[10,9,8,7,5]{4,3,2,1,0} parameter(1)
%updates = f32[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} parameter(2)
ROOT %scatter = f32[50,49,48,47,46]{4,3,2,1,0} scatter(f32[50,49,48,47,46]{4,3,2,1,0} %input_tensor, s64[10,9,8,7,5]{4,3,2,1,0} %scatter_indices, f32[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} %updates), update_window_dims={4,5,6,7,8}, inserted_window_dims={}, scatter_dims_to_operand_dims={0,1,2,3,4}, index_vector_dim=4, to_apply=%add_F32.v3
}
)"
},
{
"SortedScatter",
R"(HloModule StringifySortedScatter
%add_F32.v3 (lhs: f32[], rhs: f32[]) -> f32[] {
%lhs = f32[] parameter(0)
%rhs = f32[] parameter(1)
ROOT %add = f32[] add(f32[] %lhs, f32[] %rhs)
}
ENTRY %Scatter (input_tensor: f32[50,49,48,47,46], scatter_indices: s64[10,9,8,7,5], updates: f32[10,9,8,7,30,29,28,27,26]) -> f32[50,49,48,47,46] {
%input_tensor = f32[50,49,48,47,46]{4,3,2,1,0} parameter(0)
%scatter_indices = s64[10,9,8,7,5]{4,3,2,1,0} parameter(1)
%updates = f32[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} parameter(2)
ROOT %scatter = f32[50,49,48,47,46]{4,3,2,1,0} scatter(f32[50,49,48,47,46]{4,3,2,1,0} %input_tensor, s64[10,9,8,7,5]{4,3,2,1,0} %scatter_indices, f32[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} %updates), update_window_dims={4,5,6,7,8}, inserted_window_dims={}, scatter_dims_to_operand_dims={0,1,2,3,4}, index_vector_dim=4, indices_are_sorted=true, to_apply=%add_F32.v3
}
)"
},
{
"UniqueIndicesScatter",
R"(HloModule StringifyUniqueIndicesScatter
%add_F32.v3 (lhs: f32[], rhs: f32[]) -> f32[] {
%lhs = f32[] parameter(0)
%rhs = f32[] parameter(1)
ROOT %add = f32[] add(f32[] %lhs, f32[] %rhs)
}
ENTRY %Scatter (input_tensor: f32[50,49,48,47,46], scatter_indices: s64[10,9,8,7,5], updates: f32[10,9,8,7,30,29,28,27,26]) -> f32[50,49,48,47,46] {
%input_tensor = f32[50,49,48,47,46]{4,3,2,1,0} parameter(0)
%scatter_indices = s64[10,9,8,7,5]{4,3,2,1,0} parameter(1)
%updates = f32[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} parameter(2)
ROOT %scatter = f32[50,49,48,47,46]{4,3,2,1,0} scatter(f32[50,49,48,47,46]{4,3,2,1,0} %input_tensor, s64[10,9,8,7,5]{4,3,2,1,0} %scatter_indices, f32[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} %updates), update_window_dims={4,5,6,7,8}, inserted_window_dims={}, scatter_dims_to_operand_dims={0,1,2,3,4}, index_vector_dim=4, unique_indices=true, to_apply=%add_F32.v3
}
)"
},
{
"ConstantUnsignedNoUnderflow",
R"(HloModule ConstantUnsignedNoUnderflow_module
ENTRY %ConstantUnsignedNoUnderflow () -> u64[] {
ROOT %constant = u64[] constant(1)
}
)"
},
{
"ConstantUnsignedNoOverflow",
R"(HloModule ConstantUnsignedNoOverflow_module
ENTRY %ConstantUnsignedNoOverflow () -> u64[] {
ROOT %constant = u64[] constant(9223372036854775807)
}
)"
},
// CustomCallWithLayoutConstraints
{
"CustomCallWithLayoutConstraints",
R"(HloModule CustomCallWithLayoutConstraints
ENTRY %CustomCallWithLayoutConstraints (p0: f32[42,2,3], p1: f32[123,4]) -> f32[1,2,3] {
%p0 = f32[42,2,3]{0,1,2} parameter(0)
%p1 = f32[123,4]{0,1} parameter(1)
ROOT %custom-call = f32[1,2,3]{0,2,1} custom-call(f32[42,2,3]{0,1,2} %p0, f32[123,4]{0,1} %p1), custom_call_target="baz", operand_layout_constraints={f32[42,2,3]{0,1,2}, f32[123,4]{1,0}}
}
)"
},
// CustomCallWithLayoutConstraintsNoOperands
{
"CustomCallWithLayoutConstraintsNoOperands",
R"(HloModule CustomCallWithLayoutConstraintsNoOperands
ENTRY %CustomCallWithLayoutConstraints () -> f32[1,2,3] {
ROOT %custom-call = f32[1,2,3]{0,2,1} custom-call(), custom_call_target="baz", operand_layout_constraints={}
}
)"
},
// CustomCallWithLayoutConstraintsTupleShapes
{
"CustomCallWithLayoutConstraintsTupleShapes",
R"(HloModule CustomCallWithLayoutConstraintsTupleShapes
ENTRY %CustomCallWithLayoutConstraints (p0: (f32[2,2], f32[42,2,3]), p1: f32[123,4]) -> (f32[1,2,3], f32[1,2,3]) {
%p0 = (f32[2,2]{0,1}, f32[42,2,3]{0,1,2}) parameter(0)
%p1 = f32[123,4]{0,1} parameter(1)
ROOT %custom-call = (f32[1,2,3]{0,2,1}, f32[1,2,3]{1,2,0}) custom-call((f32[2,2]{0,1}, f32[42,2,3]{0,1,2}) %p0, f32[123,4]{0,1} %p1), custom_call_target="baz", operand_layout_constraints={(f32[2,2]{1,0}, f32[42,2,3]{2,0,1}), f32[123,4]{1,0}}
}
)"
},
// CustomCallWithHasSideEffect
{
"CustomCallWithHasSideEffect",
R"(HloModule CustomCallWithHasSideEffect
ENTRY %CustomCallWithHasSideEffect (p0: (f32[2,2], f32[42,2,3]), p1: f32[123,4]) -> (f32[1,2,3], f32[1,2,3]) {
%p0 = (f32[2,2]{0,1}, f32[42,2,3]{0,1,2}) parameter(0)
%p1 = f32[123,4]{0,1} parameter(1)
ROOT %custom-call = (f32[1,2,3]{0,2,1}, f32[1,2,3]{1,2,0}) custom-call((f32[2,2]{0,1}, f32[42,2,3]{0,1,2}) %p0, f32[123,4]{0,1} %p1), custom_call_target="baz", custom_call_has_side_effect=true
}
)"
},
// CustomCallWithAliasing
{
"CustomCallWithAliasing",
R"(HloModule CustomCallWithAliasing
ENTRY %CustomCallWithAliasing (p0: (f32[2,2], f32[42,2,3]), p1: f32[123,4]) -> (f32[123,4], f32[2,2], f32[1,2,3]) {
%p0 = (f32[2,2]{0,1}, f32[42,2,3]{0,1,2}) parameter(0)
%p1 = f32[123,4]{0,1} parameter(1)
ROOT %custom-call = (f32[123,4]{0,1}, f32[2,2]{0,1}, f32[1,2,3]{0,1,2}) custom-call((f32[2,2]{0,1}, f32[42,2,3]{0,1,2}) %p0, f32[123,4]{0,1} %p1), custom_call_target="baz", output_to_operand_aliasing={{0}: (1, {}), {1}: (0, {0})}
}
)"
},
// CustomCall with schedule.
{
"CustomCallWithSchedule",
R"(HloModule custom_call
ENTRY %CustomCall () -> f32[1,2,3] {
%constant = f32[1]{0} constant({12345})
%custom-call.0 = f32[1,2,3]{0,2,1} custom-call(f32[1]{0} %constant), custom_call_target="foo", schedule=SCHEDULE_EARLIEST
ROOT %custom-call.1 = f32[1,2,3]{0,2,1} custom-call(f32[1,2,3]{0,2,1} %custom-call.0), custom_call_target="bar", schedule=SCHEDULE_LATEST
}
)"
},
// Parse c64 literal
{
"ParseC64Literal",
R"(HloModule ParseC64Literal
ENTRY %ParseC64Literal () -> c64[2] {
ROOT %c = c64[2]{0} constant({(1, 2), (-inf, nan)})
}
)"
},
// Parse c128 literal
{
"ParseC128Literal",
R"(HloModule ParseC128Literal
ENTRY %ParseC128Literal () -> c128[2] {
ROOT %c = c128[2]{0} constant({(1, 2), (-inf, nan)})
}
)"
},
// Indexed Conditional
{
"IndexedConditional",
R"(HloModule indexed_conditional
%Negate (x: f32[]) -> f32[] {
%x = f32[] parameter(0)
ROOT %negate = f32[] negate(f32[] %x)
}
%Identity (y: f32[]) -> f32[] {
%y = f32[] parameter(0)
ROOT %copy = f32[] copy(f32[] %y)
}
%Floor (z: f32[]) -> f32[] {
%z = f32[] parameter(0)
ROOT %floor = f32[] floor(f32[] %z)
}
ENTRY %Parameters1.v4 () -> f32[] {
%constant = s32[] constant(1)
%constant.1 = f32[] constant(56)
%constant.2 = f32[] constant(12)
%constant.3 = f32[] constant(13)
ROOT %conditional = f32[] conditional(s32[] %constant, f32[] %constant.1, f32[] %constant.2, f32[] %constant.3), branch_computations={%Negate, %Identity, %Floor}
}
)"
},
// rng-get-and-update-state
{
"RngGetAndUpdateState",
R"(HloModule rng_get_and_update_state
ENTRY %RngGetAndUpdateState () -> u64[2] {
ROOT %rng-get-and-update-state = u64[2]{0} rng-get-and-update-state(), delta=4096
}
)"
},
{
"RngBitGenerator",
R"(HloModule gng_bit_generator
ENTRY %RngBitGenerator (p0: u64[2]) -> (u64[2], u32[11,17]) {
%p0 = u64[2]{0} parameter(0)
ROOT %rand = (u64[2]{0}, u32[11,17]{1,0}) rng-bit-generator(u64[2]{0} %p0), algorithm=rng_three_fry
}
)"
}
});
// clang-format on
}
std::vector<TestData> CreateShortTestCases() {
// clang-format off
return std::vector<TestData>({
// map
{
"Map",
R"(HloModule MapBinaryAdder_module
add_F32.v3 {
lhs = f32[] parameter(0)
rhs = f32[] parameter(1)
ROOT add = f32[] add(lhs, rhs)
}
ENTRY MapBinaryAdder.v3 {
param0 = f32[4]{0} parameter(0)
param1 = f32[4]{0} parameter(1)
ROOT map = f32[4]{0} map(param0, param1), dimensions={0}, to_apply=add_F32.v3
}
)"
},
// reduce
{
"Reduce",
R"(HloModule ReduceR3ToR2_module
add_F32.v3 {
lhs = f32[] parameter(0)
rhs = f32[] parameter(1)
ROOT add = f32[] add(lhs, rhs)
}
ENTRY ReduceR3ToR2.v3 {
input = f32[8,16,256]{2,1,0} parameter(0)
constant = f32[] constant(0)
ROOT reduce = f32[8,16]{1,0} reduce(input, constant), dimensions={2}, to_apply=add_F32.v3
}
)"
},
// tuple reduce
{
"TupleReduce",
R"(HloModule TupleReduce
max_argmax {
value = f32[] parameter(2)
prev_max = f32[] parameter(0)
is_next_larger = pred[] compare(value, prev_max), direction=GE
max = f32[] select(is_next_larger, value, prev_max)
index = s32[] parameter(3)
prev_argmax = s32[] parameter(1)
argmax = s32[] select(is_next_larger, index, prev_argmax)
ROOT pair = (f32[], s32[]) tuple(max, argmax)
}
ENTRY reduce_entry {
values = f32[1024]{0} parameter(0)
indices = s32[1024]{0} parameter(1)
init_value = f32[] constant(-inf)
init_index = s32[] constant(-1)
ROOT result = (f32[], s32[]) reduce(values, indices, init_value, init_index), dimensions={0}, to_apply=max_argmax
}
)"
},
// infeed/outfeed
{
"InfeedOutfeed",
R"(HloModule outfeed_module
ENTRY InfeedToOutfeed {
token0 = token[] after-all()
infeed = ((u32[3]{0}, pred[]), token[]) infeed(token0)
infeed.data = (u32[3]{0}, pred[]) get-tuple-element(infeed), index=0
outfeed = token[] outfeed(infeed.data, token0), outfeed_shape=(u32[3]{0}, pred[])
ROOT infeed.1 = ((u32[3]{0}, pred[]), token[]) infeed(token0)
infeed.1.data = (u32[3]{0}, pred[]) get-tuple-element(infeed.1), index=0
infeed.1.token = token[] get-tuple-element(infeed.1), index=1
outfeed.1 = token[] outfeed(infeed.1.data, infeed.1.token), outfeed_shape=(u32[3]{0}, pred[])
}
)"
},
// Rng
{
"Rng",
R"(HloModule rng_module
ENTRY Rng {
constant = f32[] constant(0)
constant.1 = f32[] constant(1)
ROOT rng = f32[8]{0} rng(constant, constant.1), distribution=rng_uniform
}
)"
},
// Reduce precision
{
"ReducePrecision",
R"(HloModule reduce_precision
ENTRY ReducePrecision {
constant = f32[1]{0} constant({3.14159})
ROOT reduce-precision = f32[1]{0} reduce-precision(constant), exponent_bits=8, mantissa_bits=10
}
)"
},
// Sort (Key)
{
"SortKey",
R"(HloModule sort
compare {
p.0.lhs = f32[] parameter(0)
p.0.rhs = f32[] parameter(1)
ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT
}
ENTRY Sort {
x = f32[1024]{0} parameter(0)
ROOT sorted = f32[1024]{0} sort(x), dimensions={0}, to_apply=compare
}
)"
},
// Sort (Key, Value)
{
"SortKeyValue",
R"(HloModule sort
compare {
p.1.lhs = s32[] parameter(2)
p.1.rhs = s32[] parameter(3)
p.0.lhs = f32[] parameter(0)
p.0.rhs = f32[] parameter(1)
ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT
}
ENTRY Sort {
keys = f32[1024]{0} parameter(0)
values = s32[1024]{0} parameter(1)
ROOT sorted = (f32[1024]{0}, s32[1024]{0}) sort(keys, values), dimensions={0}, to_apply=compare
}
)"
},
// R2 Sort (Key)
{
"SortKeyR2",
R"(HloModule sort
compare {
p.0.lhs = f32[] parameter(0)
p.0.rhs = f32[] parameter(1)
ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT
}
ENTRY Sort {
x = f32[1024,16]{0,1} parameter(0)
ROOT sorted = f32[1024,16]{0,1} sort(x), dimensions={0}, to_apply=compare
}
)"
},
// R2 Sort (Key, Value)
{
"SortKeyValueR2",
R"(HloModule sort
compare {
p.1.lhs = s32[] parameter(2)
p.1.rhs = s32[] parameter(3)
p.0.lhs = f32[] parameter(0)
p.0.rhs = f32[] parameter(1)
ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT
}
ENTRY Sort {
keys = f32[1024,16]{0,1} parameter(0)
values = s32[1024,16]{0,1} parameter(1)
ROOT sorted = (f32[1024,16]{0,1}, s32[1024,16]{0,1}) sort(keys, values), dimensions={0}, to_apply=compare
}
)"
},
// Sort (Key, Value, Value, Value)
{
"SortManyValues",
R"(HloModule sort
compare {
p.1.lhs = s32[] parameter(2)
p.1.rhs = s32[] parameter(3)
p.2.lhs = u32[] parameter(4)
p.2.rhs = u32[] parameter(5)
p.3.lhs = f32[] parameter(6)
p.3.rhs = f32[] parameter(7)
p.0.lhs = f32[] parameter(0)
p.0.rhs = f32[] parameter(1)
ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT
}
ENTRY Sort {
keys = f32[1024,16]{0,1} parameter(0)
values.0 = s32[1024,16]{0,1} parameter(1)
values.1 = u32[1024,16]{0,1} parameter(2)
values.2 = f32[1024,16]{0,1} parameter(3)
ROOT sorted = (f32[1024,16]{0,1}, s32[1024,16]{0,1}, u32[1024,16]{0,1}, f32[1024,16]{0,1}) sort(keys, values.0, values.1, values.2), dimensions={0}, to_apply=compare
}
)"
},
// Sort (Key) is_stable=true
{
"SortKeyStable",
R"(HloModule sort
compare {
p.0.lhs = f32[] parameter(0)
p.0.rhs = f32[] parameter(1)
ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT
}
ENTRY Sort {
x = f32[1024]{0} parameter(0)
ROOT sorted = f32[1024]{0} sort(x), dimensions={0}, is_stable=true, to_apply=compare
}
)"
},
// Indexed Conditional
{
"IndexedConditional",
R"(HloModule indexed_conditional
Negate {
x = f32[] parameter(0)
ROOT negate = f32[] negate(x)
}
Identity {
y = f32[] parameter(0)
ROOT copy = f32[] copy(y)
}
Floor {
z = f32[] parameter(0)
ROOT floor = f32[] floor(z)
}
ENTRY Parameters1.v4 {
constant = s32[] constant(1)
constant.1 = f32[] constant(56)
constant.2 = f32[] constant(12)
constant.3 = f32[] constant(13)
ROOT conditional = f32[] conditional(constant, constant.1, constant.2, constant.3), branch_computations={Negate, Identity, Floor}
}
)"
},
// Predicated Conditional
{
"PredicatedConditional",
R"(HloModule pred_conditional
Negate {
x = f32[] parameter(0)
ROOT negate = f32[] negate(x)
}
Identity {
y = f32[] parameter(0)
ROOT copy = f32[] copy(y)
}
ENTRY Parameters1.v4 {
constant = pred[] constant(true)
constant.1 = f32[] constant(56)
constant.2 = f32[] constant(12)
ROOT conditional = f32[] conditional(constant, constant.1, constant.2), true_computation=Negate, false_computation=Identity
}
)"
},
// CustomCall
{
"CustomCall",
R"(HloModule custom_call
ENTRY CustomCall {
constant = f32[1]{0} constant({12345})
ROOT custom-call = f32[1,2,3]{0,2,1} custom-call(constant), custom_call_target="foo\"bar"
}
)"
},
// CustomCall with single computation.
{
"CustumCallSingleComp",
R"(HloModule custom_call_with_comp
max_F32 {
lhs = f32[] parameter(0)
rhs = f32[] parameter(1)
ROOT maximum = f32[] maximum(lhs, rhs)
}
ENTRY CustomCall {
constant = f32[1]{0} constant({12345})
ROOT custom-call = f32[1,2,3]{0,2,1} custom-call(constant), custom_call_target="foo\"bar", called_computations={max_F32}
}
)"
},
// CustomCall with multiple computations.
{
"CustumCallMultipleComps",
R"(HloModule custom_call_with_comps
max_F32 {
lhs = f32[] parameter(0)
rhs = f32[] parameter(1)
ROOT maximum = f32[] maximum(lhs, rhs)
}
ENTRY CustomCall {
constant = f32[1]{0} constant({12345})
ROOT custom-call = f32[1,2,3]{0,2,1} custom-call(constant), custom_call_target="foo\"bar", called_computations={max_F32, max_F32}
}
)"
},
// Variables with non-default names
{
"NonDefaultNames",
R"(HloModule add_constants_module
ENTRY add_constants {
foo = f32[] constant(3.14)
ROOT bar = f32[] add(foo, foo)
}
)"
},
{
"Dot",
R"(HloModule dot
ENTRY dot {
a = f32[2,10]{1,0} parameter(0)
b = f32[10,2]{1,0} parameter(1)
ROOT dot = f32[2]{0} dot(a, b), lhs_batch_dims={0}, lhs_contracting_dims={1}, rhs_batch_dims={1}, rhs_contracting_dims={0}
}
)"
},
{
"gather",
R"(HloModule gather
ENTRY Gather {
input_tensor = f32[50,49,48,47,46]{4,3,2,1,0} parameter(0)
start_indices = s64[10,9,8,7,5]{4,3,2,1,0} parameter(1)
ROOT gather = f32[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} gather(input_tensor, start_indices), offset_dims={4,5,6,7,8}, collapsed_slice_dims={}, start_index_map={0,1,2,3,4}, index_vector_dim=4, slice_sizes={30,29,28,27,26}
}
)"
},
// all-reduce
{
"AllReduce",
R"(HloModule CRS
add {
lhs = f32[] parameter(0)
rhs = f32[] parameter(1)
ROOT add = f32[] add(lhs, rhs)
}
ENTRY CRS {
input = f32[8]{0} parameter(0)
ROOT crs = f32[8]{0} all-reduce(input), replica_groups={}, to_apply=add
}
)"
},
// all-reduce with subgroups
{
"AllReduceWithSubgroups",
R"(HloModule CRS_Subgroups
add {
lhs = f32[] parameter(0)
rhs = f32[] parameter(1)
ROOT add = f32[] add(lhs, rhs)
}
ENTRY AllReduceWithSubgroups {
input = f32[128,32]{0,1} parameter(0)
ROOT all-reduce = f32[128,32]{0,1} all-reduce(input), replica_groups={{0,1},{2,3}}, to_apply=add
}
)",
/*replica_count=*/4,
},
// all-reduce with constrained layout
{
"AllReduceWithLayout",
R"(HloModule CRS
add {
lhs = f32[] parameter(0)
rhs = f32[] parameter(1)
ROOT add = f32[] add(lhs, rhs)
}
ENTRY CRS {
input = f32[8]{0} parameter(0)
ROOT crs = f32[8]{0} all-reduce(input), replica_groups={}, constrain_layout=true, to_apply=add
}
)"
},
// all-reduce with channel-id
{
"AllReduceAllReduce",
R"(HloModule CRS
add {
lhs = f32[] parameter(0)
rhs = f32[] parameter(1)
ROOT add = f32[] add(lhs, rhs)
}
ENTRY CRS {
input = f32[8]{0} parameter(0)
crs.1 = f32[8]{0} all-reduce(input), channel_id=1, replica_groups={{0}}, to_apply=add
ROOT crs.0 = f32[8]{0} all-reduce(input), channel_id=1, replica_groups={{0}}, to_apply=add
}
)"
},
// all-reduce start and done
{
"AllReduceStartAndDone",
R"(HloModule CRS
add {
lhs = f32[] parameter(0)
rhs = f32[] parameter(1)
ROOT add = f32[] add(lhs, rhs)
}
ENTRY CRS {
input = f32[8]{0} parameter(0)
crs = (f32[8]{0}, f32[8]{0}) all-reduce-start(input), replica_groups={}, to_apply=add
ROOT done = f32[8]{0} all-reduce-done(crs)
}
)"
},
// all-reduce-scatter
{
"AllReduceScatter",
R"(HloModule ARS
add {
lhs = f32[] parameter(0)
rhs = f32[] parameter(1)
ROOT add = f32[] add(lhs, rhs)
}
ENTRY CRS {
input = f32[8]{0} parameter(0)
ROOT ars = f32[4]{0} all-reduce-scatter(input), replica_groups={{0,1}}, dimensions={0}, to_apply=add
}
)"
},
// all-gather
{
"AllGather",
R"(HloModule AllGather
ENTRY AllGather {
input = f32[128,32]{0,1} parameter(0)
ROOT ag = f32[128,128]{0,1} all-gather(input), replica_groups={}, dimensions={1}
}
)"
},
// all-gather with constrained layout
{
"AllGatherWithLayout",
R"(HloModule AllGather
ENTRY AllGather {
input = f32[128,32]{0,1} parameter(0)
ROOT ag = f32[128,128]{0,1} all-gather(input), replica_groups={}, constrain_layout=true, dimensions={1}
}
)"
},
// all-gather with subgroups
{
"AllGatherWithSubgroups",
R"(HloModule AllGatherWithSubgroups
ENTRY AllGatherWithSubgroups {
input = f32[128,32]{0,1} parameter(0)
ROOT ag = f32[128,64]{0,1} all-gather(input), replica_groups={{0,1},{2,3}}, dimensions={1}
}
)",
/*replica_count=*/4,
},
// all-to-all
{
"AllToAll",
R"(HloModule AllToAll
ENTRY AllToAll {
input = f32[128,32]{0,1} parameter(0)
ROOT a2a = (f32[128,32]{0,1}) all-to-all(input), replica_groups={}
}
)"
},
// all-to-all with subgroups
{
"AllToAllWithSubgroups",
R"(HloModule AllToAllWithSubgroups
ENTRY AllToAllWithSubgroups {
p0 = f32[128,32]{0,1} parameter(0)
p1 = f32[128,32]{0,1} parameter(1)
ROOT a2a = (f32[128,32]{0,1}, f32[128,32]{0,1}) all-to-all(p0, p1), replica_groups={{1,2},{3,0}}
}
)",
/*replica_count=*/4,
},
// collective-permute
{
"CollectivePermute",
R"(HloModule CollectivePermute
ENTRY CollectivePermute {
input = f32[128,32]{0,1} parameter(0)
ROOT root = f32[128,32]{0,1} collective-permute(input), source_target_pairs={{0,1},{1,2},{2,3}}
}
)",
/*replica_count=*/4
},
// collective-permute with in-place updates
{
"CollectivePermuteInPlaceUpdate",
R"(HloModule CollectivePermuteInPlaceUpdate
ENTRY CollectivePermuteInPlaceUpdate {
input = f32[128,32]{0,1} parameter(0)
constant = f32[] constant(1)
output = f32[128,128]{0,1} broadcast(constant), dimensions={}
constant.1 = s32[] constant(0)
tuple.1 = (s32[], s32[]) tuple(constant.1, constant.1)
constant.2 = s32[] constant(64)
tuple.2 = (s32[], s32[]) tuple(constant.1, constant.2)
ROOT root = f32[128,128]{0,1} collective-permute(input, output, tuple.1, tuple.2), source_target_pairs={{0,1},{1,2},{2,3}}, slice_sizes={{128,32}}
}
)",
/*replica_count=*/4
},
// collective-permute tuple with in-place updates
{
"CollectivePermuteTupleInPlaceUpdate",
R"(HloModule CollectivePermuteTupleInPlaceUpdate
ENTRY CollectivePermuteInPlaceUpdate {
input = f32[128,32]{0,1} parameter(0)
tuple.input = (f32[128,32]{0,1}, f32[128,32]{0,1}) tuple(input, input)
constant = f32[] constant(1)
output = f32[128,128]{0,1} broadcast(constant), dimensions={}
tuple.output = (f32[128,128]{0,1}, f32[128,128]{0,1}) tuple(output, output)
constant.1 = s32[] constant(0)
tuple.1 = (s32[], s32[]) tuple(constant.1, constant.1)
constant.2 = s32[] constant(64)
tuple.2 = (s32[], s32[]) tuple(constant.2, constant.1)
tuple.3 = ((s32[], s32[]), (s32[], s32[])) tuple(tuple.1, tuple.2)
tuple.4 = (s32[], s32[]) tuple(constant.1, constant.1)
tuple.5 = (s32[], s32[]) tuple(constant.2, constant.2)
tuple.6 = ((s32[], s32[]), (s32[], s32[])) tuple(tuple.4, tuple.5)
ROOT root = (f32[128,128]{0,1}, f32[128,128]{0,1}) collective-permute(tuple.input, tuple.output, tuple.3, tuple.6), source_target_pairs={{0,1},{1,2},{2,3}}, slice_sizes={{64,32},{64,32}}
}
)",
/*replica_count=*/4
},
// collective-permute-start and -done with inplace update
{
"CollectivePermuteStartAndDone",
R"(HloModule CollectivePermuteStartAndDone
ENTRY CollectivePermuteStartAndDone {
input = f32[128,32]{0,1} parameter(0)
collective-permute-start.1 = (f32[128,32]{0,1}, f32[128,32]{0,1}, u32[], u32[]) collective-permute-start(input), source_target_pairs={{0,1},{1,2},{2,3}}
ROOT collective-permute-done.1 = f32[128,32]{0,1} collective-permute-done(collective-permute-start.1)
}
)",
/*replica_count=*/4
},
// collective-permute-start and -done
{
"CollectivePermuteStartAndDoneInplaceUpdate",
R"(HloModule CollectivePermuteStartAndDoneInplaceUpdate
ENTRY CollectivePermuteStartAndDoneInplaceUpdate {
input = f32[128,32]{0,1} parameter(0)
constant = f32[] constant(1)
output = f32[128,128]{0,1} broadcast(constant), dimensions={}
constant.1 = s32[] constant(0)
tuple.1 = (s32[], s32[]) tuple(constant.1, constant.1)
constant.2 = s32[] constant(64)
tuple.2 = (s32[], s32[]) tuple(constant.1, constant.2)
collective-permute-start.1 = (f32[128,32]{0,1}, f32[128,128]{0,1}, u32[], u32[], s32[]) collective-permute-start(input, output, tuple.1, tuple.2), source_target_pairs={{0,1},{1,2},{2,3}}, slice_sizes={{64,32}}
ROOT collective-permute-done.1 = f32[128,128]{0,1} collective-permute-done(collective-permute-start.1)
}
)",
/*replica_count=*/4
},
// replica-id
{
"ReplicaId",
R"(HloModule replica-id
ENTRY Replica-id {
ROOT replica-id = u32[] replica-id()
}
)"
},
// partition-id
{
"PartitionId",
R"(HloModule partition-id
ENTRY PartitionId {
ROOT id = u32[] partition-id()
}
)"
},
// Iota
{
"Iota",
R"(HloModule iota
ENTRY Iota {
ROOT iota = f32[100]{0} iota(), iota_dimension=0
}
)"
},
// custom-call with window, dim_labels and feature_group_count
{
"CustomCallWithWindowAndDimLabelsAndFeatureGroupCount",
R"(HloModule CustomCallWithWindowAndDimLabelsAndFeatureGroupCount
ENTRY Computation {
ROOT r = f32[100]{0} custom-call(), window={size=2x2}, dim_labels=b01f_01io->b01f, feature_group_count=2, custom_call_target="target"
}
)"
},
// is_scheduled=true attribute
{
"ScheduledModule",
R"(HloModule scheduled_module, is_scheduled=true
compare {
p.1.lhs = s32[] parameter(2)
p.1.rhs = s32[] parameter(3)
p.0.lhs = f32[] parameter(0)
p.0.rhs = f32[] parameter(1)
ROOT lhs = pred[] compare(p.0.lhs, p.0.rhs), direction=LT
}
ENTRY Sort {
keys = f32[1024]{0} parameter(0)
values = s32[1024]{0} parameter(1)
ROOT sorted = (f32[1024]{0}, s32[1024]{0}) sort(keys, values), dimensions={0}, to_apply=compare
}
)"
},
// AfterAll with multiple operands
{
"AfterAllWithMultipleOperands",
R"(HloModule AfterAllWithMultipleOperands
ENTRY AfterAllWithMultipleOperands {
p0 = f32[] parameter(0)
token0 = token[] after-all()
token1 = token[] after-all()
ROOT after-all = token[] after-all(p0, token0, token1)
}
)"
},
// AddDependency
// A dependency chain is created from 'neg' to 'exp' using tokens.
{
"AddDependency",
R"(HloModule AddDependency
ENTRY AddDependency {
p = f32[] parameter(0)
neg = f32[] negate(p)
token0 = token[] after-all(neg)
p_after_token = f32[] add-dependency(p, token0)
exp = f32[] exponential(p_after_token)
ROOT sum = f32[] add(neg, exp)
}
)"
},
// A module containing constants equal to the min/max values of various data
// types.
{
"MinMaxValues",
R"(HloModule MinMaxValues
ENTRY MinMaxValues {
x.s8 = s8[2]{0} constant({-128, 127})
x.s16 = s16[2]{0} constant({-32768, 32767})
x.s32 = s32[2]{0} constant({-2147483648, 2147483647})
x.u8 = u8[2]{0} constant({0, 255})
x.u16 = u16[2]{0} constant({0, 65535})
x.u32 = u32[2]{0} constant({0, 4294967295})
x.f16 = f16[2]{0} constant({-65504, 65504})
x.bf16 = bf16[2]{0} constant({-3.39e+38, 3.39e+38})
x.f32 = f32[2]{0} constant({-3.40282e+38, 3.40282e+38})
x.f64 = f64[2]{0} constant({-1.79769e+308, 1.79769e+308})
x.c64 = c64[2]{0} constant({(-3.40282e+38, 3.40282e+38), (3.40282e+38, -3.40282e+38)})
ROOT c.c128 = c128[2]{0} constant({(-1.79769e+308, 1.79769e+308), (1.79769e+308, -1.79769e+308)})
}
)"
},
// Bitcast-convert usage
{
"BitcastConvert",
R"(HloModule BitcastConvert
ENTRY BitcastConvertUsage {
p = f32[100]{0} parameter(0)
ROOT out = u32[100]{0} bitcast-convert(p)
}
)"
},
{
"OuterDimensionPartitions",
R"(HloModule OuterDimensionPartitions
ENTRY Test {
ROOT foo = f32[100]{0} parameter(0), outer_dimension_partitions={0,10,20}
}
)"
},
});
// clang-format on
}
// The test class for those tests defined above which round-trip through the
// parser and ToString is templatized on two bool parameters:
//
// short_form : used for the "short" test cases which use the ShortParsable
// output form.
// proto_round_trip : whether the module should also be round-tripped through
// HloProto form. This provides much better coverage for the proto
// serialization/deserialization.
//
// The proto_round_trip=true case also technically covers the Parser->ToString
// roundtrip as well, but separating out the Parser->ToString roundtrip as its
// own test provides better isolation and could conceivably catch weirdo bugs
// which are hidden by interaction between the textual and proto roundtripping.
template <bool short_form, bool proto_round_trip>
class HloParameterizedParserTest
: public ::testing::Test,
public ::testing::WithParamInterface<TestData> {
protected:
// Expects "ToString(ParseHloModule(string)) == string", that is, parses the
// string, asserts that it succeeded, stringifies the parsed module, and
// checks that it equals the original string.
void ExpectEqual() {
std::unique_ptr<HloModule> module;
const string& original = GetParam().module_string;
HloModuleConfig config;
config.set_replica_count(GetParam().replica_count);
if (GetParam().enable_verification) {
auto verified_module = absl::make_unique<VerifiedHloModule>(
GetParam().test_name, config,
/*verifier_layout_sensitive=*/false,
/*allow_mixed_precision_in_hlo_verifier=*/true,
ShapeUtil::ByteSizeOfElements);
TF_ASSERT_OK(verified_module->ParseHloStringAndVerifyModule(original));
module = std::move(verified_module);
} else {
TF_ASSERT_OK_AND_ASSIGN(module,
ParseAndReturnUnverifiedModule(original, config));
}
if (proto_round_trip) {
TF_ASSERT_OK_AND_ASSIGN(module, HloModule::CreateFromProto(
module->ToProto(), module->config()));
}
if (short_form) {
EXPECT_EQ(original, module->ToString(HloPrintOptions::ShortParsable()));
} else {
EXPECT_EQ(
original,
module->ToString(HloPrintOptions().set_print_large_constants(true)));
}
}
};
// These using shenanigans are required because the TEST_P macro doesn't like
// template instantiations which contain commas.
using HloParserTestLong = HloParameterizedParserTest<false, false>;
using HloParserTestLongProto = HloParameterizedParserTest<false, true>;
using HloParserTestShort = HloParameterizedParserTest<true, false>;
using HloParserTestShortProto = HloParameterizedParserTest<true, true>;
TEST_P(HloParserTestLong, Run) { ExpectEqual(); }
TEST_P(HloParserTestLongProto, Run) { ExpectEqual(); }
TEST_P(HloParserTestShort, Run) { ExpectEqual(); }
TEST_P(HloParserTestShortProto, Run) { ExpectEqual(); }
INSTANTIATE_TEST_SUITE_P(HloParserTestSuccessInstantiation, HloParserTestLong,
::testing::ValuesIn(CreateTestCases()),
TestDataToString);
INSTANTIATE_TEST_SUITE_P(HloParserTestSuccessInstantiation,
HloParserTestLongProto,
::testing::ValuesIn(CreateTestCases()),
TestDataToString);
INSTANTIATE_TEST_SUITE_P(HloParserTestSuccessInstantiation, HloParserTestShort,
::testing::ValuesIn(CreateShortTestCases()),
TestDataToString);
INSTANTIATE_TEST_SUITE_P(HloParserTestSuccessInstantiation,
HloParserTestShortProto,
::testing::ValuesIn(CreateShortTestCases()),
TestDataToString);
class HloParserTest : public ::testing::Test {
protected:
static void ExpectHasSubstr(string_view s, string_view expected) {
EXPECT_TRUE(absl::StrContains(s, expected))
<< "'" << s << "' does not contain '" << expected << "'";
}
StatusOr<std::unique_ptr<VerifiedHloModule>> ParseAndReturnVerifiedModule(
absl::string_view hlo_text) {
auto module = absl::make_unique<VerifiedHloModule>(
::testing::UnitTest::GetInstance()->current_test_info()->name(),
HloModuleConfig(),
/*verifier_layout_sensitive=*/false,
/*allow_mixed_precision_in_hlo_verifier=*/true,
ShapeUtil::ByteSizeOfElements);
TF_RETURN_IF_ERROR(module->ParseHloStringAndVerifyModule(hlo_text));
return std::move(module);
}
};
TEST_F(HloParserTest, Empty) {
const string original = "";
auto result = ParseAndReturnUnverifiedModule(original);
EXPECT_NE(Status::OK(), result.status());
}
TEST_F(HloParserTest, Garbage) {
const string original = "HloModule thi$ str1ng makes# N0 sen$e @all!*&^%$";
auto result = ParseAndReturnUnverifiedModule(original);
EXPECT_NE(Status::OK(), result.status());
}
TEST_F(HloParserTest, WrongOpcode) {
const string original = R"(HloModule wrong_opcode:
ENTRY %blabla (x: f32[], y: f32[]) -> f32[] {
%x = f32[]{} parameter(0)
%y = f32[]{} parameter(1)
%le = pred[]{} le(f32[]{} %x, f32[]{} %y)
}
)";
auto result = ParseAndReturnUnverifiedModule(original);
EXPECT_NE(Status::OK(), result.status());
}
TEST_F(HloParserTest, MetadataWithCholesky) {
const string original = R"(HloModule metadata_with_cholesky
ENTRY %blabla (a: f32[1,291,291]) -> f32[1,291,291] {
%a = f32[1,291,291] parameter(0)
%out = f32[1,291,291] cholesky(f32[1,291,291] %a), lower=true, metadata={op_type="Cholesky" op_name="Cholesky" profile_type={1}}
}
)";
auto result = ParseAndReturnVerifiedModule(original);
EXPECT_EQ(Status::OK(), result.status());
EXPECT_EQ("Cholesky", result.ValueOrDie()
->entry_computation()
->root_instruction()
->metadata()
.op_name());
EXPECT_EQ("Cholesky", result.ValueOrDie()
->entry_computation()
->root_instruction()
->metadata()
.op_type());
EXPECT_EQ(WINDOW, *result.ValueOrDie()
->entry_computation()
->root_instruction()
->metadata()
.profile_type()
.begin());
}
TEST_F(HloParserTest, WrongShape) {
const string original = R"(HloModule wrong_opcode:
ENTRY %blabla (x: g32[]) -> g32[] {
%x = g32[]{} parameter(0)
}
)";
auto result = ParseAndReturnUnverifiedModule(original);
EXPECT_NE(Status::OK(), result.status());
}
TEST_F(HloParserTest, WrongOperandsSize) {
const string original = R"(HloModule wrong_opcode:
ENTRY %blabla (x: f32[]) -> pred[] {
%x = f32[]{} parameter(0)
%eq = pred[]{} compare(f32[]{} %x), direction=EQ
}
)";
auto result = ParseAndReturnUnverifiedModule(original);
EXPECT_NE(Status::OK(), result.status());
}
TEST_F(HloParserTest, OperandNotFound) {
const string original = R"(HloModule operand_not_found:
ENTRY %blabla (x: f32[]) -> pred[] {
%x = f32[]{} parameter(0)
%eq = pred[]{} compare(f32[]{} %x, f32[]{} %y), direction=EQ
}
)";
auto result = ParseAndReturnUnverifiedModule(original);
EXPECT_NE(Status::OK(), result.status());
}
TEST_F(HloParserTest, MoreConstants) {
const string original = R"(HloModule SelectScalarS32True_module
ENTRY %SelectScalarS32True.v4 () -> s32[] {
%constant.2 = pred[] constant(true)
%constant.1 = s32[] constant(-42), sharding={devices=[2,2]1,2,3,4}
%constant = s32[] constant(42)
%select = s32[] select(pred[] %constant.2, s32[] %constant.1, s32[] %constant)
}
)";
auto result = ParseAndReturnVerifiedModule(original);
TF_EXPECT_OK(result.status());
// Constant instructions have no name. The string will be parsed successfully
// but the constant names will not be exactly the same.
}
TEST_F(HloParserTest, ConfigurationField) {
const string original = R"(HloModule AModule
ENTRY %configuration_test() -> s32[] {
%constant = s32[] constant(42), backend_config="foo bar"
})";
auto result = ParseAndReturnVerifiedModule(original);
TF_ASSERT_OK(result.status());
EXPECT_EQ("foo bar", result.ValueOrDie()
->entry_computation()
->root_instruction()
->raw_backend_config_string());
}
TEST_F(HloParserTest, LiteralDimensionsMismatch_1) {
const string original = R"(HloModule some_2_module
ENTRY %some_2 () -> f32[2] {
ROOT %constant = f32[2]{0} constant({1,{2}})
}
)";
auto result = ParseAndReturnUnverifiedModule(original);
EXPECT_NE(Status::OK(), result.status());
ExpectHasSubstr(result.status().error_message(),
"expects nested array in rank 1, but sees larger");
}
TEST_F(HloParserTest, LiteralDimensionsMismatch_2) {
const string original = R"(HloModule some_2x3_module
ENTRY %some_2x3 () -> f32[2,3] {
ROOT %constant = f32[2,3]{1,0} constant({1, 2, 3, 4, 5, 6})
}
)";
auto result = ParseAndReturnUnverifiedModule(original);
EXPECT_NE(Status::OK(), result.status());
ExpectHasSubstr(result.status().error_message(),
"expects nested array in rank 2, but sees 1");
}
TEST_F(HloParserTest, LiteralDimensionsMismatch_3) {
const string original = R"(HloModule some_2x3x2_module
ENTRY %some_2x3x2 () -> f32[2,3,2] {
ROOT %constant = f32[2,3,2]{2,1,0} constant({{{1, 2}, {3, 4}, {5, 6}, {7, 8}, {9, 10}, {11, 12}}})
}
)";
auto result = ParseAndReturnUnverifiedModule(original);
EXPECT_NE(Status::OK(), result.status());
ExpectHasSubstr(result.status().error_message(),
"expects 3 elements in the [0]th element");
}
TEST_F(HloParserTest, ConstantF16Overflow) {
const string original =
R"(HloModule ConstantF16Overflow_module
ENTRY %ConstantF16Overflow.v4 () -> f16[] {
ROOT %constant = f16[] constant(-65520)
}
)";
auto result = ParseAndReturnUnverifiedModule(original);
EXPECT_NE(Status::OK(), result.status());
ExpectHasSubstr(result.status().error_message(),
"is out of range for literal's primitive type F16");
}
TEST_F(HloParserTest, ConstantBf16NoOverflow) {
// 65505 is in range for bf16.
const string original = R"(
HloModule test_module
ENTRY test {
ROOT c = bf16[] constant(-65505)
})";
EXPECT_EQ(Status::OK(), ParseAndReturnVerifiedModule(original).status());
}
TEST_F(HloParserTest, ConstantBf16Overflow) {
// 1e100 is out of range for bf16.
const string original = R"(
HloModule test_module
ENTRY test {
ROOT c = bf16[] constant(1e100)
})";
ExpectHasSubstr(
ParseAndReturnUnverifiedModule(original).status().error_message(),
"out of range");
}
TEST_F(HloParserTest, ConstantUnsignedUnderflow) {
const string original = R"(
HloModule ConstantUnsignedUnderflow_module
ENTRY %ConstantUnsignedUnderflow () -> u64[] {
ROOT %constant = u64[] constant(-1)
})";
auto result = ParseAndReturnUnverifiedModule(original);
EXPECT_EQ(Status::OK(), result.status());
}
TEST_F(HloParserTest, ConstantUnsignedOverflow) {
const string original = R"(
HloModule ConstantUnsignedOverflow_module
ENTRY %ConstantUnsignedOverflow () -> u32[] {
ROOT %constant = u32[] constant(4294967296)
})";
auto result = ParseAndReturnUnverifiedModule(original);
EXPECT_NE(Status::OK(), result.status());
ExpectHasSubstr(result.status().error_message(),
"is out of range for literal's primitive type U32");
}
TEST_F(HloParserTest, ConstantUnsignedInt64Overflow) {
const string original = R"(
HloModule ConstantUnsignedOverflow_module
ENTRY %ConstantUnsignedOverflow () -> u64[] {
ROOT %constant = u64[] constant(9223372036854775808)
})";
auto result = ParseAndReturnUnverifiedModule(original);
EXPECT_EQ(Status::OK(), result.status());
}
TEST_F(HloParserTest, ConstantC64Overflow) {
const string original = R"(
HloModule test_module
ENTRY test () -> c64[] {
ROOT c = c64[] constant((1e100, 0))
})";
auto result = ParseAndReturnUnverifiedModule(original);
EXPECT_NE(Status::OK(), result.status());
}
TEST_F(HloParserTest, ConstantC64Underflow) {
const string original = R"(
HloModule test_module
ENTRY test () -> c64[] {
ROOT c = c64[] constant((0, -1e100))
})";
auto result = ParseAndReturnUnverifiedModule(original);
EXPECT_NE(Status::OK(), result.status());
}
TEST_F(HloParserTest, ConstantF64Overflow) {
const string original = R"(
HloModule test_module
ENTRY test {
ROOT c = f64[] constant(1.8e308)
})";
auto result = ParseAndReturnUnverifiedModule(original);
EXPECT_NE(Status::OK(), result.status());
}
TEST_F(HloParserTest, ConstantF64Underflow) {
const string original = R"(
HloModule test_module
ENTRY test {
ROOT c = f64[] constant(-1.8e308)
})";
auto result = ParseAndReturnUnverifiedModule(original);
EXPECT_NE(Status::OK(), result.status());
}
TEST_F(HloParserTest, ConstantWithExp) {
const string original = R"(HloModule ConstantWithExp_module
ENTRY %ConstantWithExp.v4 () -> f32[] {
%constant.1 = f32[] constant(3e+2)
}
)";
auto result = ParseAndReturnVerifiedModule(original);
TF_EXPECT_OK(result.status());
// The string will be parsed successfully but the output strings are not
// exactly the same, because "3e2" is parsed into value 300 and will be
// printed as "300".
}
TEST_F(HloParserTest, ShortConstant) {
const string original = R"(HloModule ShortConstant_module
ENTRY %ShortConstant.v4 () -> f32[67,89] {
ROOT %constant.1 = f32[67,89]{1,0} constant({...})
}
)";
auto result = ParseAndReturnVerifiedModule(original);
TF_EXPECT_OK(result.status());
EXPECT_EQ(result.ValueOrDie()->ToString(HloPrintOptions()), original);
}
TEST_F(HloParserTest, NegativeNan) {
const string original = R"(HloModule NegativeNan_module
ENTRY %NegativeNan () -> bf16[2] {
ROOT %constant = bf16[2]{0} constant({-nan, -nan})
}
)";
auto result = ParseAndReturnUnverifiedModule(original);
EXPECT_EQ(Status::OK(), result.status());
EXPECT_EQ(result.ValueOrDie()->ToString(HloPrintOptions()), original);
}
TEST_F(HloParserTest, NanPayload) {
const string original = R"(HloModule NanPayload_module
ENTRY %NanPayload () -> bf16[2] {
ROOT %constant = bf16[2]{0} constant({-nan(0x7f), -nan(0x3f)})
}
)";
auto result = ParseAndReturnUnverifiedModule(original);
EXPECT_EQ(Status::OK(), result.status());
EXPECT_EQ(result.ValueOrDie()->ToString(HloPrintOptions()), original);
}
TEST_F(HloParserTest, AttributesAnyOrder) {
const string original = R"(HloModule any_order_module
ENTRY %Convolve1D1Window_0.v3 (input: f32[1,2,1], filter: f32[1,1,1]) -> f32[1,4,1] {
%input = f32[1,2,1]{2,1,0} parameter(0)
%copy = f32[1,2,1]{2,0,1} copy(f32[1,2,1]{2,1,0} %input)
%filter = f32[1,1,1]{2,1,0} parameter(1)
ROOT %convolution = f32[1,4,1]{2,0,1} convolution(f32[1,2,1]{2,0,1} %copy, f32[1,1,1]{2,1,0} %filter), feature_group_count=1, sharding={maximal device=1}, backend_config="foo", dim_labels=b0f_0io->b0f, window={pad=1_1 size=1}
}
)";
TF_EXPECT_OK(ParseAndReturnVerifiedModule(original).status());
}
TEST_F(HloParserTest, InvalidDimLabels) {
string prefix = R"(HloModule invalid_dim_labels_module
ENTRY %Convolve1D1Window_0.v3 (input: f32[1,2,1], filter: f32[1,1,1]) -> f32[1,2,1] {
%input = f32[1,2,1]{2,1,0} parameter(0)
%copy = f32[1,2,1]{2,0,1} copy(f32[1,2,1]{2,1,0} %input)
%filter = f32[1,1,1]{2,1,0} parameter(1)
ROOT %convolution = f32[1,2,1]{2,0,1} convolution(f32[1,2,1]{2,0,1} %copy, f32[1,1,1]{2,1,0} %filter), window={size=1} )";
string suffix = R"(
}
)";
ExpectHasSubstr(ParseAndReturnUnverifiedModule(
absl::StrCat(prefix, ",dim_labels=00_01->10", suffix))
.status()
.error_message(),
"expects unique");
ExpectHasSubstr(ParseAndReturnUnverifiedModule(
absl::StrCat(prefix, ",dim_labels=012_0123->210", suffix))
.status()
.error_message(),
"must have same number of spatial dimensions");
ExpectHasSubstr(ParseAndReturnUnverifiedModule(
absl::StrCat(prefix, ",dim_labels=013_0123->210", suffix))
.status()
.error_message(),
"expects [0-2bf?]");
}
TEST_F(HloParserTest, UnexpectedAttribute) {
const string original = R"(HloModule unexpected_attr_module
ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> f32[] {
%token0 = token[] after-all()
%recv = (f32[], u32[], token[]) recv(token[] %token0), channel_id=15
%recv-done = (f32[], token[]) recv-done((f32[], u32[], token[]) %recv), channel_id=15
ROOT %constant = f32[] constant(2.1)
%send = (f32[], u32[], token[]) send(f32[] %constant, token[] %token0), channel_id=16, calls=%recv
%send-done = token[] send-done((f32[], u32[], token[]) %send), channel_id=16
}
)";
ExpectHasSubstr(
ParseAndReturnUnverifiedModule(original).status().error_message(),
"unexpected attribute \"calls\"");
}
TEST_F(HloParserTest, MissingAttribute) {
const string original = R"(HloModule missing_attr_module
ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> f32[] {
%token0 = token[] after-all()
%recv = (f32[], u32[], token[]) recv(token[] %token0), channel_id=15
%recv-done = (f32[], token[]) recv-done((f32[], u32[], token[]) %recv), channel_id=15
ROOT %constant = f32[] constant(-2.1)
%send = (f32[], u32[], token[]) send(f32[] %constant, token[] %token0)
%send-done = token[] send-done((f32[], u32[], token[]) %send), channel_id=16
}
)";
ExpectHasSubstr(
ParseAndReturnUnverifiedModule(original).status().error_message(),
"attribute channel_id is expected but not seen");
}
TEST_F(HloParserTest, PredecessorUndefined) {
const string original = R"(HloModule pre_not_found_module
ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> f32[] {
%token0 = token[] after-all()
%recv = (f32[], u32[], token[]) recv(token[] %token0), channel_id=15
%recv-done = (f32[], token[]) recv-done((f32[], u32[], token[]) %recv), channel_id=15
ROOT %constant = f32[] constant(2.1)
%send = (f32[], u32[], token[]) send(f32[] %constant, token[] %token0), channel_id=16, control-predecessors={%done}
%send-done = token[] send-done((f32[], u32[], token[]) %send), channel_id=16
}
)";
ExpectHasSubstr(
ParseAndReturnUnverifiedModule(original).status().error_message(),
"'done' is not defined");
}
TEST_F(HloParserTest, SliceAllowOmitStride1) {
const string original = R"(HloModule slice_module
ENTRY %slice.v2 (p0: f32[3,3,4,4]) -> f32[3,3,2,4] {
%p0 = f32[3,3,4,4]{3,2,1,0} parameter(0)
ROOT %slice = f32[3,3,2,4]{3,2,1,0} slice(f32[3,3,4,4]{3,2,1,0} %p0), slice={[0:3], [0:3], [0:4:2], [0:4]}
}
)";
TF_EXPECT_OK(ParseAndReturnVerifiedModule(original).status());
}
TEST_F(HloParserTest, PaddingConfigIsNotWindowPad) {
const string original = R"(HloModule window_pad_module
ENTRY %Convolve1D1Window_0.v3 (input: f32[1,2,1], filter: f32[1,1,1]) -> f32[1,2,1] {
%input = f32[1,2,1]{2,1,0} parameter(0)
%copy = f32[1,2,1]{2,0,1} copy(f32[1,2,1]{2,1,0} %input)
%filter = f32[1,1,1]{2,1,0} parameter(1)
ROOT %convolution = f32[1,2,1]{2,0,1} convolution(f32[1,2,1]{2,0,1} %copy, f32[1,1,1]{2,1,0} %filter), dim_labels=b0f_0io->b0f, window={pad=1_1_0 size=1}
}
)";
ExpectHasSubstr(
ParseAndReturnUnverifiedModule(original).status().error_message(),
"expects padding_low and padding_high separated by '_'");
}
TEST_F(HloParserTest, CommaBetweenSubAttributes) {
const string original = R"(HloModule test_comma_module
ENTRY %test_comma.v4 () -> f32[] {
ROOT %constant = f32[] constant(-4.2), metadata={source_line=5, op_type="::const"}
}
)";
TF_EXPECT_OK(ParseAndReturnVerifiedModule(original).status());
}
TEST_F(HloParserTest, ComputationShapeDoesNotMatchRootShape) {
const string original = R"(HloModule custom_call:
ENTRY %CustomCall () -> f32[1] {
%constant = f32[1]{0} constant({12345})
ROOT %foo = f32[1,2,3]{0,2,1} custom-call(f32[1]{0} %constant), custom_call_target="foo\"bar"
})";
ExpectHasSubstr(
ParseAndReturnUnverifiedModule(original).status().error_message(),
"Shape of computation CustomCall, f32[1], is not compatible "
"with that of its root instruction foo, f32[1,2,3]");
}
TEST_F(HloParserTest, EntryComputationWithLayout) {
const string original = R"(HloModule layout:
add_F32.v3 {
lhs = f32[] parameter(0)
rhs = f32[] parameter(1)
ROOT add = f32[] add(lhs, rhs)
}
ENTRY %Reduce (input: f32[8,16,256]) -> f32[8,16] {
input = f32[8,16,256]{0,1,2} parameter(0)
constant = f32[] constant(0)
ROOT reduce = f32[8,16]{0,1} reduce(input, constant), dimensions={2}, to_apply=add_F32.v3
})";
auto module = ParseAndReturnVerifiedModule(original);
TF_ASSERT_OK(module.status());
auto program_layout = module.ValueOrDie()->entry_computation_layout();
ASSERT_EQ(program_layout.parameter_count(), 1);
auto param_layout = program_layout.parameter_layout(0).layout();
auto result_layout = program_layout.result_layout().layout();
EXPECT_TRUE(
LayoutUtil::Equal(LayoutUtil::MakeLayout({0, 1, 2}), param_layout))
<< "actual layout of parameter(0) is "
<< LayoutUtil::HumanString(param_layout);
EXPECT_TRUE(LayoutUtil::Equal(LayoutUtil::MakeLayout({0, 1}), result_layout))
<< "actual layout of result is "
<< LayoutUtil::HumanString(result_layout);
}
TEST_F(HloParserTest, NoEntry) {
const string original = R"(HloModule no_entry:
c1 {
const1 = f32[1]{0} constant({12345})
}
c2 {
const2 = f32[1]{0} constant({67890})
})";
auto module = ParseAndReturnVerifiedModule(original);
TF_ASSERT_OK(module.status());
EXPECT_EQ(module.ValueOrDie()->entry_computation()->name(), "c2");
}
TEST_F(HloParserTest, NoRoot) {
const string original = R"(HloModule no_root:
ENTRY consts {
first = f32[1]{0} constant({12345})
last = f32[1]{0} constant({67890})
})";
auto module = ParseAndReturnVerifiedModule(original);
TF_ASSERT_OK(module.status());
EXPECT_EQ(
module.ValueOrDie()->entry_computation()->root_instruction()->name(),
"last");
}
TEST_F(HloParserTest, Comments) {
const string original = R"(/* module description. */
HloModule comments:
ENTRY /*comment*/ c1 {
/* blah */
ROOT const1 = /*foo*/f32[1]{0} constant({12345 /*bar*/})
/* comment */
}
/* something else */
)";
auto module = ParseAndReturnVerifiedModule(original);
TF_ASSERT_OK(module.status());
}
TEST_F(HloParserTest, MultilineComments) {
const string original = R"(HloModule multiline_comment:
ENTRY c1 {
/*
ROOT foo = f32[1]{0} constant({12345})
*/
ROOT const1 = f32[1]{0} constant({12345})
/*
a
b
c
d
*/
})";
auto module = ParseAndReturnVerifiedModule(original);
TF_ASSERT_OK(module.status());
}
TEST_F(HloParserTest, UnterminatedComment) {
const string original = R"(HloModule unterminated_comment:
ENTRY c1 {
/* unterminated
ROOT const1 = f32[1]{0} constant({12345})
})";
// Verify that the error message points to the beginning of the unterminated
// comment.
ExpectHasSubstr(
ParseAndReturnUnverifiedModule(original).status().error_message(),
"/* unterminated\n^");
}
TEST_F(HloParserTest, SlashSlashComments) {
const string original = R"(HloModule slash_slash_comment:
// Garbage
ENTRY c1 {
// Foo bar
ROOT const1 = f32[1]{0} constant({12345}) // Something else
})";
auto module = ParseAndReturnVerifiedModule(original);
TF_ASSERT_OK(module.status());
}
TEST_F(HloParserTest, SlashSlashCommentMsDosEolFormat) {
const string original =
"HloModule slash_slash_comment:\r\n// Garbage\r\nENTRY c1 {\r\n// Foo "
"bar\r\nROOT const1 = f32[1]{0} constant({12345}) // Something else\r\n}";
auto module = ParseAndReturnVerifiedModule(original);
TF_ASSERT_OK(module.status());
}
TEST_F(HloParserTest, SlashSlashCommentMacEolFormat) {
const string original =
"HloModule slash_slash_comment:\r// Garbage\rENTRY c1 {\r// Foo "
"bar\rROOT const1 = f32[1]{0} constant({12345}) // Something else\r}";
auto module = ParseAndReturnVerifiedModule(original);
TF_ASSERT_OK(module.status());
}
TEST_F(HloParserTest, MultipleEntries) {
const string original = R"(HloModule multiple_entries:
ENTRY c1 {
const1 = f32[1]{0} constant({12345})
}
ENTRY c2 {
const2 = f32[1]{0} constant({67890})
})";
ExpectHasSubstr(
ParseAndReturnUnverifiedModule(original).status().error_message(),
"expects only one ENTRY");
}
TEST_F(HloParserTest, SimpleAliasing) {
const string original = R"(
HloModule Module, input_output_alias={ {0}: (0, {0}, must-alias), {1}: (0, {1}) }
ENTRY entry {
%p = (f32[], f32[]) parameter(0)
%p0 = f32[] get-tuple-element((f32[], f32[]) %p), index=0
%p1 = f32[] get-tuple-element((f32[], f32[]) %p), index=1
ROOT %out = (f32[], f32[]) tuple(%p0, %p1)
}
)";
auto module = ParseAndReturnVerifiedModule(original);
TF_ASSERT_OK(module.status());
std::unique_ptr<HloModule> parsed_module = module.ConsumeValueOrDie();
EXPECT_EQ(parsed_module->input_output_alias_config().GetAliasedOutput(0, {0}),
ShapeIndex{0});
EXPECT_TRUE(
parsed_module->input_output_alias_config().ParameterMustAlias(0, {0}));
EXPECT_EQ(parsed_module->input_output_alias_config().GetAliasedOutput(0, {1}),
ShapeIndex{1});
EXPECT_FALSE(
parsed_module->input_output_alias_config().ParameterMustAlias(0, {1}));
}
TEST_F(HloParserTest, NestedAliasing) {
const string original = R"(
HloModule Module, input_output_alias={ {0, 0}: (0, {0}), {1, 1}: (0, {1}) }
ENTRY entry {
%p = (f32[], f32[]) parameter(0)
%p0 = f32[] get-tuple-element((f32[], f32[]) %p), index=0
%p1 = f32[] get-tuple-element((f32[], f32[]) %p), index=1
%t0 = (f32[], f32[]) tuple(%p0, %p1)
%t1 = (f32[], f32[]) tuple(%p0, %p1)
ROOT %out = ((f32[], f32[]), (f32[], f32[])) tuple(%t0, %t1)
}
)";
auto module = ParseAndReturnVerifiedModule(original);
TF_ASSERT_OK(module.status());
std::unique_ptr<HloModule> parsed_module = module.ConsumeValueOrDie();
EXPECT_EQ(parsed_module->input_output_alias_config().GetAliasedOutput(0, {0}),
ShapeIndex({0, 0}));
EXPECT_EQ(parsed_module->input_output_alias_config().GetAliasedOutput(0, {1}),
ShapeIndex({1, 1}));
}
TEST_F(HloParserTest, AliasingWrongIndex) {
const string original = R"(
HloModule Module, input_output_alias={ {0 : (0, {0}), {1}: (0, {1}) }
ENTRY entry {
%p = (f32[], f32[]) parameter(0)
%p0 = f32[] get-tuple-element((f32[], f32[]) %p), index=0
%p1 = f32[] get-tuple-element((f32[], f32[]) %p), index=1
ROOT %out = (f32[], f32[]) tuple(%p0, %p1)
}
)";
ExpectHasSubstr(
ParseAndReturnUnverifiedModule(original).status().error_message(),
"Expects '}' at the end of ShapeIndex");
}
TEST_F(HloParserTest, AliasingShapeIndexNotNumerical) {
const string original = R"(
HloModule Module, input_output_alias={ {0, a}: (0, {0}), {1}: (0, {1}) }
ENTRY entry {
%p = (f32[], f32[]) parameter(0)
%p0 = f32[] get-tuple-element((f32[], f32[]) %p), index=0
%p1 = f32[] get-tuple-element((f32[], f32[]) %p), index=1
ROOT %out = (f32[], f32[]) tuple(%p0, %p1)
}
)";
ExpectHasSubstr(
ParseAndReturnUnverifiedModule(original).status().error_message(),
"expects integer");
}
TEST_F(HloParserTest, AliasingWrongFormatNoColon) {
const string original = R"(
HloModule Module, input_output_alias={ {0, 0}: (0, {0}), (0, {1}) }
ENTRY entry {
%p = (f32[], f32[]) parameter(0)
%p0 = f32[] get-tuple-element((f32[], f32[]) %p), index=0
%p1 = f32[] get-tuple-element((f32[], f32[]) %p), index=1
ROOT %out = (f32[], f32[]) tuple(%p0, %p1)
}
)";
ExpectHasSubstr(
ParseAndReturnUnverifiedModule(original).status().error_message(),
"Expects '{' at the start of ShapeIndex");
}
TEST_F(HloParserTest, AliasingWrongFormatTwoColons) {
const string original = R"(
HloModule Module, input_output_alias={ {0}: (0, {0}): {0, 1}, {1}: (0, {1}) }
ENTRY entry {
%p = (f32[], f32[]) parameter(0)
%p0 = f32[] get-tuple-element((f32[], f32[]) %p), index=0
%p1 = f32[] get-tuple-element((f32[], f32[]) %p), index=1
ROOT %out = (f32[], f32[]) tuple(%p0, %p1)
}
)";
ExpectHasSubstr(
ParseAndReturnUnverifiedModule(original).status().error_message(),
"Expects '}' at the end of aliasing description");
}
TEST_F(HloParserTest, AliasingWrongFormatAlphaParam) {
const string original = R"(
HloModule Module, input_output_alias={ {0, a}: (zero, {0}), {1}: (0, {1}) }
ENTRY entry {
%p = (f32[], f32[]) parameter(0)
%p0 = f32[] get-tuple-element((f32[], f32[]) %p), index=0
%p1 = f32[] get-tuple-element((f32[], f32[]) %p), index=1
ROOT %out = (f32[], f32[]) tuple(%p0, %p1)
}
)";
ExpectHasSubstr(
ParseAndReturnUnverifiedModule(original).status().error_message(),
"expects integer");
}
TEST_F(HloParserTest, MultipleRoots) {
const string original = R"(HloModule multiple_roots:
ENTRY consts {
ROOT const1 = f32[1]{0} constant({12345})
ROOT const2 = f32[1]{0} constant({12345})
})";
ExpectHasSubstr(
ParseAndReturnUnverifiedModule(original).status().error_message(),
"one computation should have only one ROOT");
}
TEST_F(HloParserTest, ComputationExists) {
const string original = R"(HloModule comp_exists
comp {
const1 = f32[1]{0} constant({12345})
}
comp {
const2 = f32[1]{0} constant({67890})
})";
ExpectHasSubstr(
ParseAndReturnUnverifiedModule(original).status().error_message(),
R"(was parsing 2:1: error: computation previously defined here
comp {
^)");
}
TEST_F(HloParserTest, CrossComputationLookup) {
const string original = R"(HloModule cross_computation_lookup:
tcalla (a: (s32[], s32[])) -> (s32[], s32[]) {
ROOT aparam = (s32[], s32[]) parameter(0)
}
tcallb (b: (s32[], s32[])) -> s32[] {
rparam = (s32[], s32[]) parameter(0)
ROOT gte0 = s32[] get-tuple-element(aparam), index=0
}
ENTRY entry {
param = (s32[], s32[]) parameter(0)
call0 = (s32[], s32[]) call(param), to_apply=tcalla
ROOT call1 = s32[] call(param), to_apply=tcallb
})";
ExpectHasSubstr(
ParseAndReturnUnverifiedModule(original).status().error_message(),
"was parsing 8:39: error: instruction does not exist: aparam");
}
TEST_F(HloParserTest, SameNameDiffComputations) {
const string original = R"(HloModule same_names:
add {
p0 = f32[] parameter(0)
p1 = f32[] parameter(1)
ROOT result = f32[] add(p0, p1)
}
ENTRY ReduceR3ToR2 {
p0 = f32[8,16,256]{2,1,0} parameter(0)
p1 = f32[] constant(0)
ROOT result = f32[8,16]{1,0} reduce(p0, p1), dimensions={2}, to_apply=add
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(original));
ASSERT_NE(module->entry_computation(), nullptr);
EXPECT_THAT(module->entry_computation()->root_instruction(),
GmockMatch(m::Reduce()));
}
TEST_F(HloParserTest, ParseSharding) {
const string original = "{maximal device=42}";
TF_ASSERT_OK_AND_ASSIGN(HloSharding sharding, ParseSharding(original));
EXPECT_EQ(sharding.ToString(), original);
}
TEST_F(HloParserTest, ParseShardingPartialReplication) {
const string original = "{devices=[2,2]0,1,2,3 last_tile_dim_replicate}";
TF_ASSERT_OK_AND_ASSIGN(HloSharding sharding, ParseSharding(original));
EXPECT_EQ(sharding.ToString(), original);
Array<int64> group_tiling({2});
group_tiling(0) = 0;
group_tiling(1) = 1;
std::vector<int64> group0_members({0, 1});
std::vector<int64> group1_members({2, 3});
EXPECT_EQ(
HloSharding::PartialTile(group_tiling, {group0_members, group1_members})
.ToString(),
original);
}
TEST_F(HloParserTest, ParseFrontendAttributes) {
const string original =
R"({attr_a="test_a",attr_b="b",attr_c="s64",attr_d="a/b"})";
TF_ASSERT_OK_AND_ASSIGN(FrontendAttributes frontend_attributes,
ParseFrontendAttributes(original));
EXPECT_EQ(FrontendAttributesToString(frontend_attributes), original);
}
TEST_F(HloParserTest, ParseWindow) {
Window original = window_util::MakeWindow({1, 2, 3});
TF_ASSERT_OK_AND_ASSIGN(Window parsed,
ParseWindow(window_util::ToString(original)))
EXPECT_EQ(window_util::ToString(original), window_util::ToString(parsed));
}
TEST_F(HloParserTest, ParseConvolutionDimensionNumbers) {
const string original = "b0f_0io->b0f";
TF_ASSERT_OK_AND_ASSIGN(ConvolutionDimensionNumbers dnums,
ParseConvolutionDimensionNumbers(original));
EXPECT_EQ(original, ConvolutionDimensionNumbersToString(dnums));
}
TEST_F(HloParserTest, ParseConvolutionDimensionNumbersWithUnknownDims) {
const string original = "b0?f_?0?io->?b?0?f";
TF_ASSERT_OK_AND_ASSIGN(ConvolutionDimensionNumbers dnums,
ParseConvolutionDimensionNumbers(original));
EXPECT_EQ(original, ConvolutionDimensionNumbersToString(dnums));
}
TEST_F(HloParserTest, ParseReplicaGroups) {
const string original = "{{0,1},{2,3}}";
TF_ASSERT_OK_AND_ASSIGN(std::vector<ReplicaGroup> replica_groups,
ParseReplicaGroupsOnly(original));
EXPECT_EQ(original, ReplicaGroupsToString(replica_groups));
}
TEST_F(HloParserTest, ParsePaddingConfigNoInteriorPadding) {
const string original = "0_1x2_3";
TF_ASSERT_OK_AND_ASSIGN(PaddingConfig dnums, ParsePaddingConfig(original));
EXPECT_EQ(original, PaddingConfigToString(dnums));
}
TEST_F(HloParserTest, ParsePaddingConfigInteriorPadding) {
const string original = "0_1_0x2_3_4";
TF_ASSERT_OK_AND_ASSIGN(PaddingConfig dnums, ParsePaddingConfig(original));
EXPECT_EQ(original, PaddingConfigToString(dnums));
}
TEST_F(HloParserTest, ParsePaddingConfigInteriorPaddingImplicitZeroDim) {
TF_ASSERT_OK_AND_ASSIGN(PaddingConfig dnums, ParsePaddingConfig("0_1x2_3_4"));
// The extra "_0" gets added to the canonical string because the other dim has
// interior padding.
EXPECT_EQ("0_1_0x2_3_4", PaddingConfigToString(dnums));
}
TEST_F(HloParserTest, NontupleInfeed) {
const string original = R"(HloModule nontuple_infeed:
ENTRY nontuple_infeed {
token0 = token[] after-all()
ROOT infeed = pred[] infeed(token0)
})";
ExpectHasSubstr(
ParseAndReturnUnverifiedModule(original).status().error_message(),
"infeed must have a non-empty tuple shape");
}
TEST(HloParserSingleOpTest, SingleOp) {
const string text =
"%multiply = f32[2,4]{1,0} multiply(f32[2,4]{1,0} %broadcast, "
"f32[2,4]{1,0} %x)";
TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnUnverifiedModule(text));
const HloComputation* computation = module->entry_computation();
ASSERT_NE(computation, nullptr);
EXPECT_THAT(computation->root_instruction(),
GmockMatch(m::Multiply(m::Parameter(0), m::Parameter(1))));
}
TEST(HloParserSingleOpTest, SingleOpNoShapeProducesError) {
const string text = "multiply(f32[2,4]{1,0} %broadcast, f32[2,4]{1,0} %x)";
StatusOr<std::unique_ptr<HloModule>> module =
ParseAndReturnUnverifiedModule(text);
ASSERT_TRUE(!module.status().ok());
LOG(INFO) << "Status: " << module.status();
EXPECT_THAT(module.status().ToString(),
::testing::HasSubstr("expects '=' in instruction"));
}
TEST(HloParserSingleOpTest, SingleOpNoOperandShapesProducesError) {
const string text = "%multiply = f32[2,4]{1,0} multiply(%broadcast, %x)";
StatusOr<std::unique_ptr<HloModule>> module =
ParseAndReturnUnverifiedModule(text);
ASSERT_TRUE(!module.status().ok());
LOG(INFO) << "Status: " << module.status();
EXPECT_THAT(module.status().ToString(),
::testing::HasSubstr("Operand had no shape in HLO text"));
}
TEST(HloParserSingleOpTest, SingleOpNoNames) {
const string text =
"%multiply = f32[2,4]{1,0} multiply(f32[2,4]{1,0}, f32[2,4]{1,0})";
TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnUnverifiedModule(text));
const HloComputation* computation = module->entry_computation();
ASSERT_NE(computation, nullptr);
EXPECT_THAT(computation->root_instruction(),
GmockMatch(m::Multiply(m::Parameter(0), m::Parameter(1))));
}
TEST(HloParserSingleOpTest, CanonicalOp) {
const string text = "f32[2,4]{1,0} multiply(f32[2,4]{1,0}, f32[2,4]{1,0})";
TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnUnverifiedModule(text));
const HloComputation* computation = module->entry_computation();
ASSERT_NE(computation, nullptr);
EXPECT_THAT(computation->root_instruction(),
GmockMatch(m::Multiply(m::Parameter(0), m::Parameter(1))));
EXPECT_EQ(
computation->root_instruction()->ToString(HloPrintOptions::Canonical()),
text);
}
TEST(HloParserSingleOpTest, CanonicalOpWithNested) {
const string text =
R"(f32[5,20]{1,0} while(f32[5,10]{1,0}), condition=
{
tmp_0 = f32[5,10]{1,0} parameter(0)
tmp_1 = f32[20,10]{1,0} parameter(1)
ROOT tmp_2 = f32[5,20]{1,0} fusion(f32[5,10]{1,0} tmp_0, f32[20,10]{1,0} tmp_1), kind=kLoop, calls=
{
tmp_0 = f32[5,10]{1,0} parameter(0)
tmp_1 = f32[20,10]{1,0} parameter(1)
tmp_2 = f32[10,20]{1,0} transpose(f32[20,10]{1,0} tmp_1), dimensions={1,0}
ROOT tmp_3 = f32[5,20]{1,0} dot(f32[5,10]{1,0} tmp_0, f32[10,20]{1,0} tmp_2), lhs_contracting_dims={1}, rhs_contracting_dims={0}
}
}, body=
{
tmp_0 = f32[5,10]{1,0} parameter(0)
tmp_1 = f32[20,10]{1,0} parameter(1)
ROOT tmp_2 = f32[5,20]{1,0} fusion(f32[5,10]{1,0} tmp_0, f32[20,10]{1,0} tmp_1), kind=kLoop, calls=
{
tmp_0 = f32[5,10]{1,0} parameter(0)
tmp_1 = f32[20,10]{1,0} parameter(1)
tmp_2 = f32[10,20]{1,0} transpose(f32[20,10]{1,0} tmp_1), dimensions={1,0}
ROOT tmp_3 = f32[5,20]{1,0} dot(f32[5,10]{1,0} tmp_0, f32[10,20]{1,0} tmp_2), lhs_contracting_dims={1}, rhs_contracting_dims={0}
}
})";
TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnUnverifiedModule(text));
const HloComputation* computation = module->entry_computation();
ASSERT_NE(computation, nullptr);
EXPECT_EQ(
computation->root_instruction()->ToString(HloPrintOptions::Canonical()),
text);
}
TEST(HloParserSingleOpTest, CanonicalOpIndexedConditionalInlinedBranches) {
const string text =
R"(f32[5,10]{1,0} conditional(s32[], f32[5,10]{1,0}, f32[5,10]{1,0}, f32[5,10]{1,0}), branch_computations={
{
tmp_0 = f32[5,10]{1,0} parameter(0)
ROOT tmp_1 = f32[5,10]{1,0} ceil(f32[5,10]{1,0} tmp_0)
},
{
tmp_0 = f32[5,10]{1,0} parameter(0)
ROOT tmp_1 = f32[5,10]{1,0} floor(f32[5,10]{1,0} tmp_0)
},
{
tmp_0 = f32[5,10]{1,0} parameter(0)
ROOT tmp_1 = f32[5,10]{1,0} copy(f32[5,10]{1,0} tmp_0)
}
})";
TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnUnverifiedModule(text));
const HloComputation* computation = module->entry_computation();
ASSERT_NE(computation, nullptr);
EXPECT_EQ(
computation->root_instruction()->ToString(HloPrintOptions::Canonical()),
text);
}
TEST(HloParserSingleOpTest, SingleOpWithNested) {
const string text =
R"(%fusion = f32[3,2,1,1]{3,2,1,0} fusion(f32[3,2,1,1]{3,2,1,0} %p0, f32[2]{0} %p1), kind=kLoop, calls=
{
%param_0 = f32[3,2,1,1]{3,2,1,0} parameter(0)
%param_1 = f32[2]{0} parameter(1)
%broadcast = f32[3,2,1,1]{3,2,1,0} broadcast(f32[2]{0} %param_1), dimensions={1}
ROOT %subtract = f32[3,2,1,1]{3,2,1,0} subtract(f32[3,2,1,1]{3,2,1,0} %param_0, f32[3,2,1,1]{3,2,1,0} %broadcast)
})";
TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnUnverifiedModule(text));
const HloComputation* computation = module->entry_computation();
ASSERT_NE(computation, nullptr);
EXPECT_THAT(computation->root_instruction(),
GmockMatch(m::Op()
.WithOpcode(HloOpcode::kFusion)
.WithNumOperands(2)
.WithOperand(0, m::Parameter(0))
.WithOperand(1, m::Parameter(1))));
}
TEST(HloParserSingleOpTest, SingleOpWithNested_DoesNotExist) {
const string text =
R"(reduce = f32[] reduce(f32[10], f32[]), dimensions={1}, to_apply=
{
result = f32[] add(f32[] x, f32[] y)
})";
auto status = ParseAndReturnUnverifiedModule(text).status();
ASSERT_FALSE(status.ok());
EXPECT_THAT(status.error_message(),
::testing::HasSubstr("does not exist: x"));
}
TEST(HloParserSingleOpTest, SingleOpWithNested_NoLhs) {
const string text =
R"(reduce = f32[] reduce(f32[10], f32[]), dimensions={1}, to_apply=
{
f32[] add(f32[] x, f32[] y)
})";
auto status = ParseAndReturnUnverifiedModule(text).status();
ASSERT_FALSE(status.ok());
EXPECT_THAT(status.error_message(), ::testing::HasSubstr("expects name"));
}
TEST(HloParserSingleOpTest, SingleOpWithNested_NoOperandName) {
const string text =
R"(reduce = f32[] reduce(f32[10], f32[]), dimensions={1}, to_apply=
{
result = f32[] add(f32[], f32[])
})";
auto status = ParseAndReturnUnverifiedModule(text).status();
ASSERT_FALSE(status.ok());
EXPECT_THAT(status.error_message(), ::testing::HasSubstr("expects name"));
}
TEST(HloParserSingleOpTest, ConvolutionTrivialFeatureGroupCount) {
const string text =
R"(%convolution = f32[1,2,1]{2,0,1} convolution(f32[1,2,1]{2,0,1} %copy, f32[1,1,1]{2,1,0} %filter), window={size=1}, dim_labels=b0f_0io->b0f)";
TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnUnverifiedModule(text));
const HloComputation* computation = module->entry_computation();
ASSERT_NE(computation, nullptr);
EXPECT_THAT(computation->root_instruction(),
GmockMatch(m::Convolution(m::Parameter(0), m::Parameter(1))));
auto* convolution =
Cast<HloConvolutionInstruction>(computation->root_instruction());
EXPECT_EQ(convolution->feature_group_count(), 1);
}
TEST(HloParserSingleOpTest, MultipleOpsProducesError) {
const string text = R"(
param = f32[2,5,1,3] parameter(0)
transpose = f32[1,5,2,3] transpose(param), dimensions={2,1,0,3}
)";
auto status = ParseAndReturnUnverifiedModule(text).status();
ASSERT_FALSE(status.ok());
EXPECT_THAT(status.error_message(), ::testing::HasSubstr("Expected eof"));
}
TEST_F(HloParserTest, IsScheduledIsFalse) {
const string text = R"(
HloModule axpy_module, is_scheduled=false
ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] {
%alpha = f32[] parameter(0)
%broadcast = f32[2,4]{1,0} broadcast(f32[] %alpha), dimensions={}
%x = f32[2,4]{1,0} parameter(1)
%multiply = f32[2,4]{1,0} multiply(f32[2,4]{1,0} %broadcast, f32[2,4]{1,0} %x)
%y = f32[2,4]{1,0} parameter(2)
ROOT %add = f32[2,4]{1,0} add(f32[2,4]{1,0} %multiply, f32[2,4]{1,0} %y)
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(text));
ASSERT_FALSE(module->has_schedule());
}
TEST_F(HloParserTest, IsScheduledNotPresent) {
const string text = R"(
HloModule axpy_module
ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] {
%alpha = f32[] parameter(0)
%broadcast = f32[2,4]{1,0} broadcast(f32[] %alpha), dimensions={}
%x = f32[2,4]{1,0} parameter(1)
%multiply = f32[2,4]{1,0} multiply(f32[2,4]{1,0} %broadcast, f32[2,4]{1,0} %x)
%y = f32[2,4]{1,0} parameter(2)
ROOT %add = f32[2,4]{1,0} add(f32[2,4]{1,0} %multiply, f32[2,4]{1,0} %y)
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(text));
ASSERT_FALSE(module->has_schedule());
}
TEST_F(HloParserTest, IsScheduledIsTrue) {
const string text = R"(
HloModule axpy_module, is_scheduled=true
ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] {
%alpha = f32[] parameter(0)
%broadcast = f32[2,4]{1,0} broadcast(f32[] %alpha), dimensions={}
%x = f32[2,4]{1,0} parameter(1)
%multiply = f32[2,4]{1,0} multiply(f32[2,4]{1,0} %broadcast, f32[2,4]{1,0} %x)
%y = f32[2,4]{1,0} parameter(2)
ROOT %add = f32[2,4]{1,0} add(f32[2,4]{1,0} %multiply, f32[2,4]{1,0} %y)
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(text));
ASSERT_TRUE(module->has_schedule());
TF_ASSERT_OK(module->schedule().Verify());
EXPECT_EQ(module->schedule().sequences().size(), 1);
ASSERT_TRUE(
module->schedule().is_computation_scheduled(module->entry_computation()));
EXPECT_THAT(
module->schedule().sequence(module->entry_computation()).instructions(),
::testing::ElementsAre(
GmockMatch(m::Parameter()), GmockMatch(m::Broadcast()),
GmockMatch(m::Parameter()), GmockMatch(m::Multiply()),
GmockMatch(m::Parameter()), GmockMatch(m::Add())));
}
TEST_F(HloParserTest, IsScheduledIsTrueDifferentOrder) {
// As above but in with a different schedule order.
const string text = R"(
HloModule axpy_module, is_scheduled=true
ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] {
%alpha = f32[] parameter(0)
%x = f32[2,4]{1,0} parameter(1)
%y = f32[2,4]{1,0} parameter(2)
%broadcast = f32[2,4]{1,0} broadcast(f32[] %alpha), dimensions={}
%multiply = f32[2,4]{1,0} multiply(f32[2,4]{1,0} %broadcast, f32[2,4]{1,0} %x)
ROOT %add = f32[2,4]{1,0} add(f32[2,4]{1,0} %multiply, f32[2,4]{1,0} %y)
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(text));
ASSERT_TRUE(module->has_schedule());
TF_ASSERT_OK(module->schedule().Verify());
EXPECT_EQ(module->schedule().sequences().size(), 1);
ASSERT_TRUE(
module->schedule().is_computation_scheduled(module->entry_computation()));
EXPECT_THAT(
module->schedule().sequence(module->entry_computation()).instructions(),
::testing::ElementsAre(
GmockMatch(m::Parameter()), GmockMatch(m::Parameter()),
GmockMatch(m::Parameter()), GmockMatch(m::Broadcast()),
GmockMatch(m::Multiply()), GmockMatch(m::Add())));
}
TEST_F(HloParserTest, CustomCallWrongNumberofOperandConstraints) {
const string original = R"(HloModule CustomCallWrongNumberofOperandConstraints
ENTRY %CustomCallWrongNumberofOperandConstraints (p0: f32[42,2,3], p1: f32[123,4]) -> f32[1,2,3] {
%p0 = f32[42,2,3]{0,1,2} parameter(0)
%p1 = f32[123,4]{0,1} parameter(1)
ROOT %custom-call = f32[1,2,3]{0,1,2} custom-call(f32[42,2,3]{0,1,2} %p0, f32[123,4]{0,1} %p1), custom_call_target="baz", operand_layout_constraints={f32[42,2,3]{0,1,2}}
}
)";
ExpectHasSubstr(
ParseAndReturnUnverifiedModule(original).status().error_message(),
"Expected 2 operand layout constraints, 1 given");
}
TEST_F(HloParserTest, CustomCallIncompatibleOperandConstraints) {
const string original = R"(HloModule CustomCallIncompatibleOperandConstraints
ENTRY %CustomCallIncompatibleOperandConstraints (p0: f32[42,2,3], p1: f32[123,4]) -> f32[1,2,3] {
%p0 = f32[42,2,3]{0,1,2} parameter(0)
%p1 = f32[123,4]{0,1} parameter(1)
ROOT %custom-call = f32[1,2,3]{0,1,2} custom-call(f32[42,2,3]{0,1,2} %p0, f32[123,4]{0,1} %p1), custom_call_target="baz", operand_layout_constraints={f32[42,2,3]{0,1,2}, f32[555,5]{1,0}}
}
)";
ExpectHasSubstr(
ParseAndReturnUnverifiedModule(original).status().error_message(),
"operand 1 is not compatible with operand shape");
}
TEST_F(HloParserTest, AllowShapeWhitespace) {
const string text = R"(
HloModule module
ENTRY entry {
ROOT root = f32[ 1, 2,3, 4, 5]{0, 1, 2,3, 4 } parameter(0)
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(text));
}
TEST_F(HloParserTest, ShapeMismatchInOperand) {
const string text = R"(
HloModule foobar
ENTRY %entrycomp (p: f32[2,2]) -> f32[2,2] {
%p = f32[2,2] parameter(0)
%constant.1 = f32[2,2] constant({{1, 2}, {3, 4}})
ROOT %add.1 = f32[2,2] add(f32[2,2] %p, f32[2,5] %constant.1)
}
)";
ExpectHasSubstr(ParseAndReturnUnverifiedModule(text).status().error_message(),
"The declared operand shape f32[2,5]{1,0} is not compatible"
" with the shape of the operand instruction f32[2,2]{1,0}.");
}
TEST_F(HloParserTest, ParseShapeStringR2F32) {
string shape_string = "f32[123,456]";
TF_ASSERT_OK_AND_ASSIGN(Shape actual, ParseShape(shape_string));
Shape expected = ShapeUtil::MakeShape(F32, {123, 456});
ASSERT_TRUE(ShapeUtil::Equal(expected, actual))
<< "expected: " << ShapeUtil::HumanString(expected)
<< "actual: " << ShapeUtil::HumanString(actual);
}
TEST_F(HloParserTest, ParseShapeStringTupleOfArrays) {
string shape_string = "(f32[1572864],s8[5120,1024])";
TF_ASSERT_OK_AND_ASSIGN(Shape actual, ParseShape(shape_string));
Shape expected =
ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {1572864}),
ShapeUtil::MakeShape(S8, {5120, 1024})});
ASSERT_TRUE(ShapeUtil::Equal(expected, actual))
<< "expected: " << ShapeUtil::HumanString(expected)
<< "actual: " << ShapeUtil::HumanString(actual);
}
TEST_F(HloParserTest, ParseShapeStringNestedTuple) {
string shape_string = "(f32[1],(f32[2], token[]), opaque[], f32[3])";
TF_ASSERT_OK_AND_ASSIGN(Shape actual, ParseShape(shape_string));
Shape expected = ShapeUtil::MakeTupleShape({
ShapeUtil::MakeShape(F32, {1}),
ShapeUtil::MakeTupleShape(
{ShapeUtil::MakeShape(F32, {2}), ShapeUtil::MakeTokenShape()}),
ShapeUtil::MakeOpaqueShape(),
ShapeUtil::MakeShape(F32, {3}),
});
ASSERT_TRUE(ShapeUtil::Equal(expected, actual))
<< "expected: " << ShapeUtil::HumanString(expected)
<< "actual: " << ShapeUtil::HumanString(actual);
}
TEST_F(HloParserTest, ParseShapeStringWithLayout) {
string shape_string = "f32[123,456]{0,1}";
TF_ASSERT_OK_AND_ASSIGN(Shape actual, ParseShape(shape_string));
Shape expected = ShapeUtil::MakeShapeWithLayout(F32, {123, 456}, {0, 1});
ASSERT_TRUE(ShapeUtil::Equal(expected, actual))
<< "expected: " << ShapeUtil::HumanString(expected)
<< "actual: " << ShapeUtil::HumanString(actual);
}
TEST_F(HloParserTest, ParseShapeStringWithInvalidLayout) {
string shape_string = "f32[123,456]invalid{}";
TF_ASSERT_OK_AND_ASSIGN(Shape actual, ParseShape(shape_string));
Shape expected = ShapeUtil::MakeShape(F32, {123, 456});
ASSERT_TRUE(ShapeUtil::Compatible(expected, actual))
<< "expected: " << ShapeUtil::HumanString(expected)
<< "actual: " << ShapeUtil::HumanString(actual);
}
TEST_F(HloParserTest, ParseShapeStringWithTilingLayout) {
// One tile.
string shape_string = "f32[123,456]{0,1:T(2,128)}";
TF_ASSERT_OK_AND_ASSIGN(Shape actual, ParseShape(shape_string));
Shape expected =
ShapeUtil::MakeShapeWithLayout(F32, {123, 456}, {0, 1}, {Tile({2, 128})});
EXPECT_EQ(expected, actual)
<< "expected: " << ShapeUtil::HumanStringWithLayout(expected)
<< "actual: " << ShapeUtil::HumanStringWithLayout(actual);
// Tile with negative dimension size for combining dimensions.
shape_string = "f32[123,456,789]{0,1,2:T(2, * , 128)}";
TF_ASSERT_OK_AND_ASSIGN(actual, ParseShape(shape_string));
expected =
ShapeUtil::MakeShapeWithLayout(F32, {123, 456, 789}, {0, 1, 2},
{Tile({2, Tile::kCombineDimension, 128})});
EXPECT_EQ(expected, actual)
<< "expected: " << ShapeUtil::HumanStringWithLayout(expected)
<< "actual: " << ShapeUtil::HumanStringWithLayout(actual);
// Two tiles.
shape_string = "bf16[123,456,789]{2,1,0:T(2,*,128)(2,1)}";
TF_ASSERT_OK_AND_ASSIGN(actual, ParseShape(shape_string));
expected = ShapeUtil::MakeShapeWithLayout(
BF16, {123, 456, 789}, {2, 1, 0},
{Tile({2, Tile::kCombineDimension, 128}), Tile({2, 1})});
EXPECT_EQ(expected, actual)
<< "expected: " << ShapeUtil::HumanStringWithLayout(expected)
<< "actual: " << ShapeUtil::HumanStringWithLayout(actual);
// Tile with element size in bits.
shape_string = "pred[123,456]{1,0:T(2,128)E(1)}";
TF_ASSERT_OK_AND_ASSIGN(actual, ParseShape(shape_string));
expected = ShapeUtil::MakeShapeWithLayout(PRED, {123, 456}, {1, 0},
{Tile({2, 128})}, 1);
EXPECT_EQ(expected, actual)
<< "expected: " << ShapeUtil::HumanStringWithLayout(expected)
<< "actual: " << ShapeUtil::HumanStringWithLayout(actual);
// Element size in bits without tile.
shape_string = "pred[123,456]{1,0:E(1)}";
TF_ASSERT_OK_AND_ASSIGN(actual, ParseShape(shape_string));
expected = ShapeUtil::MakeShapeWithLayout(PRED, {123, 456}, {1, 0}, {}, 1);
EXPECT_EQ(expected, actual)
<< "expected: " << ShapeUtil::HumanStringWithLayout(expected)
<< "actual: " << ShapeUtil::HumanStringWithLayout(actual);
// Wrong minor_to_major.
shape_string = "f32[123,456,789]{1:T(2, * , 128)}";
auto result = ParseShape(shape_string);
ExpectHasSubstr(result.status().error_message(),
"Dimensions size is 3, but minor to major size is 1.");
}
TEST_F(HloParserTest, ParseShapeStringWithMemorySpaceLayout) {
// Tile, element size, and memory space.
string shape_string = "pred[123,456]{1,0:T(2,128)E(1)S(3)}";
TF_ASSERT_OK_AND_ASSIGN(Shape actual, ParseShape(shape_string));
Shape expected = ShapeUtil::MakeShapeWithLayout(PRED, {123, 456}, {1, 0},
{Tile({2, 128})}, 1, 3);
EXPECT_EQ(expected, actual)
<< "expected: " << ShapeUtil::HumanStringWithLayout(expected)
<< "actual: " << ShapeUtil::HumanStringWithLayout(actual);
// Element size and memory space.
shape_string = "pred[123,456]{1,0:E(1)S(3)}";
TF_ASSERT_OK_AND_ASSIGN(actual, ParseShape(shape_string));
expected = ShapeUtil::MakeShapeWithLayout(PRED, {123, 456}, {1, 0}, {}, 1, 3);
EXPECT_EQ(expected, actual)
<< "expected: " << ShapeUtil::HumanStringWithLayout(expected)
<< "actual: " << ShapeUtil::HumanStringWithLayout(actual);
// Memory space only.
shape_string = "pred[123,456]{1,0:S(3)}";
TF_ASSERT_OK_AND_ASSIGN(actual, ParseShape(shape_string));
expected = ShapeUtil::MakeShapeWithLayout(PRED, {123, 456}, {1, 0}, {}, 0, 3);
EXPECT_EQ(expected, actual)
<< "expected: " << ShapeUtil::HumanStringWithLayout(expected)
<< "actual: " << ShapeUtil::HumanStringWithLayout(actual);
}
TEST_F(HloParserTest, ParseOpaqueType) {
TF_ASSERT_OK_AND_ASSIGN(Shape actual, ParseShape("opaque[]"));
Shape expected = ShapeUtil::MakeOpaqueShape();
ASSERT_TRUE(ShapeUtil::Equal(expected, actual))
<< "expected: " << ShapeUtil::HumanString(expected)
<< "actual: " << ShapeUtil::HumanString(actual);
}
TEST_F(HloParserTest, ParseTokenType) {
TF_ASSERT_OK_AND_ASSIGN(Shape actual, ParseShape("token[]"));
Shape expected = ShapeUtil::MakeTokenShape();
ASSERT_TRUE(ShapeUtil::Equal(expected, actual))
<< "expected: " << ShapeUtil::HumanString(expected)
<< "actual: " << ShapeUtil::HumanString(actual);
}
TEST_F(HloParserTest, ParseInvalidShapeString) {
string shape_strings[] = {"f32[123,456]foobar{0,1}", "f32[123,456]{foo}",
"f32[123,456]dense{foo}"};
for (const string& shape_string : shape_strings) {
StatusOr<Shape> result = ParseShape(shape_string);
ASSERT_FALSE(result.ok()) << "shape: " << shape_string;
}
}
TEST_F(HloParserTest, ParseDynamicArray) {
string shape_string = "f32[123,<=456]";
TF_ASSERT_OK_AND_ASSIGN(Shape actual, ParseShape(shape_string));
Shape expected = ShapeUtil::MakeShape(F32, {123, 456}, {false, true});
ASSERT_TRUE(ShapeUtil::Equal(expected, actual))
<< "expected: " << ShapeUtil::HumanString(expected)
<< "actual: " << ShapeUtil::HumanString(actual);
}
TEST_F(HloParserTest, ParseDynamicTuple) {
string shape_string = "(f32[42], u32[<=123,<=456])";
TF_ASSERT_OK_AND_ASSIGN(Shape actual, ParseShape(shape_string));
Shape expected = ShapeUtil::MakeTupleShape(
{ShapeUtil::MakeShape(F32, {42}),
ShapeUtil::MakeShape(U32, {123, 456}, {true, true})});
ASSERT_TRUE(ShapeUtil::Equal(expected, actual))
<< "expected: " << ShapeUtil::HumanString(expected)
<< "actual: " << ShapeUtil::HumanString(actual);
}
TEST_F(HloParserTest, NegativeParameterNumber) {
const string hlo_string = "par0 = f32[3,5] parameter(-1)";
auto result = ParseAndReturnUnverifiedModule(hlo_string);
ASSERT_FALSE(result.status().ok());
EXPECT_THAT(result.status().error_message(),
::testing::HasSubstr("parameter number must be >= 0"));
}
TEST_F(HloParserTest, WrongNumberOfParameterLeafBuffersInReplication) {
const string hlo_string =
"par0 = (f32[3,5], f32[]) parameter(0), "
"parameter_replication={true,false,true}";
auto result = ParseAndReturnUnverifiedModule(hlo_string);
ASSERT_FALSE(result.status().ok());
EXPECT_THAT(result.status().error_message(),
::testing::HasSubstr("parameter has 2 leaf buffers, but "
"parameter_replication has 3 elements"));
}
TEST_F(HloParserTest, CheckIndexedConditionalDimension) {
const char* const hlo_string = R"(
HloModule Module
branch0 {
tparam = f32[4] parameter(0)
ROOT tgte1 = f32[4] ceil(tparam)
}
branch1 {
fparam = f32[4] parameter(0)
ROOT fgte1 = f32[4] floor(fparam)
}
ENTRY entry {
p0 = f32[4] parameter(0)
b0 = s32[2] parameter(1)
ROOT conditional = f32[4] conditional(b0, p0, p0),
branch_computations={branch0, branch1}
}
)";
auto result = ParseAndReturnUnverifiedModule(hlo_string);
EXPECT_NE(Status::OK(), result.status());
EXPECT_THAT(result.status().error_message(),
::testing::HasSubstr("The first operand must be a scalar"));
}
TEST_F(HloParserTest, CheckIndexedConditionalElementType) {
const char* const hlo_string = R"(
HloModule Module
branch0 {
tparam = f32[4] parameter(0)
ROOT tgte1 = f32[4] ceil(tparam)
}
branch1 {
fparam = f32[4] parameter(0)
ROOT fgte1 = f32[4] floor(fparam)
}
ENTRY entry {
p0 = f32[4] parameter(0)
b0 = f32[] parameter(1)
ROOT conditional = f32[4] conditional(b0, p0, p0),
branch_computations={branch0, branch1}
}
)";
auto result = ParseAndReturnUnverifiedModule(hlo_string);
EXPECT_NE(Status::OK(), result.status());
EXPECT_THAT(result.status().error_message(),
::testing::HasSubstr(
"The first operand must be a scalar of PRED or S32"));
}
TEST_F(HloParserTest,
CheckPredicatedConditionalRequiresTrueAndFalseComputation) {
const char* const hlo_string = R"(
HloModule Module
branch0 {
tparam = f32[4] parameter(0)
ROOT tgte1 = f32[4] ceil(tparam)
}
branch1 {
fparam = f32[4] parameter(0)
ROOT fgte1 = f32[4] floor(fparam)
}
ENTRY entry {
p0 = f32[4] parameter(0)
b0 = pred[] parameter(1)
ROOT conditional = f32[4] conditional(b0, p0, p0),
branch_computations={branch0, branch1}
}
)";
auto result = ParseAndReturnUnverifiedModule(hlo_string);
EXPECT_NE(Status::OK(), result.status());
EXPECT_THAT(
result.status().error_message(),
::testing::HasSubstr("unexpected attribute \"branch_computations\""));
}
// Result shape inference tests cases.
TEST_F(HloParserTest, InferUnaryShape) {
constexpr char text[] = R"(HloModule InferUnaryShapeTest
ENTRY InferUnaryShape {
a = f32[2,10]{1,0} parameter(0)
ROOT v = abs(a)
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(text));
}
TEST_F(HloParserTest, InferBinaryShape) {
constexpr char text[] = R"(HloModule InferBinaryShapeTest
ENTRY InferBinaryShape {
a = f32[2,10]{1,0} parameter(0)
b = f32[2,10]{1,0} parameter(1)
ROOT sum = add(a, b)
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(text));
EXPECT_TRUE(ShapeUtil::Equal(
module->entry_computation()->ComputeProgramShape().result(),
ShapeUtil::MakeShapeWithLayout(F32, {2, 10}, {1, 0})));
}
TEST_F(HloParserTest, InferTernaryShape) {
constexpr char text[] = R"(HloModule InferTernaryShapeTest
ENTRY InferTernaryShape {
p = pred[] constant(true)
f = s32[] constant(-42)
t = s32[] constant(42)
ROOT select = select(p, f, t)
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(text));
EXPECT_TRUE(ShapeUtil::Equal(
module->entry_computation()->ComputeProgramShape().result(),
ShapeUtil::MakeScalarShape(S32)));
}
TEST_F(HloParserTest, InferDotShape) {
constexpr char text[] = R"(HloModule InferDotShapeTest
ENTRY InferDotShape {
a = f32[2,10]{1,0} parameter(0)
b = f32[10,2]{1,0} parameter(1)
ROOT dot = dot(a, b), lhs_batch_dims={0}, lhs_contracting_dims={1}, rhs_batch_dims={1}, rhs_contracting_dims={0}
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(text));
EXPECT_TRUE(ShapeUtil::Equal(
module->entry_computation()->ComputeProgramShape().result(),
ShapeUtil::MakeShape(F32, {2}, {0})));
}
TEST_F(HloParserTest, InferTupleShape) {
constexpr char text[] = R"(HloModule InferTupleShapeTest
ENTRY InferTupleShape () -> s32[2,3] {
c0 = f32[3]{0} constant({1, 2, 3})
c1 = s32[2,3]{1,0} constant({ { 1, 2, 3 }, { 4, 5, 6 } })
tuple = tuple(c0, c1)
ROOT get = get-tuple-element(tuple), index=1, sharding={maximal device=0}
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(text));
EXPECT_TRUE(ShapeUtil::Equal(
module->entry_computation()->ComputeProgramShape().result(),
ShapeUtil::MakeShapeWithLayout(S32, {2, 3}, {1, 0})));
}
TEST_F(HloParserTest, InferShapeMixedExplicitShape) {
constexpr char text[] = R"(HloModule InferUnaryShapeTest
Negate {
x = f32[] parameter(0)
ROOT negate = negate(x)
}
Identity {
y = f32[] parameter(0)
ROOT copy = copy(y)
}
ENTRY InferUnaryShape {
a = f32[] parameter(0)
b = f32[] parameter(1)
p = pred[] parameter(2)
c = f32[] add(a, b)
ROOT conditional = conditional(p, a, c), true_computation=Negate, false_computation=Identity
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(text));
EXPECT_TRUE(ShapeUtil::Equal(
module->entry_computation()->ComputeProgramShape().result(),
ShapeUtil::MakeScalarShape(F32)));
}
} // namespace
} // namespace xla