Remove validate_datarules.cpp (#2911)

* Checks moved into individual opcode validation
  * removes duplicated checks
* Add check that forward pointer points to struct

diff --git a/Android.mk b/Android.mk
index 2eb4368..cb7062c 100644
--- a/Android.mk
+++ b/Android.mk
@@ -49,7 +49,6 @@
 		source/val/validate_composites.cpp \
 		source/val/validate_constants.cpp \
 		source/val/validate_conversion.cpp \
-		source/val/validate_datarules.cpp \
 		source/val/validate_debug.cpp \
 		source/val/validate_decorations.cpp \
 		source/val/validate_derivatives.cpp \
diff --git a/BUILD.gn b/BUILD.gn
index b1ed713..a8bd134 100644
--- a/BUILD.gn
+++ b/BUILD.gn
@@ -410,7 +410,6 @@
     "source/val/validate_composites.cpp",
     "source/val/validate_constants.cpp",
     "source/val/validate_conversion.cpp",
-    "source/val/validate_datarules.cpp",
     "source/val/validate_debug.cpp",
     "source/val/validate_decorations.cpp",
     "source/val/validate_derivatives.cpp",
diff --git a/source/CMakeLists.txt b/source/CMakeLists.txt
index 37348e9..a083995 100644
--- a/source/CMakeLists.txt
+++ b/source/CMakeLists.txt
@@ -283,7 +283,6 @@
   ${CMAKE_CURRENT_SOURCE_DIR}/val/validate_composites.cpp
   ${CMAKE_CURRENT_SOURCE_DIR}/val/validate_constants.cpp
   ${CMAKE_CURRENT_SOURCE_DIR}/val/validate_conversion.cpp
-  ${CMAKE_CURRENT_SOURCE_DIR}/val/validate_datarules.cpp
   ${CMAKE_CURRENT_SOURCE_DIR}/val/validate_debug.cpp
   ${CMAKE_CURRENT_SOURCE_DIR}/val/validate_decorations.cpp
   ${CMAKE_CURRENT_SOURCE_DIR}/val/validate_derivatives.cpp
diff --git a/source/val/validate.cpp b/source/val/validate.cpp
index e4123a4..7f4b0dc 100644
--- a/source/val/validate.cpp
+++ b/source/val/validate.cpp
@@ -343,7 +343,6 @@
     }
 
     if (auto error = CapabilityPass(*vstate, &instruction)) return error;
-    if (auto error = DataRulesPass(*vstate, &instruction)) return error;
     if (auto error = ModuleLayoutPass(*vstate, &instruction)) return error;
     if (auto error = CfgPass(*vstate, &instruction)) return error;
     if (auto error = InstructionPass(*vstate, &instruction)) return error;
@@ -352,6 +351,9 @@
     {
       Instruction* inst = const_cast<Instruction*>(&instruction);
       vstate->RegisterInstruction(inst);
+      if (inst->opcode() == SpvOpTypeForwardPointer) {
+        vstate->RegisterForwardPointer(inst->GetOperandAs<uint32_t>(0));
+      }
     }
   }
 
diff --git a/source/val/validate.h b/source/val/validate.h
index da3d0b8..31a775b 100644
--- a/source/val/validate.h
+++ b/source/val/validate.h
@@ -123,11 +123,6 @@
 /// Performs Id and SSA validation of a module
 spv_result_t IdPass(ValidationState_t& _, Instruction* inst);
 
-/// Performs validation of the Data Rules subsection of 2.16.1 Universal
-/// Validation Rules.
-/// TODO(ehsann): add more comments here as more validation code is added.
-spv_result_t DataRulesPass(ValidationState_t& _, const Instruction* inst);
-
 /// Performs instruction validation.
 spv_result_t InstructionPass(ValidationState_t& _, const Instruction* inst);
 
diff --git a/source/val/validate_builtins.cpp b/source/val/validate_builtins.cpp
index 42497cf..7623d49 100644
--- a/source/val/validate_builtins.cpp
+++ b/source/val/validate_builtins.cpp
@@ -2504,20 +2504,20 @@
       switch (execution_model) {
         case SpvExecutionModelGeometry:
         case SpvExecutionModelFragment:
-        case SpvExecutionModelMeshNV: {
+        case SpvExecutionModelMeshNV:
           // Ok.
           break;
-          case SpvExecutionModelVertex:
-          case SpvExecutionModelTessellationEvaluation:
-            if (!_.HasCapability(SpvCapabilityShaderViewportIndexLayerEXT)) {
-              return _.diag(SPV_ERROR_INVALID_DATA, &referenced_from_inst)
-                     << "Using BuiltIn "
-                     << _.grammar().lookupOperandName(SPV_OPERAND_TYPE_BUILT_IN,
-                                                      decoration.params()[0])
-                     << " in Vertex or Tessellation execution model requires "
-                        "the ShaderViewportIndexLayerEXT capability.";
-            }
-            break;
+        case SpvExecutionModelVertex:
+        case SpvExecutionModelTessellationEvaluation: {
+          if (!_.HasCapability(SpvCapabilityShaderViewportIndexLayerEXT)) {
+            return _.diag(SPV_ERROR_INVALID_DATA, &referenced_from_inst)
+                   << "Using BuiltIn "
+                   << _.grammar().lookupOperandName(SPV_OPERAND_TYPE_BUILT_IN,
+                                                    decoration.params()[0])
+                   << " in Vertex or Tessellation execution model requires "
+                      "the ShaderViewportIndexLayerEXT capability.";
+          }
+          break;
         }
         default: {
           return _.diag(SPV_ERROR_INVALID_DATA, &referenced_from_inst)
diff --git a/source/val/validate_constants.cpp b/source/val/validate_constants.cpp
index 565518b..95ec168 100644
--- a/source/val/validate_constants.cpp
+++ b/source/val/validate_constants.cpp
@@ -342,6 +342,21 @@
   return SPV_SUCCESS;
 }
 
+// Validates that OpSpecConstant specializes to either int or float type.
+spv_result_t ValidateSpecConstant(ValidationState_t& _,
+                                  const Instruction* inst) {
+  // Operand 0 is the <id> of the type that we're specializing to.
+  auto type_id = inst->GetOperandAs<const uint32_t>(0);
+  auto type_instruction = _.FindDef(type_id);
+  auto type_opcode = type_instruction->opcode();
+  if (type_opcode != SpvOpTypeInt && type_opcode != SpvOpTypeFloat) {
+    return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Specialization constant "
+                                                   "must be an integer or "
+                                                   "floating-point number.";
+  }
+  return SPV_SUCCESS;
+}
+
 spv_result_t ValidateSpecConstantOp(ValidationState_t& _,
                                     const Instruction* inst) {
   const auto op = inst->GetOperandAs<SpvOp>(2);
@@ -422,6 +437,9 @@
     case SpvOpConstantNull:
       if (auto error = ValidateConstantNull(_, inst)) return error;
       break;
+    case SpvOpSpecConstant:
+      if (auto error = ValidateSpecConstant(_, inst)) return error;
+      break;
     case SpvOpSpecConstantOp:
       if (auto error = ValidateSpecConstantOp(_, inst)) return error;
       break;
diff --git a/source/val/validate_datarules.cpp b/source/val/validate_datarules.cpp
deleted file mode 100644
index 826eb8d..0000000
--- a/source/val/validate_datarules.cpp
+++ /dev/null
@@ -1,286 +0,0 @@
-// Copyright (c) 2016 Google Inc.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-//     http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-// Ensures Data Rules are followed according to the specifications.
-
-#include "source/val/validate.h"
-
-#include <cassert>
-#include <sstream>
-#include <string>
-
-#include "source/diagnostic.h"
-#include "source/opcode.h"
-#include "source/operand.h"
-#include "source/val/instruction.h"
-#include "source/val/validation_state.h"
-
-namespace spvtools {
-namespace val {
-namespace {
-
-// Validates that the number of components in the vector is valid.
-// Vector types can only be parameterized as having 2, 3, or 4 components.
-// If the Vector16 capability is added, 8 and 16 components are also allowed.
-spv_result_t ValidateVecNumComponents(ValidationState_t& _,
-                                      const Instruction* inst) {
-  // Operand 2 specifies the number of components in the vector.
-  auto num_components = inst->GetOperandAs<const uint32_t>(2);
-  if (num_components == 2 || num_components == 3 || num_components == 4) {
-    return SPV_SUCCESS;
-  }
-  if (num_components == 8 || num_components == 16) {
-    if (_.HasCapability(SpvCapabilityVector16)) {
-      return SPV_SUCCESS;
-    }
-    return _.diag(SPV_ERROR_INVALID_DATA, inst)
-           << "Having " << num_components << " components for "
-           << spvOpcodeString(inst->opcode())
-           << " requires the Vector16 capability";
-  }
-  return _.diag(SPV_ERROR_INVALID_DATA, inst)
-         << "Illegal number of components (" << num_components << ") for "
-         << spvOpcodeString(inst->opcode());
-}
-
-// Validates that the number of bits specifed for a float type is valid.
-// Scalar floating-point types can be parameterized only with 32-bits.
-// Float16 capability allows using a 16-bit OpTypeFloat.
-// Float16Buffer capability allows creation of a 16-bit OpTypeFloat.
-// Float64 capability allows using a 64-bit OpTypeFloat.
-spv_result_t ValidateFloatSize(ValidationState_t& _, const Instruction* inst) {
-  // Operand 1 is the number of bits for this float
-  auto num_bits = inst->GetOperandAs<const uint32_t>(1);
-  if (num_bits == 32) {
-    return SPV_SUCCESS;
-  }
-  if (num_bits == 16) {
-    if (_.features().declare_float16_type) {
-      return SPV_SUCCESS;
-    }
-    return _.diag(SPV_ERROR_INVALID_DATA, inst)
-           << "Using a 16-bit floating point "
-           << "type requires the Float16 or Float16Buffer capability,"
-              " or an extension that explicitly enables 16-bit floating point.";
-  }
-  if (num_bits == 64) {
-    if (_.HasCapability(SpvCapabilityFloat64)) {
-      return SPV_SUCCESS;
-    }
-    return _.diag(SPV_ERROR_INVALID_DATA, inst)
-           << "Using a 64-bit floating point "
-           << "type requires the Float64 capability.";
-  }
-  return _.diag(SPV_ERROR_INVALID_DATA, inst)
-         << "Invalid number of bits (" << num_bits << ") used for OpTypeFloat.";
-}
-
-// Validates that the number of bits specified for an Int type is valid.
-// Scalar integer types can be parameterized only with 32-bits.
-// Int8, Int16, and Int64 capabilities allow using 8-bit, 16-bit, and 64-bit
-// integers, respectively.
-spv_result_t ValidateIntSize(ValidationState_t& _, const Instruction* inst) {
-  // Operand 1 is the number of bits for this integer.
-  auto num_bits = inst->GetOperandAs<const uint32_t>(1);
-  if (num_bits == 32) {
-    return SPV_SUCCESS;
-  }
-  if (num_bits == 8) {
-    if (_.features().declare_int8_type) {
-      return SPV_SUCCESS;
-    }
-    return _.diag(SPV_ERROR_INVALID_DATA, inst)
-           << "Using an 8-bit integer type requires the Int8 capability,"
-              " or an extension that explicitly enables 8-bit integers.";
-  }
-  if (num_bits == 16) {
-    if (_.features().declare_int16_type) {
-      return SPV_SUCCESS;
-    }
-    return _.diag(SPV_ERROR_INVALID_DATA, inst)
-           << "Using a 16-bit integer type requires the Int16 capability,"
-              " or an extension that explicitly enables 16-bit integers.";
-  }
-  if (num_bits == 64) {
-    if (_.HasCapability(SpvCapabilityInt64)) {
-      return SPV_SUCCESS;
-    }
-    return _.diag(SPV_ERROR_INVALID_DATA, inst)
-           << "Using a 64-bit integer type requires the Int64 capability.";
-  }
-  return _.diag(SPV_ERROR_INVALID_DATA, inst)
-         << "Invalid number of bits (" << num_bits << ") used for OpTypeInt.";
-}
-
-// Validates that the matrix is parameterized with floating-point types.
-spv_result_t ValidateMatrixColumnType(ValidationState_t& _,
-                                      const Instruction* inst) {
-  // Find the component type of matrix columns (must be vector).
-  // Operand 1 is the <id> of the type specified for matrix columns.
-  auto type_id = inst->GetOperandAs<const uint32_t>(1);
-  auto col_type_instr = _.FindDef(type_id);
-  if (col_type_instr->opcode() != SpvOpTypeVector) {
-    return _.diag(SPV_ERROR_INVALID_ID, inst)
-           << "Columns in a matrix must be of type vector.";
-  }
-
-  // Trace back once more to find out the type of components in the vector.
-  // Operand 1 is the <id> of the type of data in the vector.
-  auto comp_type_id =
-      col_type_instr->words()[col_type_instr->operands()[1].offset];
-  auto comp_type_instruction = _.FindDef(comp_type_id);
-  if (comp_type_instruction->opcode() != SpvOpTypeFloat) {
-    return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Matrix types can only be "
-                                                   "parameterized with "
-                                                   "floating-point types.";
-  }
-  return SPV_SUCCESS;
-}
-
-// Validates that the matrix has 2,3, or 4 columns.
-spv_result_t ValidateMatrixNumCols(ValidationState_t& _,
-                                   const Instruction* inst) {
-  // Operand 2 is the number of columns in the matrix.
-  auto num_cols = inst->GetOperandAs<const uint32_t>(2);
-  if (num_cols != 2 && num_cols != 3 && num_cols != 4) {
-    return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Matrix types can only be "
-                                                   "parameterized as having "
-                                                   "only 2, 3, or 4 columns.";
-  }
-  return SPV_SUCCESS;
-}
-
-// Validates that OpSpecConstant specializes to either int or float type.
-spv_result_t ValidateSpecConstNumerical(ValidationState_t& _,
-                                        const Instruction* inst) {
-  // Operand 0 is the <id> of the type that we're specializing to.
-  auto type_id = inst->GetOperandAs<const uint32_t>(0);
-  auto type_instruction = _.FindDef(type_id);
-  auto type_opcode = type_instruction->opcode();
-  if (type_opcode != SpvOpTypeInt && type_opcode != SpvOpTypeFloat) {
-    return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Specialization constant "
-                                                   "must be an integer or "
-                                                   "floating-point number.";
-  }
-  return SPV_SUCCESS;
-}
-
-// Validates that OpSpecConstantTrue and OpSpecConstantFalse specialize to bool.
-spv_result_t ValidateSpecConstBoolean(ValidationState_t& _,
-                                      const Instruction* inst) {
-  // Find out the type that we're specializing to.
-  auto type_instruction = _.FindDef(inst->type_id());
-  if (type_instruction->opcode() != SpvOpTypeBool) {
-    return _.diag(SPV_ERROR_INVALID_ID, inst)
-           << "Specialization constant must be a boolean type.";
-  }
-  return SPV_SUCCESS;
-}
-
-// Records the <id> of the forward pointer to be used for validation.
-spv_result_t ValidateForwardPointer(ValidationState_t& _,
-                                    const Instruction* inst) {
-  // Record the <id> (which is operand 0) to ensure it's used properly.
-  // OpTypeStruct can only include undefined pointers that are
-  // previously declared as a ForwardPointer
-  return (_.RegisterForwardPointer(inst->GetOperandAs<uint32_t>(0)));
-}
-
-// Validates that any undefined component of the struct is a forward pointer.
-// It is valid to declare a forward pointer, and use its <id> as one of the
-// components of a struct.
-spv_result_t ValidateStruct(ValidationState_t& _, const Instruction* inst) {
-  // Struct components are operands 1, 2, etc.
-  for (unsigned i = 1; i < inst->operands().size(); i++) {
-    auto type_id = inst->GetOperandAs<const uint32_t>(i);
-    auto type_instruction = _.FindDef(type_id);
-    if (type_instruction == nullptr && !_.IsForwardPointer(type_id)) {
-      return _.diag(SPV_ERROR_INVALID_ID, inst)
-             << "Forward reference operands in an OpTypeStruct must first be "
-                "declared using OpTypeForwardPointer.";
-    }
-  }
-  return SPV_SUCCESS;
-}
-
-// Validates that any undefined type of the array is a forward pointer.
-// It is valid to declare a forward pointer, and use its <id> as the element
-// type of the array.
-spv_result_t ValidateArray(ValidationState_t& _, const Instruction* inst) {
-  auto element_type_id = inst->GetOperandAs<const uint32_t>(1);
-  auto element_type_instruction = _.FindDef(element_type_id);
-  if (element_type_instruction == nullptr &&
-      !_.IsForwardPointer(element_type_id)) {
-    return _.diag(SPV_ERROR_INVALID_ID, inst)
-           << "Forward reference operands in an OpTypeArray must first be "
-              "declared using OpTypeForwardPointer.";
-  }
-  return SPV_SUCCESS;
-}
-
-}  // namespace
-
-// Validates that Data Rules are followed according to the specifications.
-// (Data Rules subsection of 2.16.1 Universal Validation Rules)
-spv_result_t DataRulesPass(ValidationState_t& _, const Instruction* inst) {
-  switch (inst->opcode()) {
-    case SpvOpTypeVector: {
-      if (auto error = ValidateVecNumComponents(_, inst)) return error;
-      break;
-    }
-    case SpvOpTypeFloat: {
-      if (auto error = ValidateFloatSize(_, inst)) return error;
-      break;
-    }
-    case SpvOpTypeInt: {
-      if (auto error = ValidateIntSize(_, inst)) return error;
-      break;
-    }
-    case SpvOpTypeMatrix: {
-      if (auto error = ValidateMatrixColumnType(_, inst)) return error;
-      if (auto error = ValidateMatrixNumCols(_, inst)) return error;
-      break;
-    }
-    // TODO(ehsan): Add OpSpecConstantComposite validation code.
-    // TODO(ehsan): Add OpSpecConstantOp validation code (if any).
-    case SpvOpSpecConstant: {
-      if (auto error = ValidateSpecConstNumerical(_, inst)) return error;
-      break;
-    }
-    case SpvOpSpecConstantFalse:
-    case SpvOpSpecConstantTrue: {
-      if (auto error = ValidateSpecConstBoolean(_, inst)) return error;
-      break;
-    }
-    case SpvOpTypeForwardPointer: {
-      if (auto error = ValidateForwardPointer(_, inst)) return error;
-      break;
-    }
-    case SpvOpTypeStruct: {
-      if (auto error = ValidateStruct(_, inst)) return error;
-      break;
-    }
-    case SpvOpTypeArray: {
-      if (auto error = ValidateArray(_, inst)) return error;
-      break;
-    }
-    // TODO(ehsan): add more data rules validation here.
-    default: { break; }
-  }
-
-  return SPV_SUCCESS;
-}
-
-}  // namespace val
-}  // namespace spvtools
diff --git a/source/val/validate_type.cpp b/source/val/validate_type.cpp
index afc0656..d3872da 100644
--- a/source/val/validate_type.cpp
+++ b/source/val/validate_type.cpp
@@ -67,6 +67,39 @@
 }
 
 spv_result_t ValidateTypeInt(ValidationState_t& _, const Instruction* inst) {
+  // Validates that the number of bits specified for an Int type is valid.
+  // Scalar integer types can be parameterized only with 32-bits.
+  // Int8, Int16, and Int64 capabilities allow using 8-bit, 16-bit, and 64-bit
+  // integers, respectively.
+  auto num_bits = inst->GetOperandAs<const uint32_t>(1);
+  if (num_bits != 32) {
+    if (num_bits == 8) {
+      if (_.features().declare_int8_type) {
+        return SPV_SUCCESS;
+      }
+      return _.diag(SPV_ERROR_INVALID_DATA, inst)
+             << "Using an 8-bit integer type requires the Int8 capability,"
+                " or an extension that explicitly enables 8-bit integers.";
+    } else if (num_bits == 16) {
+      if (_.features().declare_int16_type) {
+        return SPV_SUCCESS;
+      }
+      return _.diag(SPV_ERROR_INVALID_DATA, inst)
+             << "Using a 16-bit integer type requires the Int16 capability,"
+                " or an extension that explicitly enables 16-bit integers.";
+    } else if (num_bits == 64) {
+      if (_.HasCapability(SpvCapabilityInt64)) {
+        return SPV_SUCCESS;
+      }
+      return _.diag(SPV_ERROR_INVALID_DATA, inst)
+             << "Using a 64-bit integer type requires the Int64 capability.";
+    } else {
+      return _.diag(SPV_ERROR_INVALID_DATA, inst)
+             << "Invalid number of bits (" << num_bits
+             << ") used for OpTypeInt.";
+    }
+  }
+
   const auto signedness_index = 2;
   const auto signedness = inst->GetOperandAs<uint32_t>(signedness_index);
   if (signedness != 0 && signedness != 1) {
@@ -76,6 +109,36 @@
   return SPV_SUCCESS;
 }
 
+spv_result_t ValidateTypeFloat(ValidationState_t& _, const Instruction* inst) {
+  // Validates that the number of bits specified for an Int type is valid.
+  // Scalar integer types can be parameterized only with 32-bits.
+  // Int8, Int16, and Int64 capabilities allow using 8-bit, 16-bit, and 64-bit
+  // integers, respectively.
+  auto num_bits = inst->GetOperandAs<const uint32_t>(1);
+  if (num_bits == 32) {
+    return SPV_SUCCESS;
+  }
+  if (num_bits == 16) {
+    if (_.features().declare_float16_type) {
+      return SPV_SUCCESS;
+    }
+    return _.diag(SPV_ERROR_INVALID_DATA, inst)
+           << "Using a 16-bit floating point "
+           << "type requires the Float16 or Float16Buffer capability,"
+              " or an extension that explicitly enables 16-bit floating point.";
+  }
+  if (num_bits == 64) {
+    if (_.HasCapability(SpvCapabilityFloat64)) {
+      return SPV_SUCCESS;
+    }
+    return _.diag(SPV_ERROR_INVALID_DATA, inst)
+           << "Using a 64-bit floating point "
+           << "type requires the Float64 capability.";
+  }
+  return _.diag(SPV_ERROR_INVALID_DATA, inst)
+         << "Invalid number of bits (" << num_bits << ") used for OpTypeFloat.";
+}
+
 spv_result_t ValidateTypeVector(ValidationState_t& _, const Instruction* inst) {
   const auto component_index = 1;
   const auto component_id = inst->GetOperandAs<uint32_t>(component_index);
@@ -85,6 +148,27 @@
            << "OpTypeVector Component Type <id> '" << _.getIdName(component_id)
            << "' is not a scalar type.";
   }
+
+  // Validates that the number of components in the vector is valid.
+  // Vector types can only be parameterized as having 2, 3, or 4 components.
+  // If the Vector16 capability is added, 8 and 16 components are also allowed.
+  auto num_components = inst->GetOperandAs<const uint32_t>(2);
+  if (num_components == 2 || num_components == 3 || num_components == 4) {
+    return SPV_SUCCESS;
+  } else if (num_components == 8 || num_components == 16) {
+    if (_.HasCapability(SpvCapabilityVector16)) {
+      return SPV_SUCCESS;
+    }
+    return _.diag(SPV_ERROR_INVALID_DATA, inst)
+           << "Having " << num_components << " components for "
+           << spvOpcodeString(inst->opcode())
+           << " requires the Vector16 capability";
+  } else {
+    return _.diag(SPV_ERROR_INVALID_DATA, inst)
+           << "Illegal number of components (" << num_components << ") for "
+           << spvOpcodeString(inst->opcode());
+  }
+
   return SPV_SUCCESS;
 }
 
@@ -94,9 +178,27 @@
   const auto column_type = _.FindDef(column_type_id);
   if (!column_type || SpvOpTypeVector != column_type->opcode()) {
     return _.diag(SPV_ERROR_INVALID_ID, inst)
-           << "OpTypeMatrix Column Type <id> '" << _.getIdName(column_type_id)
-           << "' is not a vector.";
+           << "Columns in a matrix must be of type vector.";
   }
+
+  // Trace back once more to find out the type of components in the vector.
+  // Operand 1 is the <id> of the type of data in the vector.
+  const auto comp_type_id = column_type->GetOperandAs<uint32_t>(1);
+  auto comp_type_instruction = _.FindDef(comp_type_id);
+  if (comp_type_instruction->opcode() != SpvOpTypeFloat) {
+    return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Matrix types can only be "
+                                                   "parameterized with "
+                                                   "floating-point types.";
+  }
+
+  // Validates that the matrix has 2,3, or 4 columns.
+  auto num_cols = inst->GetOperandAs<const uint32_t>(2);
+  if (num_cols != 2 && num_cols != 3 && num_cols != 4) {
+    return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Matrix types can only be "
+                                                   "parameterized as having "
+                                                   "only 2, 3, or 4 columns.";
+  }
+
   return SPV_SUCCESS;
 }
 
@@ -224,6 +326,11 @@
   for (size_t member_type_index = 1;
        member_type_index < inst->operands().size(); ++member_type_index) {
     auto member_type_id = inst->GetOperandAs<uint32_t>(member_type_index);
+    if (member_type_id == inst->id()) {
+      return _.diag(SPV_ERROR_INVALID_ID, inst)
+             << "Structure members may not be self references";
+    }
+
     auto member_type = _.FindDef(member_type_id);
     if (!member_type || !spvOpcodeGeneratesType(member_type->opcode())) {
       return _.diag(SPV_ERROR_INVALID_ID, inst)
@@ -245,22 +352,6 @@
              << " contains structure <id> " << _.getIdName(member_type_id)
              << ".";
     }
-    if (_.IsForwardPointer(member_type_id)) {
-      // If we're dealing with a forward pointer:
-      // Find out the type that the pointer is pointing to (must be struct)
-      // word 3 is the <id> of the type being pointed to.
-      auto type_pointing_to = _.FindDef(member_type->words()[3]);
-      if (type_pointing_to && type_pointing_to->opcode() != SpvOpTypeStruct) {
-        // Forward declared operands of a struct may only point to a struct.
-        return _.diag(SPV_ERROR_INVALID_ID, inst)
-               << "A forward reference operand in an OpTypeStruct must be an "
-                  "OpTypePointer that points to an OpTypeStruct. "
-                  "Found OpTypePointer that points to Op"
-               << spvOpcodeString(
-                      static_cast<SpvOp>(type_pointing_to->opcode()))
-               << ".";
-      }
-    }
 
     if (spvIsVulkanOrWebGPUEnv(_.context()->target_env) &&
         member_type->opcode() == SpvOpTypeRuntimeArray) {
@@ -356,7 +447,6 @@
   }
   return SPV_SUCCESS;
 }
-}  // namespace
 
 spv_result_t ValidateTypeFunction(ValidationState_t& _,
                                   const Instruction* inst) {
@@ -425,6 +515,13 @@
            << "pointer definition.";
   }
 
+  const auto pointee_type_id = pointer_type_inst->GetOperandAs<uint32_t>(2);
+  const auto pointee_type = _.FindDef(pointee_type_id);
+  if (!pointee_type || pointee_type->opcode() != SpvOpTypeStruct) {
+    return _.diag(SPV_ERROR_INVALID_ID, inst)
+           << "Forward pointers must point to a structure";
+  }
+
   return SPV_SUCCESS;
 }
 
@@ -474,6 +571,7 @@
 
   return SPV_SUCCESS;
 }
+}  // namespace
 
 spv_result_t TypePass(ValidationState_t& _, const Instruction* inst) {
   if (!spvOpcodeGeneratesType(inst->opcode()) &&
@@ -487,6 +585,9 @@
     case SpvOpTypeInt:
       if (auto error = ValidateTypeInt(_, inst)) return error;
       break;
+    case SpvOpTypeFloat:
+      if (auto error = ValidateTypeFloat(_, inst)) return error;
+      break;
     case SpvOpTypeVector:
       if (auto error = ValidateTypeVector(_, inst)) return error;
       break;
diff --git a/test/val/val_data_test.cpp b/test/val/val_data_test.cpp
index 4690f97..0178fa7 100644
--- a/test/val/val_data_test.cpp
+++ b/test/val/val_data_test.cpp
@@ -629,15 +629,26 @@
   ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
 }
 
-TEST_F(ValidateData, specialize_boolean_to_int) {
+TEST_F(ValidateData, specialize_boolean_true_to_int) {
   std::string str = header + R"(
 %2 = OpTypeInt 32 1
-%3 = OpSpecConstantTrue %2
+%3 = OpSpecConstantTrue %2)";
+  CompileSuccessfully(str.c_str());
+  ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
+  EXPECT_THAT(getDiagnosticString(),
+              HasSubstr("OpSpecConstantTrue Result Type <id> '1[%int]' is not "
+                        "a boolean type"));
+}
+
+TEST_F(ValidateData, specialize_boolean_false_to_int) {
+  std::string str = header + R"(
+%2 = OpTypeInt 32 1
 %4 = OpSpecConstantFalse %2)";
   CompileSuccessfully(str.c_str());
   ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
   EXPECT_THAT(getDiagnosticString(),
-              HasSubstr("Specialization constant must be a boolean"));
+              HasSubstr("OpSpecConstantFalse Result Type <id> '1[%int]' is not "
+                        "a boolean type"));
 }
 
 TEST_F(ValidateData, missing_forward_pointer_decl) {
@@ -647,8 +658,9 @@
 )";
   CompileSuccessfully(str.c_str());
   ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
-  EXPECT_THAT(getDiagnosticString(),
-              HasSubstr("must first be declared using OpTypeForwardPointer"));
+  EXPECT_THAT(
+      getDiagnosticString(),
+      HasSubstr("The following forward referenced IDs have not been defined:"));
 }
 
 TEST_F(ValidateData, missing_forward_pointer_decl_self_reference) {
@@ -659,7 +671,7 @@
   CompileSuccessfully(str.c_str());
   ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
   EXPECT_THAT(getDiagnosticString(),
-              HasSubstr("must first be declared using OpTypeForwardPointer"));
+              HasSubstr("Structure members may not be self references"));
 }
 
 TEST_F(ValidateData, forward_pointer_missing_definition) {
@@ -698,9 +710,7 @@
   CompileSuccessfully(str.c_str());
   ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
   EXPECT_THAT(getDiagnosticString(),
-              HasSubstr("A forward reference operand in an OpTypeStruct must "
-                        "be an OpTypePointer that points to an OpTypeStruct. "
-                        "Found OpTypePointer that points to OpTypeInt."));
+              HasSubstr("Forward pointers must point to a structure"));
 }
 
 TEST_F(ValidateData, struct_forward_pointer_good) {
@@ -934,23 +944,6 @@
                         "OpTypeStruct %_runtimearr_uint %uint\n"));
 }
 
-TEST_F(ValidateData, invalid_forward_reference_in_array) {
-  std::string str = R"(
-               OpCapability Shader
-               OpCapability Linkage
-               OpMemoryModel Logical GLSL450
-       %uint = OpTypeInt 32 0
-%uint_1 = OpConstant %uint 1
-%_arr_3_uint_1 = OpTypeArray %_arr_3_uint_1 %uint_1
-)";
-
-  CompileSuccessfully(str.c_str(), SPV_ENV_UNIVERSAL_1_3);
-  ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions(SPV_ENV_UNIVERSAL_1_3));
-  EXPECT_THAT(getDiagnosticString(),
-              HasSubstr("Forward reference operands in an OpTypeArray must "
-                        "first be declared using OpTypeForwardPointer."));
-}
-
 }  // namespace
 }  // namespace val
 }  // namespace spvtools
diff --git a/test/val/val_id_test.cpp b/test/val/val_id_test.cpp
index 4c2a73b..43e3197 100644
--- a/test/val/val_id_test.cpp
+++ b/test/val/val_id_test.cpp
@@ -1476,7 +1476,8 @@
   CompileSuccessfully(spirv.c_str());
   EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
   EXPECT_THAT(getDiagnosticString(),
-              HasSubstr("Specialization constant must be a boolean type."));
+              HasSubstr("OpSpecConstantTrue Result Type <id> '1[%void]' is not "
+                        "a boolean type"));
 }
 
 TEST_F(ValidateIdWithMessage, OpSpecConstantFalseGood) {
@@ -1492,8 +1493,10 @@
 %2 = OpSpecConstantFalse %1)";
   CompileSuccessfully(spirv.c_str());
   EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
-  EXPECT_THAT(getDiagnosticString(),
-              HasSubstr("Specialization constant must be a boolean type."));
+  EXPECT_THAT(
+      getDiagnosticString(),
+      HasSubstr("OpSpecConstantFalse Result Type <id> '1[%void]' is not "
+                "a boolean type"));
 }
 
 TEST_F(ValidateIdWithMessage, OpSpecConstantGood) {
@@ -5175,9 +5178,9 @@
 )";
   CompileSuccessfully(spirv.c_str());
   EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
-  EXPECT_THAT(getDiagnosticString(),
-              HasSubstr("Forward reference operands in an OpTypeStruct must "
-                        "first be declared using OpTypeForwardPointer."));
+  EXPECT_THAT(
+      getDiagnosticString(),
+      HasSubstr("The following forward referenced IDs have not been defined"));
 }
 
 TEST_F(ValidateIdWithMessage, UndefinedIdScope) {
diff --git a/test/val/val_ssa_test.cpp b/test/val/val_ssa_test.cpp
index e10abd7..035c710 100644
--- a/test/val/val_ssa_test.cpp
+++ b/test/val/val_ssa_test.cpp
@@ -1415,7 +1415,8 @@
                OpName %intptrt "intptrt"
                OpTypeForwardPointer %intptrt UniformConstant
        %uint = OpTypeInt 32 0
-    %intptrt = OpTypePointer UniformConstant %uint
+     %struct = OpTypeStruct %uint
+    %intptrt = OpTypePointer UniformConstant %struct
 )";
 
   CompileSuccessfully(str);
diff --git a/test/val/val_type_unique_test.cpp b/test/val/val_type_unique_test.cpp
index 67ceadd..45a4d50 100644
--- a/test/val/val_type_unique_test.cpp
+++ b/test/val/val_type_unique_test.cpp
@@ -210,9 +210,11 @@
 OpTypeForwardPointer %ptr Generic
 OpTypeForwardPointer %ptr2 Generic
 %intt = OpTypeInt 32 0
+%int_struct = OpTypeStruct %intt
 %floatt = OpTypeFloat 32
-%ptr = OpTypePointer Generic %intt
-%ptr2 = OpTypePointer Generic %floatt
+%ptr = OpTypePointer Generic %int_struct
+%float_struct = OpTypeStruct %floatt
+%ptr2 = OpTypePointer Generic %float_struct
 )";
   CompileSuccessfully(str.c_str());
   ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());