spirv-val: Add SPV_KHR_ray_tracing instructions (#4871)

diff --git a/Android.mk b/Android.mk
index cd1d7f8..c32732d 100644
--- a/Android.mk
+++ b/Android.mk
@@ -69,6 +69,7 @@
 		source/val/validate_non_uniform.cpp \
 		source/val/validate_primitives.cpp \
 		source/val/validate_ray_query.cpp \
+		source/val/validate_ray_tracing.cpp \
 		source/val/validate_scopes.cpp \
 		source/val/validate_small_type_uses.cpp \
 		source/val/validate_type.cpp
diff --git a/BUILD.gn b/BUILD.gn
index 9e9f6e5..ac75cba 100644
--- a/BUILD.gn
+++ b/BUILD.gn
@@ -531,6 +531,7 @@
     "source/val/validate_non_uniform.cpp",
     "source/val/validate_primitives.cpp",
     "source/val/validate_ray_query.cpp",
+    "source/val/validate_ray_tracing.cpp",
     "source/val/validate_scopes.cpp",
     "source/val/validate_scopes.h",
     "source/val/validate_small_type_uses.cpp",
diff --git a/source/CMakeLists.txt b/source/CMakeLists.txt
index 1ceb78f..ab4578b 100644
--- a/source/CMakeLists.txt
+++ b/source/CMakeLists.txt
@@ -323,6 +323,7 @@
   ${CMAKE_CURRENT_SOURCE_DIR}/val/validate_non_uniform.cpp
   ${CMAKE_CURRENT_SOURCE_DIR}/val/validate_primitives.cpp
   ${CMAKE_CURRENT_SOURCE_DIR}/val/validate_ray_query.cpp
+  ${CMAKE_CURRENT_SOURCE_DIR}/val/validate_ray_tracing.cpp
   ${CMAKE_CURRENT_SOURCE_DIR}/val/validate_scopes.cpp
   ${CMAKE_CURRENT_SOURCE_DIR}/val/validate_small_type_uses.cpp
   ${CMAKE_CURRENT_SOURCE_DIR}/val/validate_type.cpp
diff --git a/source/val/validate.cpp b/source/val/validate.cpp
index 55e9fd2..9a685f2 100644
--- a/source/val/validate.cpp
+++ b/source/val/validate.cpp
@@ -351,6 +351,7 @@
 
     if (auto error = LiteralsPass(*vstate, &instruction)) return error;
     if (auto error = RayQueryPass(*vstate, &instruction)) return error;
+    if (auto error = RayTracingPass(*vstate, &instruction)) return error;
   }
 
   // Validate the preconditions involving adjacent instructions. e.g. SpvOpPhi
diff --git a/source/val/validate.h b/source/val/validate.h
index 97d4683..85c32d3 100644
--- a/source/val/validate.h
+++ b/source/val/validate.h
@@ -200,6 +200,9 @@
 /// Validates correctness of ray query instructions.
 spv_result_t RayQueryPass(ValidationState_t& _, const Instruction* inst);
 
+/// Validates correctness of ray tracing instructions.
+spv_result_t RayTracingPass(ValidationState_t& _, const Instruction* inst);
+
 /// Calculates the reachability of basic blocks.
 void ReachabilityPass(ValidationState_t& _);
 
diff --git a/source/val/validate_cfg.cpp b/source/val/validate_cfg.cpp
index 0220fcd..6c341cb 100644
--- a/source/val/validate_cfg.cpp
+++ b/source/val/validate_cfg.cpp
@@ -1081,12 +1081,12 @@
       if (opcode == SpvOpIgnoreIntersectionKHR) {
         _.current_function().RegisterExecutionModelLimitation(
             SpvExecutionModelAnyHitKHR,
-            "OpIgnoreIntersectionKHR requires AnyHit execution model");
+            "OpIgnoreIntersectionKHR requires AnyHitKHR execution model");
       }
       if (opcode == SpvOpTerminateRayKHR) {
         _.current_function().RegisterExecutionModelLimitation(
             SpvExecutionModelAnyHitKHR,
-            "OpTerminateRayKHR requires AnyHit execution model");
+            "OpTerminateRayKHR requires AnyHitKHR execution model");
       }
 
       break;
diff --git a/source/val/validate_ray_tracing.cpp b/source/val/validate_ray_tracing.cpp
new file mode 100644
index 0000000..81fa593
--- /dev/null
+++ b/source/val/validate_ray_tracing.cpp
@@ -0,0 +1,199 @@
+// Copyright (c) 2022 The Khronos Group Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Validates ray tracing instructions from SPV_KHR_ray_tracing
+
+#include "source/opcode.h"
+#include "source/val/instruction.h"
+#include "source/val/validate.h"
+#include "source/val/validation_state.h"
+
+namespace spvtools {
+namespace val {
+
+spv_result_t RayTracingPass(ValidationState_t& _, const Instruction* inst) {
+  const SpvOp opcode = inst->opcode();
+  const uint32_t result_type = inst->type_id();
+
+  switch (opcode) {
+    case SpvOpTraceRayKHR: {
+      _.function(inst->function()->id())
+          ->RegisterExecutionModelLimitation(
+              [](SpvExecutionModel model, std::string* message) {
+                if (model != SpvExecutionModelRayGenerationKHR &&
+                    model != SpvExecutionModelClosestHitKHR &&
+                    model != SpvExecutionModelMissKHR) {
+                  if (message) {
+                    *message =
+                        "OpTraceRayKHR requires RayGenerationKHR, "
+                        "ClosestHitKHR and MissKHR execution models";
+                  }
+                  return false;
+                }
+                return true;
+              });
+
+      if (_.GetIdOpcode(_.GetOperandTypeId(inst, 0)) !=
+          SpvOpTypeAccelerationStructureKHR) {
+        return _.diag(SPV_ERROR_INVALID_DATA, inst)
+               << "Expected Acceleration Structure to be of type "
+                  "OpTypeAccelerationStructureKHR";
+      }
+
+      const uint32_t ray_flags = _.GetOperandTypeId(inst, 1);
+      if (!_.IsIntScalarType(ray_flags) || _.GetBitWidth(ray_flags) != 32) {
+        return _.diag(SPV_ERROR_INVALID_DATA, inst)
+               << "Ray Flags must be a 32-bit int scalar";
+      }
+
+      const uint32_t cull_mask = _.GetOperandTypeId(inst, 2);
+      if (!_.IsIntScalarType(cull_mask) || _.GetBitWidth(cull_mask) != 32) {
+        return _.diag(SPV_ERROR_INVALID_DATA, inst)
+               << "Cull Mask must be a 32-bit int scalar";
+      }
+
+      const uint32_t sbt_offset = _.GetOperandTypeId(inst, 3);
+      if (!_.IsIntScalarType(sbt_offset) || _.GetBitWidth(sbt_offset) != 32) {
+        return _.diag(SPV_ERROR_INVALID_DATA, inst)
+               << "SBT Offset must be a 32-bit int scalar";
+      }
+
+      const uint32_t sbt_stride = _.GetOperandTypeId(inst, 4);
+      if (!_.IsIntScalarType(sbt_stride) || _.GetBitWidth(sbt_stride) != 32) {
+        return _.diag(SPV_ERROR_INVALID_DATA, inst)
+               << "SBT Stride must be a 32-bit int scalar";
+      }
+
+      const uint32_t miss_index = _.GetOperandTypeId(inst, 5);
+      if (!_.IsIntScalarType(miss_index) || _.GetBitWidth(miss_index) != 32) {
+        return _.diag(SPV_ERROR_INVALID_DATA, inst)
+               << "Miss Index must be a 32-bit int scalar";
+      }
+
+      const uint32_t ray_origin = _.GetOperandTypeId(inst, 6);
+      if (!_.IsFloatVectorType(ray_origin) || _.GetDimension(ray_origin) != 3 ||
+          _.GetBitWidth(ray_origin) != 32) {
+        return _.diag(SPV_ERROR_INVALID_DATA, inst)
+               << "Ray Origin must be a 32-bit float 3-component vector";
+      }
+
+      const uint32_t ray_tmin = _.GetOperandTypeId(inst, 7);
+      if (!_.IsFloatScalarType(ray_tmin) || _.GetBitWidth(ray_tmin) != 32) {
+        return _.diag(SPV_ERROR_INVALID_DATA, inst)
+               << "Ray TMin must be a 32-bit float scalar";
+      }
+
+      const uint32_t ray_direction = _.GetOperandTypeId(inst, 8);
+      if (!_.IsFloatVectorType(ray_direction) ||
+          _.GetDimension(ray_direction) != 3 ||
+          _.GetBitWidth(ray_direction) != 32) {
+        return _.diag(SPV_ERROR_INVALID_DATA, inst)
+               << "Ray Direction must be a 32-bit float 3-component vector";
+      }
+
+      const uint32_t ray_tmax = _.GetOperandTypeId(inst, 9);
+      if (!_.IsFloatScalarType(ray_tmax) || _.GetBitWidth(ray_tmax) != 32) {
+        return _.diag(SPV_ERROR_INVALID_DATA, inst)
+               << "Ray TMax must be a 32-bit float scalar";
+      }
+
+      const Instruction* payload = _.FindDef(inst->GetOperandAs<uint32_t>(10));
+      if (payload->opcode() != SpvOpVariable) {
+        return _.diag(SPV_ERROR_INVALID_DATA, inst)
+               << "Payload must be the result of a OpVariable";
+      } else if (payload->word(3) != SpvStorageClassRayPayloadKHR &&
+                 payload->word(3) != SpvStorageClassIncomingRayPayloadKHR) {
+        return _.diag(SPV_ERROR_INVALID_DATA, inst)
+               << "Payload must have storage class RayPayloadKHR or "
+                  "IncomingRayPayloadKHR";
+      }
+      break;
+    }
+
+    case SpvOpReportIntersectionKHR: {
+      _.function(inst->function()->id())
+          ->RegisterExecutionModelLimitation(
+              [](SpvExecutionModel model, std::string* message) {
+                if (model != SpvExecutionModelIntersectionKHR) {
+                  if (message) {
+                    *message =
+                        "OpReportIntersectionKHR requires IntersectionKHR "
+                        "execution model";
+                  }
+                  return false;
+                }
+                return true;
+              });
+
+      if (!_.IsBoolScalarType(result_type)) {
+        return _.diag(SPV_ERROR_INVALID_DATA, inst)
+               << "expected Result Type to be bool scalar type";
+      }
+
+      const uint32_t hit = _.GetOperandTypeId(inst, 2);
+      if (!_.IsFloatScalarType(hit) || _.GetBitWidth(hit) != 32) {
+        return _.diag(SPV_ERROR_INVALID_DATA, inst)
+               << "Hit must be a 32-bit int scalar";
+      }
+
+      const uint32_t hit_kind = _.GetOperandTypeId(inst, 3);
+      if (!_.IsUnsignedIntScalarType(hit_kind) ||
+          _.GetBitWidth(hit_kind) != 32) {
+        return _.diag(SPV_ERROR_INVALID_DATA, inst)
+               << "Hit Kind must be a 32-bit unsigned int scalar";
+      }
+      break;
+    }
+
+    case SpvOpExecuteCallableKHR: {
+      _.function(inst->function()->id())
+          ->RegisterExecutionModelLimitation([](SpvExecutionModel model,
+                                                std::string* message) {
+            if (model != SpvExecutionModelRayGenerationKHR &&
+                model != SpvExecutionModelClosestHitKHR &&
+                model != SpvExecutionModelMissKHR &&
+                model != SpvExecutionModelCallableKHR) {
+              if (message) {
+                *message =
+                    "OpExecuteCallableKHR requires RayGenerationKHR, "
+                    "ClosestHitKHR, MissKHR and CallableKHR execution models";
+              }
+              return false;
+            }
+            return true;
+          });
+
+      const auto callable_data = _.FindDef(inst->GetOperandAs<uint32_t>(1));
+      if (callable_data->opcode() != SpvOpVariable) {
+        return _.diag(SPV_ERROR_INVALID_DATA, inst)
+               << "Callable Data must be the result of a OpVariable";
+      } else if (callable_data->word(3) != SpvStorageClassCallableDataKHR &&
+                 callable_data->word(3) !=
+                     SpvStorageClassIncomingCallableDataKHR) {
+        return _.diag(SPV_ERROR_INVALID_DATA, inst)
+               << "Callable Data must have storage class CallableDataKHR or "
+                  "IncomingCallableDataKHR";
+      }
+
+      break;
+    }
+
+    default:
+      break;
+  }
+
+  return SPV_SUCCESS;
+}
+}  // namespace val
+}  // namespace spvtools
diff --git a/test/val/CMakeLists.txt b/test/val/CMakeLists.txt
index d02807a..e73e0f4 100644
--- a/test/val/CMakeLists.txt
+++ b/test/val/CMakeLists.txt
@@ -91,6 +91,7 @@
 add_spvtools_unittest(TARGET val_rstuvw
   SRCS
        val_ray_query.cpp
+       val_ray_tracing.cpp
        val_small_type_uses_test.cpp
        val_ssa_test.cpp
        val_state_test.cpp
diff --git a/test/val/val_ray_tracing.cpp b/test/val/val_ray_tracing.cpp
new file mode 100644
index 0000000..9486777
--- /dev/null
+++ b/test/val/val_ray_tracing.cpp
@@ -0,0 +1,555 @@
+// Copyright (c) 2022 The Khronos Group Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Tests ray tracing instructions from SPV_KHR_ray_tracing.
+
+#include <sstream>
+#include <string>
+
+#include "gmock/gmock.h"
+#include "test/val/val_fixtures.h"
+
+namespace spvtools {
+namespace val {
+namespace {
+
+using ::testing::HasSubstr;
+using ::testing::Values;
+
+using ValidateRayTracing = spvtest::ValidateBase<bool>;
+
+TEST_F(ValidateRayTracing, IgnoreIntersectionSuccess) {
+  const std::string body = R"(
+OpCapability RayTracingKHR
+OpExtension "SPV_KHR_ray_tracing"
+OpMemoryModel Logical GLSL450
+OpEntryPoint AnyHitKHR %main "main"
+OpName %main "main"
+%void = OpTypeVoid
+%func = OpTypeFunction %void
+%main = OpFunction %void None %func
+%label = OpLabel
+OpIgnoreIntersectionKHR
+OpFunctionEnd
+)";
+
+  CompileSuccessfully(body.c_str());
+  EXPECT_EQ(SPV_SUCCESS, ValidateInstructions());
+}
+
+TEST_F(ValidateRayTracing, IgnoreIntersectionExecutionModel) {
+  const std::string body = R"(
+OpCapability RayTracingKHR
+OpExtension "SPV_KHR_ray_tracing"
+OpMemoryModel Logical GLSL450
+OpEntryPoint CallableKHR %main "main"
+OpName %main "main"
+%void = OpTypeVoid
+%func = OpTypeFunction %void
+%main = OpFunction %void None %func
+%label = OpLabel
+OpIgnoreIntersectionKHR
+OpFunctionEnd
+)";
+
+  CompileSuccessfully(body.c_str());
+  EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
+  EXPECT_THAT(
+      getDiagnosticString(),
+      HasSubstr("OpIgnoreIntersectionKHR requires AnyHitKHR execution model"));
+}
+
+TEST_F(ValidateRayTracing, TerminateRaySuccess) {
+  const std::string body = R"(
+OpCapability RayTracingKHR
+OpExtension "SPV_KHR_ray_tracing"
+OpMemoryModel Logical GLSL450
+OpEntryPoint AnyHitKHR %main "main"
+OpName %main "main"
+%void = OpTypeVoid
+%func = OpTypeFunction %void
+%main = OpFunction %void None %func
+%label = OpLabel
+OpIgnoreIntersectionKHR
+OpFunctionEnd
+)";
+
+  CompileSuccessfully(body.c_str());
+  EXPECT_EQ(SPV_SUCCESS, ValidateInstructions());
+}
+
+TEST_F(ValidateRayTracing, TerminateRayExecutionModel) {
+  const std::string body = R"(
+OpCapability RayTracingKHR
+OpExtension "SPV_KHR_ray_tracing"
+OpMemoryModel Logical GLSL450
+OpEntryPoint MissKHR %main "main"
+OpName %main "main"
+%void = OpTypeVoid
+%func = OpTypeFunction %void
+%main = OpFunction %void None %func
+%label = OpLabel
+OpTerminateRayKHR
+OpFunctionEnd
+)";
+
+  CompileSuccessfully(body.c_str());
+  EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
+  EXPECT_THAT(
+      getDiagnosticString(),
+      HasSubstr("OpTerminateRayKHR requires AnyHitKHR execution model"));
+}
+
+TEST_F(ValidateRayTracing, ReportIntersectionRaySuccess) {
+  const std::string body = R"(
+OpCapability RayTracingKHR
+OpExtension "SPV_KHR_ray_tracing"
+OpMemoryModel Logical GLSL450
+OpEntryPoint IntersectionKHR %main "main"
+OpName %main "main"
+%void = OpTypeVoid
+%func = OpTypeFunction %void
+%float = OpTypeFloat 32
+%float_1 = OpConstant %float 1
+%uint = OpTypeInt 32 0
+%uint_1 = OpConstant %uint 1
+%bool = OpTypeBool
+%main = OpFunction %void None %func
+%label = OpLabel
+%report = OpReportIntersectionKHR %bool %float_1 %uint_1
+OpReturn
+OpFunctionEnd
+)";
+
+  CompileSuccessfully(body.c_str());
+  EXPECT_EQ(SPV_SUCCESS, ValidateInstructions());
+}
+
+TEST_F(ValidateRayTracing, ReportIntersectionExecutionModel) {
+  const std::string body = R"(
+OpCapability RayTracingKHR
+OpExtension "SPV_KHR_ray_tracing"
+OpMemoryModel Logical GLSL450
+OpEntryPoint MissKHR %main "main"
+OpName %main "main"
+%void = OpTypeVoid
+%func = OpTypeFunction %void
+%float = OpTypeFloat 32
+%float_1 = OpConstant %float 1
+%uint = OpTypeInt 32 0
+%uint_1 = OpConstant %uint 1
+%bool = OpTypeBool
+%main = OpFunction %void None %func
+%label = OpLabel
+%report = OpReportIntersectionKHR %bool %float_1 %uint_1
+OpReturn
+OpFunctionEnd
+)";
+
+  CompileSuccessfully(body.c_str());
+  EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
+  EXPECT_THAT(
+      getDiagnosticString(),
+      HasSubstr(
+          "OpReportIntersectionKHR requires IntersectionKHR execution model"));
+}
+
+TEST_F(ValidateRayTracing, ReportIntersectionReturnType) {
+  const std::string body = R"(
+OpCapability RayTracingKHR
+OpExtension "SPV_KHR_ray_tracing"
+OpMemoryModel Logical GLSL450
+OpEntryPoint IntersectionKHR %main "main"
+OpName %main "main"
+%void = OpTypeVoid
+%func = OpTypeFunction %void
+%float = OpTypeFloat 32
+%float_1 = OpConstant %float 1
+%uint = OpTypeInt 32 0
+%uint_1 = OpConstant %uint 1
+%main = OpFunction %void None %func
+%label = OpLabel
+%report = OpReportIntersectionKHR %uint %float_1 %uint_1
+OpReturn
+OpFunctionEnd
+)";
+
+  CompileSuccessfully(body.c_str());
+  EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
+  EXPECT_THAT(getDiagnosticString(),
+              HasSubstr("expected Result Type to be bool scalar type"));
+}
+
+TEST_F(ValidateRayTracing, ReportIntersectionHit) {
+  const std::string body = R"(
+OpCapability RayTracingKHR
+OpCapability Float64
+OpExtension "SPV_KHR_ray_tracing"
+OpMemoryModel Logical GLSL450
+OpEntryPoint IntersectionKHR %main "main"
+OpName %main "main"
+%void = OpTypeVoid
+%func = OpTypeFunction %void
+%float64 = OpTypeFloat 64
+%float64_1 = OpConstant %float64 1
+%uint = OpTypeInt 32 0
+%uint_1 = OpConstant %uint 1
+%bool = OpTypeBool
+%main = OpFunction %void None %func
+%label = OpLabel
+%report = OpReportIntersectionKHR %bool %float64_1 %uint_1
+OpReturn
+OpFunctionEnd
+)";
+
+  CompileSuccessfully(body.c_str());
+  EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
+  EXPECT_THAT(getDiagnosticString(),
+              HasSubstr("Hit must be a 32-bit int scalar"));
+}
+
+TEST_F(ValidateRayTracing, ReportIntersectionHitKind) {
+  const std::string body = R"(
+OpCapability RayTracingKHR
+OpExtension "SPV_KHR_ray_tracing"
+OpMemoryModel Logical GLSL450
+OpEntryPoint IntersectionKHR %main "main"
+OpName %main "main"
+%void = OpTypeVoid
+%func = OpTypeFunction %void
+%float = OpTypeFloat 32
+%float_1 = OpConstant %float 1
+%sint = OpTypeInt 32 1
+%sint_1 = OpConstant %sint 1
+%bool = OpTypeBool
+%main = OpFunction %void None %func
+%label = OpLabel
+%report = OpReportIntersectionKHR %bool %float_1 %sint_1
+OpReturn
+OpFunctionEnd
+)";
+
+  CompileSuccessfully(body.c_str());
+  EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
+  EXPECT_THAT(getDiagnosticString(),
+              HasSubstr("Hit Kind must be a 32-bit unsigned int scalar"));
+}
+
+TEST_F(ValidateRayTracing, ExecuteCallableSuccess) {
+  const std::string body = R"(
+OpCapability RayTracingKHR
+OpExtension "SPV_KHR_ray_tracing"
+OpMemoryModel Logical GLSL450
+OpEntryPoint CallableKHR %main "main"
+OpName %main "main"
+%void = OpTypeVoid
+%func = OpTypeFunction %void
+%int = OpTypeInt 32 1
+%uint = OpTypeInt 32 0
+%uint_0 = OpConstant %uint 0
+%data_ptr = OpTypePointer CallableDataKHR %int
+%data = OpVariable %data_ptr CallableDataKHR
+%inData_ptr = OpTypePointer IncomingCallableDataKHR %int
+%inData = OpVariable %inData_ptr IncomingCallableDataKHR
+%main = OpFunction %void None %func
+%label = OpLabel
+OpExecuteCallableKHR %uint_0 %data
+OpExecuteCallableKHR %uint_0 %inData
+OpReturn
+OpFunctionEnd
+)";
+
+  CompileSuccessfully(body.c_str());
+  EXPECT_EQ(SPV_SUCCESS, ValidateInstructions());
+}
+
+TEST_F(ValidateRayTracing, ExecuteCallableExecutionModel) {
+  const std::string body = R"(
+OpCapability RayTracingKHR
+OpExtension "SPV_KHR_ray_tracing"
+OpMemoryModel Logical GLSL450
+OpEntryPoint AnyHitKHR %main "main"
+OpName %main "main"
+%void = OpTypeVoid
+%func = OpTypeFunction %void
+%int = OpTypeInt 32 1
+%uint = OpTypeInt 32 0
+%uint_0 = OpConstant %uint 0
+%data_ptr = OpTypePointer CallableDataKHR %int
+%data = OpVariable %data_ptr CallableDataKHR
+%inData_ptr = OpTypePointer IncomingCallableDataKHR %int
+%inData = OpVariable %inData_ptr IncomingCallableDataKHR
+%main = OpFunction %void None %func
+%label = OpLabel
+OpExecuteCallableKHR %uint_0 %data
+OpExecuteCallableKHR %uint_0 %inData
+OpReturn
+OpFunctionEnd
+)";
+
+  CompileSuccessfully(body.c_str());
+  EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
+  EXPECT_THAT(
+      getDiagnosticString(),
+      HasSubstr("OpExecuteCallableKHR requires RayGenerationKHR, "
+                "ClosestHitKHR, MissKHR and CallableKHR execution models"));
+}
+
+TEST_F(ValidateRayTracing, ExecuteCallableStorageClass) {
+  const std::string body = R"(
+OpCapability RayTracingKHR
+OpExtension "SPV_KHR_ray_tracing"
+OpMemoryModel Logical GLSL450
+OpEntryPoint RayGenerationKHR %main "main"
+OpName %main "main"
+%void = OpTypeVoid
+%func = OpTypeFunction %void
+%int = OpTypeInt 32 1
+%uint = OpTypeInt 32 0
+%uint_0 = OpConstant %uint 0
+%data_ptr = OpTypePointer RayPayloadKHR %int
+%data = OpVariable %data_ptr RayPayloadKHR
+%main = OpFunction %void None %func
+%label = OpLabel
+OpExecuteCallableKHR %uint_0 %data
+OpReturn
+OpFunctionEnd
+)";
+
+  CompileSuccessfully(body.c_str());
+  EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
+  EXPECT_THAT(getDiagnosticString(),
+              HasSubstr("Callable Data must have storage class CallableDataKHR "
+                        "or IncomingCallableDataKHR"));
+}
+
+std::string GenerateRayTraceCode(
+    const std::string& body,
+    const std::string execution_model = "RayGenerationKHR") {
+  std::ostringstream ss;
+  ss << R"(
+OpCapability RayTracingKHR
+OpCapability Float64
+OpExtension "SPV_KHR_ray_tracing"
+OpMemoryModel Logical GLSL450
+OpEntryPoint )"
+     << execution_model << R"( %main "main"
+OpDecorate %top_level_as DescriptorSet 0
+OpDecorate %top_level_as Binding 0
+%void = OpTypeVoid
+%func = OpTypeFunction %void
+%type_as = OpTypeAccelerationStructureKHR
+%as_uc_ptr = OpTypePointer UniformConstant %type_as
+%top_level_as = OpVariable %as_uc_ptr UniformConstant
+%uint = OpTypeInt 32 0
+%uint_1 = OpConstant %uint 1
+%float = OpTypeFloat 32
+%float64 = OpTypeFloat 64
+%f32vec3 = OpTypeVector %float 3
+%f32vec4 = OpTypeVector %float 4
+%float_0 = OpConstant %float 0
+%float64_0 = OpConstant %float64 0
+%v3composite = OpConstantComposite %f32vec3 %float_0 %float_0 %float_0
+%v4composite = OpConstantComposite %f32vec4 %float_0 %float_0 %float_0 %float_0
+%int = OpTypeInt 32 1
+%int_1 = OpConstant %int 1
+%payload_ptr = OpTypePointer RayPayloadKHR %int
+%payload = OpVariable %payload_ptr RayPayloadKHR
+%callable_ptr = OpTypePointer CallableDataKHR %int
+%callable = OpVariable %callable_ptr CallableDataKHR
+%ptr_uint = OpTypePointer Private %uint
+%var_uint = OpVariable %ptr_uint Private
+%ptr_float = OpTypePointer Private %float
+%var_float = OpVariable %ptr_float Private
+%ptr_f32vec3 = OpTypePointer Private %f32vec3
+%var_f32vec3 = OpVariable %ptr_f32vec3 Private
+%main = OpFunction %void None %func
+%label = OpLabel
+)";
+
+  ss << body;
+
+  ss << R"(
+OpReturn
+OpFunctionEnd)";
+  return ss.str();
+}
+
+TEST_F(ValidateRayTracing, TraceRaySuccess) {
+  const std::string body = R"(
+%as = OpLoad %type_as %top_level_as
+OpTraceRayKHR %as %uint_1 %uint_1 %uint_1 %uint_1 %uint_1 %v3composite %float_0 %v3composite %float_0 %payload
+
+%_uint = OpLoad %uint %var_uint
+%_float = OpLoad %float %var_float
+%_f32vec3 = OpLoad %f32vec3 %var_f32vec3
+OpTraceRayKHR %as %_uint %_uint %_uint %_uint %_uint %_f32vec3 %_float %_f32vec3 %_float %payload
+)";
+
+  CompileSuccessfully(GenerateRayTraceCode(body).c_str());
+  EXPECT_EQ(SPV_SUCCESS, ValidateInstructions());
+}
+
+TEST_F(ValidateRayTracing, TraceRayExecutionModel) {
+  const std::string body = R"(
+%as = OpLoad %type_as %top_level_as
+OpTraceRayKHR %as %uint_1 %uint_1 %uint_1 %uint_1 %uint_1 %v3composite %float_0 %v3composite %float_0 %payload
+)";
+
+  CompileSuccessfully(GenerateRayTraceCode(body, "CallableKHR").c_str());
+  EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
+  EXPECT_THAT(getDiagnosticString(),
+              HasSubstr("OpTraceRayKHR requires RayGenerationKHR, "
+                        "ClosestHitKHR and MissKHR execution models"));
+}
+
+TEST_F(ValidateRayTracing, TraceRayAccelerationStructure) {
+  const std::string body = R"(
+%_uint = OpLoad %uint %var_uint
+OpTraceRayKHR %_uint %uint_1 %uint_1 %uint_1 %uint_1 %uint_1 %v3composite %float_0 %v3composite %float_0 %payload
+)";
+
+  CompileSuccessfully(GenerateRayTraceCode(body).c_str());
+  EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
+  EXPECT_THAT(getDiagnosticString(),
+              HasSubstr("Expected Acceleration Structure to be of type "
+                        "OpTypeAccelerationStructureKHR"));
+}
+
+TEST_F(ValidateRayTracing, TraceRayRayFlags) {
+  const std::string body = R"(
+%as = OpLoad %type_as %top_level_as
+OpTraceRayKHR %as %float_0 %uint_1 %uint_1 %uint_1 %uint_1 %v3composite %float_0 %v3composite %float_0 %payload
+)";
+
+  CompileSuccessfully(GenerateRayTraceCode(body).c_str());
+  EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
+  EXPECT_THAT(getDiagnosticString(),
+              HasSubstr("Ray Flags must be a 32-bit int scalar"));
+}
+
+TEST_F(ValidateRayTracing, TraceRayCullMask) {
+  const std::string body = R"(
+%as = OpLoad %type_as %top_level_as
+OpTraceRayKHR %as %uint_1 %float_0 %uint_1 %uint_1 %uint_1 %v3composite %float_0 %v3composite %float_0 %payload
+)";
+
+  CompileSuccessfully(GenerateRayTraceCode(body).c_str());
+  EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
+  EXPECT_THAT(getDiagnosticString(),
+              HasSubstr("Cull Mask must be a 32-bit int scalar"));
+}
+
+TEST_F(ValidateRayTracing, TraceRaySbtOffest) {
+  const std::string body = R"(
+%as = OpLoad %type_as %top_level_as
+OpTraceRayKHR %as %uint_1 %uint_1 %float_0 %uint_1 %uint_1 %v3composite %float_0 %v3composite %float_0 %payload
+)";
+
+  CompileSuccessfully(GenerateRayTraceCode(body).c_str());
+  EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
+  EXPECT_THAT(getDiagnosticString(),
+              HasSubstr("SBT Offset must be a 32-bit int scalar"));
+}
+
+TEST_F(ValidateRayTracing, TraceRaySbtStride) {
+  const std::string body = R"(
+%as = OpLoad %type_as %top_level_as
+OpTraceRayKHR %as %uint_1 %uint_1 %uint_1 %float_0 %uint_1 %v3composite %float_0 %v3composite %float_0 %payload
+)";
+
+  CompileSuccessfully(GenerateRayTraceCode(body).c_str());
+  EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
+  EXPECT_THAT(getDiagnosticString(),
+              HasSubstr("SBT Stride must be a 32-bit int scalar"));
+}
+
+TEST_F(ValidateRayTracing, TraceRayMissIndex) {
+  const std::string body = R"(
+%as = OpLoad %type_as %top_level_as
+OpTraceRayKHR %as %uint_1 %uint_1 %uint_1 %uint_1 %float_0 %v3composite %float_0 %v3composite %float_0 %payload
+)";
+
+  CompileSuccessfully(GenerateRayTraceCode(body).c_str());
+  EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
+  EXPECT_THAT(getDiagnosticString(),
+              HasSubstr("Miss Index must be a 32-bit int scalar"));
+}
+
+TEST_F(ValidateRayTracing, TraceRayRayOrigin) {
+  const std::string body = R"(
+%as = OpLoad %type_as %top_level_as
+OpTraceRayKHR %as %uint_1 %uint_1 %uint_1 %uint_1 %uint_1 %float_0 %float_0 %v3composite %float_0 %payload
+)";
+
+  CompileSuccessfully(GenerateRayTraceCode(body).c_str());
+  EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
+  EXPECT_THAT(
+      getDiagnosticString(),
+      HasSubstr("Ray Origin must be a 32-bit float 3-component vector"));
+}
+
+TEST_F(ValidateRayTracing, TraceRayRayTMin) {
+  const std::string body = R"(
+%as = OpLoad %type_as %top_level_as
+OpTraceRayKHR %as %uint_1 %uint_1 %uint_1 %uint_1 %uint_1 %v3composite %uint_1 %v3composite %float_0 %payload
+)";
+
+  CompileSuccessfully(GenerateRayTraceCode(body).c_str());
+  EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
+  EXPECT_THAT(getDiagnosticString(),
+              HasSubstr("Ray TMin must be a 32-bit float scalar"));
+}
+
+TEST_F(ValidateRayTracing, TraceRayRayDirection) {
+  const std::string body = R"(
+%as = OpLoad %type_as %top_level_as
+OpTraceRayKHR %as %uint_1 %uint_1 %uint_1 %uint_1 %uint_1 %v3composite %float_0 %v4composite %float_0 %payload
+)";
+
+  CompileSuccessfully(GenerateRayTraceCode(body).c_str());
+  EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
+  EXPECT_THAT(
+      getDiagnosticString(),
+      HasSubstr("Ray Direction must be a 32-bit float 3-component vector"));
+}
+
+TEST_F(ValidateRayTracing, TraceRayRayTMax) {
+  const std::string body = R"(
+%as = OpLoad %type_as %top_level_as
+OpTraceRayKHR %as %uint_1 %uint_1 %uint_1 %uint_1 %uint_1 %v3composite %float_0 %v3composite %float64_0 %payload
+)";
+
+  CompileSuccessfully(GenerateRayTraceCode(body).c_str());
+  EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
+  EXPECT_THAT(getDiagnosticString(),
+              HasSubstr("Ray TMax must be a 32-bit float scalar"));
+}
+
+TEST_F(ValidateRayTracing, TraceRayPayload) {
+  const std::string body = R"(
+%as = OpLoad %type_as %top_level_as
+OpTraceRayKHR %as %uint_1 %uint_1 %uint_1 %uint_1 %uint_1 %v3composite %float_0 %v3composite %float_0 %callable
+)";
+
+  CompileSuccessfully(GenerateRayTraceCode(body).c_str());
+  EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
+  EXPECT_THAT(getDiagnosticString(),
+              HasSubstr("Payload must have storage class RayPayloadKHR or "
+                        "IncomingRayPayloadKHR"));
+}
+
+}  // namespace
+}  // namespace val
+}  // namespace spvtools