| // RUN: hlo_to_llvm_ir %s | FileCheck %s |
| // This tests that we do not increase the grid launch size when |
| // few_waves is enabled. |
| |
| // CHECK-LABEL: entry: |
| // CHECK-DAG: call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range ![[ctaid_range:[0-9]+]] |
| // CHECK-DAG: call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range ![[tid_range:[0-9]+]] |
| // CHECK-DAG: ![[ctaid_range]] = !{i32 0, i32 2} |
| // CHECK-DAG: ![[tid_range]] = !{i32 0, i32 256} |
| |
| HloModule Test |
| |
| ENTRY main { |
| a = f32[100, 20]{1,0} parameter(0) |
| ROOT b = f32[100, 20]{1,0} round-nearest-even(a) |
| } |
| |
| // ----- |
| |
| // This tests that we cap grid launch code when few_waves is enabled. |
| |
| // CHECK-LABEL: entry: |
| // CHECK-DAG: call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range ![[ctaid_range:[0-9]+]] |
| // CHECK-DAG: call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range ![[tid_range:[0-9]+]] |
| // CHECK-DAG: ![[ctaid_range]] = !{i32 0, i32 1280} |
| // CHECK-DAG: ![[tid_range]] = !{i32 0, i32 128} |
| |
| HloModule Test |
| |
| ENTRY main { |
| a = f32[10000, 10000]{1,0} parameter(0) |
| ROOT b = f32[10000, 10000]{1,0} round-nearest-even(a) |
| } |
| |
| // ----- |
| |
| // This tests that we cap grid launch code when few_waves is enabled |
| // and scalar broadcast are present. |
| |
| // CHECK-LABEL: entry: |
| // CHECK-DAG: call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range ![[ctaid_range:[0-9]+]] |
| // CHECK-DAG: call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range ![[tid_range:[0-9]+]] |
| // CHECK-DAG: ![[ctaid_range]] = !{i32 0, i32 1280} |
| // CHECK-DAG: ![[tid_range]] = !{i32 0, i32 128} |
| |
| HloModule ScalarBroadcast |
| |
| %fused_computation.3 (param_0: f32[], param_1: f32[10000, 10000]) -> f32[10000, 10000] { |
| %param_0 = f32[] parameter(0) |
| %broadcast = f32[10000, 10000]{1,0} broadcast(%param_0), dimensions={} |
| %param_1 = f32[10000, 10000]{1,0} parameter(1) |
| ROOT %add = f32[10000, 10000]{1,0} add(%broadcast, %param_1) |
| } |
| |
| ENTRY main { |
| %param_0 = f32[] parameter(0) |
| %param_1 = f32[10000, 10000]{1,0} parameter(1) |
| |
| ROOT %fusion.3 = f32[10000, 10000]{1,0} fusion(%param_0, %param_1), kind=kLoop, calls=%fused_computation.3 |
| } |
| |
| |
| // ----- |
| // This tests the GELU kernel. The original kernel that |
| // motivated few_waves implementation. |
| |
| // CHECK-LABEL: entry: |
| // CHECK-DAG: call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range ![[ctaid_range:[0-9]+]] |
| // CHECK-DAG: call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range ![[tid_range:[0-9]+]] |
| // CHECK-DAG: ![[ctaid_range]] = !{i32 0, i32 1280} |
| // CHECK-DAG: ![[tid_range]] = !{i32 0, i32 128} |
| |
| HloModule Test |
| |
| %fused_computation (param_0: f16[6,512,4096]) -> f16[6,512,4096] { |
| %param_0 = f16[6,512,4096]{2,1,0} parameter(0) |
| %power.tmp.1 = f16[6,512,4096]{2,1,0} multiply(f16[6,512,4096]{2,1,0} %param_0, f16[6,512,4096]{2,1,0} %param_0) |
| %power.0 = f16[6,512,4096]{2,1,0} multiply(f16[6,512,4096]{2,1,0} %power.tmp.1, f16[6,512,4096]{2,1,0} %param_0) |
| %constant_4 = f16[] constant(0.044708), metadata={op_type="Mul" op_name="mul"} |
| %broadcast.3 = f16[6,512,4096]{2,1,0} broadcast(f16[] %constant_4), dimensions={}, metadata={op_type="Mul" op_name="mul"} |
| %multiply.3 = f16[6,512,4096]{2,1,0} multiply(f16[6,512,4096]{2,1,0} %power.0, f16[6,512,4096]{2,1,0} %broadcast.3), metadata={op_type="Mul" op_name="mul"} |
| %add.1 = f16[6,512,4096]{2,1,0} add(f16[6,512,4096]{2,1,0} %param_0, f16[6,512,4096]{2,1,0} %multiply.3), metadata={op_type="AddV2" op_name="add"} |
| %constant_2 = f16[] constant(0.79785), metadata={op_type="Mul" op_name="mul_1"} |
| %broadcast.2 = f16[6,512,4096]{2,1,0} broadcast(f16[] %constant_2), dimensions={}, metadata={op_type="Mul" op_name="mul_1"} |
| %multiply.2 = f16[6,512,4096]{2,1,0} multiply(f16[6,512,4096]{2,1,0} %add.1, f16[6,512,4096]{2,1,0} %broadcast.2), metadata={op_type="Mul" op_name="mul_1"} |
| %tanh.0 = f16[6,512,4096]{2,1,0} tanh(f16[6,512,4096]{2,1,0} %multiply.2), metadata={op_type="Tanh" op_name="Tanh"} |
| %constant_1 = f16[] constant(1), metadata={op_type="AddV2" op_name="add_1"} |
| %broadcast.1 = f16[6,512,4096]{2,1,0} broadcast(f16[] %constant_1), dimensions={}, metadata={op_type="AddV2" op_name="add_1"} |
| %add.0 = f16[6,512,4096]{2,1,0} add(f16[6,512,4096]{2,1,0} %tanh.0, f16[6,512,4096]{2,1,0} %broadcast.1), metadata={op_type="AddV2" op_name="add_1"} |
| %constant_0 = f16[] constant(0.5), metadata={op_type="Mul" op_name="mul_2"} |
| %broadcast.0 = f16[6,512,4096]{2,1,0} broadcast(f16[] %constant_0), dimensions={}, metadata={op_type="Mul" op_name="mul_2"} |
| %multiply.1 = f16[6,512,4096]{2,1,0} multiply(f16[6,512,4096]{2,1,0} %add.0, f16[6,512,4096]{2,1,0} %broadcast.0), metadata={op_type="Mul" op_name="mul_2"} |
| ROOT %multiply.0 = f16[6,512,4096]{2,1,0} multiply(f16[6,512,4096]{2,1,0} %param_0, f16[6,512,4096]{2,1,0} %multiply.1), metadata={op_type="Mul" op_name="mul_3"} |
| } |
| |
| ENTRY %cluster_0__XlaCompiledKernel_true__XlaNumConstantArgs_0__XlaNumResourceArgs_0_.24 (arg0.1: f16[6,512,4096]) -> f16[6,512,4096] { |
| %arg0.1 = f16[6,512,4096]{2,1,0} parameter(0), parameter_replication={false}, metadata={op_name="XLA_Args"} |
| ROOT %fusion = f16[6,512,4096]{2,1,0} fusion(f16[6,512,4096]{2,1,0} %arg0.1), kind=kLoop, calls=%fused_computation, metadata={op_type="Mul" op_name="mul_3"} |
| } |