blob: b417994ba7ac3cd5ac55126cde236c327fc64fd9 [file] [log] [blame]
// RUN: hlo_to_llvm_ir %s | FileCheck %s
// This tests that we do not increase the grid launch size when
// few_waves is enabled.
// CHECK-LABEL: entry:
// CHECK-DAG: call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range ![[ctaid_range:[0-9]+]]
// CHECK-DAG: call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range ![[tid_range:[0-9]+]]
// CHECK-DAG: ![[ctaid_range]] = !{i32 0, i32 2}
// CHECK-DAG: ![[tid_range]] = !{i32 0, i32 256}
HloModule Test
ENTRY main {
a = f32[100, 20]{1,0} parameter(0)
ROOT b = f32[100, 20]{1,0} round-nearest-even(a)
}
// -----
// This tests that we cap grid launch code when few_waves is enabled.
// CHECK-LABEL: entry:
// CHECK-DAG: call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range ![[ctaid_range:[0-9]+]]
// CHECK-DAG: call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range ![[tid_range:[0-9]+]]
// CHECK-DAG: ![[ctaid_range]] = !{i32 0, i32 1280}
// CHECK-DAG: ![[tid_range]] = !{i32 0, i32 128}
HloModule Test
ENTRY main {
a = f32[10000, 10000]{1,0} parameter(0)
ROOT b = f32[10000, 10000]{1,0} round-nearest-even(a)
}
// -----
// This tests that we cap grid launch code when few_waves is enabled
// and scalar broadcast are present.
// CHECK-LABEL: entry:
// CHECK-DAG: call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range ![[ctaid_range:[0-9]+]]
// CHECK-DAG: call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range ![[tid_range:[0-9]+]]
// CHECK-DAG: ![[ctaid_range]] = !{i32 0, i32 1280}
// CHECK-DAG: ![[tid_range]] = !{i32 0, i32 128}
HloModule ScalarBroadcast
%fused_computation.3 (param_0: f32[], param_1: f32[10000, 10000]) -> f32[10000, 10000] {
%param_0 = f32[] parameter(0)
%broadcast = f32[10000, 10000]{1,0} broadcast(%param_0), dimensions={}
%param_1 = f32[10000, 10000]{1,0} parameter(1)
ROOT %add = f32[10000, 10000]{1,0} add(%broadcast, %param_1)
}
ENTRY main {
%param_0 = f32[] parameter(0)
%param_1 = f32[10000, 10000]{1,0} parameter(1)
ROOT %fusion.3 = f32[10000, 10000]{1,0} fusion(%param_0, %param_1), kind=kLoop, calls=%fused_computation.3
}
// -----
// This tests the GELU kernel. The original kernel that
// motivated few_waves implementation.
// CHECK-LABEL: entry:
// CHECK-DAG: call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range ![[ctaid_range:[0-9]+]]
// CHECK-DAG: call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range ![[tid_range:[0-9]+]]
// CHECK-DAG: ![[ctaid_range]] = !{i32 0, i32 1280}
// CHECK-DAG: ![[tid_range]] = !{i32 0, i32 128}
HloModule Test
%fused_computation (param_0: f16[6,512,4096]) -> f16[6,512,4096] {
%param_0 = f16[6,512,4096]{2,1,0} parameter(0)
%power.tmp.1 = f16[6,512,4096]{2,1,0} multiply(f16[6,512,4096]{2,1,0} %param_0, f16[6,512,4096]{2,1,0} %param_0)
%power.0 = f16[6,512,4096]{2,1,0} multiply(f16[6,512,4096]{2,1,0} %power.tmp.1, f16[6,512,4096]{2,1,0} %param_0)
%constant_4 = f16[] constant(0.044708), metadata={op_type="Mul" op_name="mul"}
%broadcast.3 = f16[6,512,4096]{2,1,0} broadcast(f16[] %constant_4), dimensions={}, metadata={op_type="Mul" op_name="mul"}
%multiply.3 = f16[6,512,4096]{2,1,0} multiply(f16[6,512,4096]{2,1,0} %power.0, f16[6,512,4096]{2,1,0} %broadcast.3), metadata={op_type="Mul" op_name="mul"}
%add.1 = f16[6,512,4096]{2,1,0} add(f16[6,512,4096]{2,1,0} %param_0, f16[6,512,4096]{2,1,0} %multiply.3), metadata={op_type="AddV2" op_name="add"}
%constant_2 = f16[] constant(0.79785), metadata={op_type="Mul" op_name="mul_1"}
%broadcast.2 = f16[6,512,4096]{2,1,0} broadcast(f16[] %constant_2), dimensions={}, metadata={op_type="Mul" op_name="mul_1"}
%multiply.2 = f16[6,512,4096]{2,1,0} multiply(f16[6,512,4096]{2,1,0} %add.1, f16[6,512,4096]{2,1,0} %broadcast.2), metadata={op_type="Mul" op_name="mul_1"}
%tanh.0 = f16[6,512,4096]{2,1,0} tanh(f16[6,512,4096]{2,1,0} %multiply.2), metadata={op_type="Tanh" op_name="Tanh"}
%constant_1 = f16[] constant(1), metadata={op_type="AddV2" op_name="add_1"}
%broadcast.1 = f16[6,512,4096]{2,1,0} broadcast(f16[] %constant_1), dimensions={}, metadata={op_type="AddV2" op_name="add_1"}
%add.0 = f16[6,512,4096]{2,1,0} add(f16[6,512,4096]{2,1,0} %tanh.0, f16[6,512,4096]{2,1,0} %broadcast.1), metadata={op_type="AddV2" op_name="add_1"}
%constant_0 = f16[] constant(0.5), metadata={op_type="Mul" op_name="mul_2"}
%broadcast.0 = f16[6,512,4096]{2,1,0} broadcast(f16[] %constant_0), dimensions={}, metadata={op_type="Mul" op_name="mul_2"}
%multiply.1 = f16[6,512,4096]{2,1,0} multiply(f16[6,512,4096]{2,1,0} %add.0, f16[6,512,4096]{2,1,0} %broadcast.0), metadata={op_type="Mul" op_name="mul_2"}
ROOT %multiply.0 = f16[6,512,4096]{2,1,0} multiply(f16[6,512,4096]{2,1,0} %param_0, f16[6,512,4096]{2,1,0} %multiply.1), metadata={op_type="Mul" op_name="mul_3"}
}
ENTRY %cluster_0__XlaCompiledKernel_true__XlaNumConstantArgs_0__XlaNumResourceArgs_0_.24 (arg0.1: f16[6,512,4096]) -> f16[6,512,4096] {
%arg0.1 = f16[6,512,4096]{2,1,0} parameter(0), parameter_replication={false}, metadata={op_name="XLA_Args"}
ROOT %fusion = f16[6,512,4096]{2,1,0} fusion(f16[6,512,4096]{2,1,0} %arg0.1), kind=kLoop, calls=%fused_computation, metadata={op_type="Mul" op_name="mul_3"}
}