Unify memory instruction validation style (#1934)

* Rename ValidateMemoryInstructions to MemoryPass
* Changed functions to take pointer to an instruction instead of
reference
diff --git a/source/val/validate.cpp b/source/val/validate.cpp
index 84ec193..012e9d7 100644
--- a/source/val/validate.cpp
+++ b/source/val/validate.cpp
@@ -313,8 +313,7 @@
     if (auto error = ModeSettingPass(*vstate, &instruction)) return error;
     if (auto error = TypePass(*vstate, &instruction)) return error;
     if (auto error = ConstantPass(*vstate, &instruction)) return error;
-    if (auto error = ValidateMemoryInstructions(*vstate, &instruction))
-      return error;
+    if (auto error = MemoryPass(*vstate, &instruction)) return error;
     if (auto error = FunctionPass(*vstate, &instruction)) return error;
     if (auto error = ImagePass(*vstate, &instruction)) return error;
     if (auto error = ConversionPass(*vstate, &instruction)) return error;
diff --git a/source/val/validate.h b/source/val/validate.h
index 4599c4a..518547f 100644
--- a/source/val/validate.h
+++ b/source/val/validate.h
@@ -91,8 +91,7 @@
 ///
 /// @param[in] _ the validation state of the module
 /// @return SPV_SUCCESS if no errors are found.
-spv_result_t ValidateMemoryInstructions(ValidationState_t& _,
-                                        const Instruction* inst);
+spv_result_t MemoryPass(ValidationState_t& _, const Instruction* inst);
 
 /// @brief Updates the immediate dominator for each of the block edges
 ///
diff --git a/source/val/validate_memory.cpp b/source/val/validate_memory.cpp
index f1e10b3..3a9e1ae 100644
--- a/source/val/validate_memory.cpp
+++ b/source/val/validate_memory.cpp
@@ -180,18 +180,18 @@
   return false;
 }
 
-spv_result_t ValidateVariable(ValidationState_t& _, const Instruction& inst) {
-  auto result_type = _.FindDef(inst.type_id());
+spv_result_t ValidateVariable(ValidationState_t& _, const Instruction* inst) {
+  auto result_type = _.FindDef(inst->type_id());
   if (!result_type || result_type->opcode() != SpvOpTypePointer) {
-    return _.diag(SPV_ERROR_INVALID_ID, &inst)
-           << "OpVariable Result Type <id> '" << _.getIdName(inst.type_id())
+    return _.diag(SPV_ERROR_INVALID_ID, inst)
+           << "OpVariable Result Type <id> '" << _.getIdName(inst->type_id())
            << "' is not a pointer type.";
   }
 
   const auto initializer_index = 3;
   const auto storage_class_index = 2;
-  if (initializer_index < inst.operands().size()) {
-    const auto initializer_id = inst.GetOperandAs<uint32_t>(initializer_index);
+  if (initializer_index < inst->operands().size()) {
+    const auto initializer_id = inst->GetOperandAs<uint32_t>(initializer_index);
     const auto initializer = _.FindDef(initializer_id);
     const auto is_module_scope_var =
         initializer && (initializer->opcode() == SpvOpVariable) &&
@@ -200,13 +200,13 @@
     const auto is_constant =
         initializer && spvOpcodeIsConstant(initializer->opcode());
     if (!initializer || !(is_constant || is_module_scope_var)) {
-      return _.diag(SPV_ERROR_INVALID_ID, &inst)
+      return _.diag(SPV_ERROR_INVALID_ID, inst)
              << "OpVariable Initializer <id> '" << _.getIdName(initializer_id)
              << "' is not a constant or module-scope variable.";
     }
   }
 
-  const auto storage_class = inst.GetOperandAs<uint32_t>(storage_class_index);
+  const auto storage_class = inst->GetOperandAs<uint32_t>(storage_class_index);
   if (storage_class != SpvStorageClassWorkgroup &&
       storage_class != SpvStorageClassCrossWorkgroup &&
       storage_class != SpvStorageClassPrivate &&
@@ -218,7 +218,7 @@
                                    storage_class == SpvStorageClassOutput;
     bool builtin = false;
     if (storage_input_or_output) {
-      for (const Decoration& decoration : _.id_decorations(inst.id())) {
+      for (const Decoration& decoration : _.id_decorations(inst->id())) {
         if (decoration.dec_type() == SpvDecorationBuiltIn) {
           builtin = true;
           break;
@@ -227,7 +227,7 @@
     }
     if (!(storage_input_or_output && builtin) &&
         ContainsInvalidBool(_, storage, storage_input_or_output)) {
-      return _.diag(SPV_ERROR_INVALID_ID, &inst)
+      return _.diag(SPV_ERROR_INVALID_ID, inst)
              << "If OpTypeBool is stored in conjunction with OpVariable, it "
              << "can only be used with non-externally visible shader Storage "
              << "Classes: Workgroup, CrossWorkgroup, Private, and Function";
@@ -236,11 +236,11 @@
   return SPV_SUCCESS;
 }
 
-spv_result_t ValidateLoad(ValidationState_t& _, const Instruction& inst) {
-  const auto result_type = _.FindDef(inst.type_id());
+spv_result_t ValidateLoad(ValidationState_t& _, const Instruction* inst) {
+  const auto result_type = _.FindDef(inst->type_id());
   if (!result_type) {
-    return _.diag(SPV_ERROR_INVALID_ID, &inst)
-           << "OpLoad Result Type <id> '" << _.getIdName(inst.type_id())
+    return _.diag(SPV_ERROR_INVALID_ID, inst)
+           << "OpLoad Result Type <id> '" << _.getIdName(inst->type_id())
            << "' is not defined.";
   }
 
@@ -248,7 +248,7 @@
       _.features().variable_pointers ||
       _.features().variable_pointers_storage_buffer;
   const auto pointer_index = 2;
-  const auto pointer_id = inst.GetOperandAs<uint32_t>(pointer_index);
+  const auto pointer_id = inst->GetOperandAs<uint32_t>(pointer_index);
   const auto pointer = _.FindDef(pointer_id);
   if (!pointer ||
       ((_.addressing_model() == SpvAddressingModelLogical) &&
@@ -256,22 +256,22 @@
          !spvOpcodeReturnsLogicalPointer(pointer->opcode())) ||
         (uses_variable_pointers &&
          !spvOpcodeReturnsLogicalVariablePointer(pointer->opcode()))))) {
-    return _.diag(SPV_ERROR_INVALID_ID, &inst)
+    return _.diag(SPV_ERROR_INVALID_ID, inst)
            << "OpLoad Pointer <id> '" << _.getIdName(pointer_id)
            << "' is not a logical pointer.";
   }
 
   const auto pointer_type = _.FindDef(pointer->type_id());
   if (!pointer_type || pointer_type->opcode() != SpvOpTypePointer) {
-    return _.diag(SPV_ERROR_INVALID_ID, &inst)
+    return _.diag(SPV_ERROR_INVALID_ID, inst)
            << "OpLoad type for pointer <id> '" << _.getIdName(pointer_id)
            << "' is not a pointer type.";
   }
 
   const auto pointee_type = _.FindDef(pointer_type->GetOperandAs<uint32_t>(2));
   if (!pointee_type || result_type->id() != pointee_type->id()) {
-    return _.diag(SPV_ERROR_INVALID_ID, &inst)
-           << "OpLoad Result Type <id> '" << _.getIdName(inst.type_id())
+    return _.diag(SPV_ERROR_INVALID_ID, inst)
+           << "OpLoad Result Type <id> '" << _.getIdName(inst->type_id())
            << "' does not match Pointer <id> '" << _.getIdName(pointer->id())
            << "'s type.";
   }
@@ -279,12 +279,12 @@
   return SPV_SUCCESS;
 }
 
-spv_result_t ValidateStore(ValidationState_t& _, const Instruction& inst) {
+spv_result_t ValidateStore(ValidationState_t& _, const Instruction* inst) {
   const bool uses_variable_pointer =
       _.features().variable_pointers ||
       _.features().variable_pointers_storage_buffer;
   const auto pointer_index = 0;
-  const auto pointer_id = inst.GetOperandAs<uint32_t>(pointer_index);
+  const auto pointer_id = inst->GetOperandAs<uint32_t>(pointer_index);
   const auto pointer = _.FindDef(pointer_id);
   if (!pointer ||
       (_.addressing_model() == SpvAddressingModelLogical &&
@@ -292,20 +292,20 @@
          !spvOpcodeReturnsLogicalPointer(pointer->opcode())) ||
         (uses_variable_pointer &&
          !spvOpcodeReturnsLogicalVariablePointer(pointer->opcode()))))) {
-    return _.diag(SPV_ERROR_INVALID_ID, &inst)
+    return _.diag(SPV_ERROR_INVALID_ID, inst)
            << "OpStore Pointer <id> '" << _.getIdName(pointer_id)
            << "' is not a logical pointer.";
   }
   const auto pointer_type = _.FindDef(pointer->type_id());
   if (!pointer_type || pointer_type->opcode() != SpvOpTypePointer) {
-    return _.diag(SPV_ERROR_INVALID_ID, &inst)
+    return _.diag(SPV_ERROR_INVALID_ID, inst)
            << "OpStore type for pointer <id> '" << _.getIdName(pointer_id)
            << "' is not a pointer type.";
   }
   const auto type_id = pointer_type->GetOperandAs<uint32_t>(2);
   const auto type = _.FindDef(type_id);
   if (!type || SpvOpTypeVoid == type->opcode()) {
-    return _.diag(SPV_ERROR_INVALID_ID, &inst)
+    return _.diag(SPV_ERROR_INVALID_ID, inst)
            << "OpStore Pointer <id> '" << _.getIdName(pointer_id)
            << "'s type is void.";
   }
@@ -315,7 +315,7 @@
     uint32_t data_type;
     uint32_t storage_class;
     if (!_.GetPointerTypeInfo(pointer_type->id(), &data_type, &storage_class)) {
-      return _.diag(SPV_ERROR_INVALID_ID, &inst)
+      return _.diag(SPV_ERROR_INVALID_ID, inst)
              << "OpStore Pointer <id> '" << _.getIdName(pointer_id)
              << "' is not pointer type";
     }
@@ -323,23 +323,23 @@
     if (storage_class == SpvStorageClassUniformConstant ||
         storage_class == SpvStorageClassInput ||
         storage_class == SpvStorageClassPushConstant) {
-      return _.diag(SPV_ERROR_INVALID_ID, &inst)
+      return _.diag(SPV_ERROR_INVALID_ID, inst)
              << "OpStore Pointer <id> '" << _.getIdName(pointer_id)
              << "' storage class is read-only";
     }
   }
 
   const auto object_index = 1;
-  const auto object_id = inst.GetOperandAs<uint32_t>(object_index);
+  const auto object_id = inst->GetOperandAs<uint32_t>(object_index);
   const auto object = _.FindDef(object_id);
   if (!object || !object->type_id()) {
-    return _.diag(SPV_ERROR_INVALID_ID, &inst)
+    return _.diag(SPV_ERROR_INVALID_ID, inst)
            << "OpStore Object <id> '" << _.getIdName(object_id)
            << "' is not an object.";
   }
   const auto object_type = _.FindDef(object->type_id());
   if (!object_type || SpvOpTypeVoid == object_type->opcode()) {
-    return _.diag(SPV_ERROR_INVALID_ID, &inst)
+    return _.diag(SPV_ERROR_INVALID_ID, inst)
            << "OpStore Object <id> '" << _.getIdName(object_id)
            << "'s type is void.";
   }
@@ -347,7 +347,7 @@
   if (type->id() != object_type->id()) {
     if (!_.options()->relax_struct_store || type->opcode() != SpvOpTypeStruct ||
         object_type->opcode() != SpvOpTypeStruct) {
-      return _.diag(SPV_ERROR_INVALID_ID, &inst)
+      return _.diag(SPV_ERROR_INVALID_ID, inst)
              << "OpStore Pointer <id> '" << _.getIdName(pointer_id)
              << "'s type does not match Object <id> '"
              << _.getIdName(object->id()) << "'s type.";
@@ -355,7 +355,7 @@
 
     // TODO: Check for layout compatible matricies and arrays as well.
     if (!AreLayoutCompatibleStructs(_, type, object_type)) {
-      return _.diag(SPV_ERROR_INVALID_ID, &inst)
+      return _.diag(SPV_ERROR_INVALID_ID, inst)
              << "OpStore Pointer <id> '" << _.getIdName(pointer_id)
              << "'s layout does not match Object <id> '"
              << _.getIdName(object->id()) << "'s layout.";
@@ -364,21 +364,21 @@
   return SPV_SUCCESS;
 }
 
-spv_result_t ValidateCopyMemory(ValidationState_t& _, const Instruction& inst) {
+spv_result_t ValidateCopyMemory(ValidationState_t& _, const Instruction* inst) {
   const auto target_index = 0;
-  const auto target_id = inst.GetOperandAs<uint32_t>(target_index);
+  const auto target_id = inst->GetOperandAs<uint32_t>(target_index);
   const auto target = _.FindDef(target_id);
   if (!target) {
-    return _.diag(SPV_ERROR_INVALID_ID, &inst)
+    return _.diag(SPV_ERROR_INVALID_ID, inst)
            << "Target operand <id> '" << _.getIdName(target_id)
            << "' is not defined.";
   }
 
   const auto source_index = 1;
-  const auto source_id = inst.GetOperandAs<uint32_t>(source_index);
+  const auto source_id = inst->GetOperandAs<uint32_t>(source_index);
   const auto source = _.FindDef(source_id);
   if (!source) {
-    return _.diag(SPV_ERROR_INVALID_ID, &inst)
+    return _.diag(SPV_ERROR_INVALID_ID, inst)
            << "Source operand <id> '" << _.getIdName(source_id)
            << "' is not defined.";
   }
@@ -386,7 +386,7 @@
   const auto target_pointer_type = _.FindDef(target->type_id());
   if (!target_pointer_type ||
       target_pointer_type->opcode() != SpvOpTypePointer) {
-    return _.diag(SPV_ERROR_INVALID_ID, &inst)
+    return _.diag(SPV_ERROR_INVALID_ID, inst)
            << "Target operand <id> '" << _.getIdName(target_id)
            << "' is not a pointer.";
   }
@@ -394,16 +394,16 @@
   const auto source_pointer_type = _.FindDef(source->type_id());
   if (!source_pointer_type ||
       source_pointer_type->opcode() != SpvOpTypePointer) {
-    return _.diag(SPV_ERROR_INVALID_ID, &inst)
+    return _.diag(SPV_ERROR_INVALID_ID, inst)
            << "Source operand <id> '" << _.getIdName(source_id)
            << "' is not a pointer.";
   }
 
-  if (inst.opcode() == SpvOpCopyMemory) {
+  if (inst->opcode() == SpvOpCopyMemory) {
     const auto target_type =
         _.FindDef(target_pointer_type->GetOperandAs<uint32_t>(2));
     if (!target_type || target_type->opcode() == SpvOpTypeVoid) {
-      return _.diag(SPV_ERROR_INVALID_ID, &inst)
+      return _.diag(SPV_ERROR_INVALID_ID, inst)
              << "Target operand <id> '" << _.getIdName(target_id)
              << "' cannot be a void pointer.";
     }
@@ -411,29 +411,29 @@
     const auto source_type =
         _.FindDef(source_pointer_type->GetOperandAs<uint32_t>(2));
     if (!source_type || source_type->opcode() == SpvOpTypeVoid) {
-      return _.diag(SPV_ERROR_INVALID_ID, &inst)
+      return _.diag(SPV_ERROR_INVALID_ID, inst)
              << "Source operand <id> '" << _.getIdName(source_id)
              << "' cannot be a void pointer.";
     }
 
     if (target_type->id() != source_type->id()) {
-      return _.diag(SPV_ERROR_INVALID_ID, &inst)
+      return _.diag(SPV_ERROR_INVALID_ID, inst)
              << "Target <id> '" << _.getIdName(source_id)
              << "'s type does not match Source <id> '"
              << _.getIdName(source_type->id()) << "'s type.";
     }
   } else {
-    const auto size_id = inst.GetOperandAs<uint32_t>(2);
+    const auto size_id = inst->GetOperandAs<uint32_t>(2);
     const auto size = _.FindDef(size_id);
     if (!size) {
-      return _.diag(SPV_ERROR_INVALID_ID, &inst)
+      return _.diag(SPV_ERROR_INVALID_ID, inst)
              << "Size operand <id> '" << _.getIdName(size_id)
              << "' is not defined.";
     }
 
     const auto size_type = _.FindDef(size->type_id());
     if (!_.IsIntScalarType(size_type->id())) {
-      return _.diag(SPV_ERROR_INVALID_ID, &inst)
+      return _.diag(SPV_ERROR_INVALID_ID, inst)
              << "Size operand <id> '" << _.getIdName(size_id)
              << "' must be a scalar integer type.";
     }
@@ -441,13 +441,13 @@
     bool is_zero = true;
     switch (size->opcode()) {
       case SpvOpConstantNull:
-        return _.diag(SPV_ERROR_INVALID_ID, &inst)
+        return _.diag(SPV_ERROR_INVALID_ID, inst)
                << "Size operand <id> '" << _.getIdName(size_id)
                << "' cannot be a constant zero.";
       case SpvOpConstant:
         if (size_type->word(3) == 1 &&
             size->word(size->words().size() - 1) & 0x80000000) {
-          return _.diag(SPV_ERROR_INVALID_ID, &inst)
+          return _.diag(SPV_ERROR_INVALID_ID, inst)
                  << "Size operand <id> '" << _.getIdName(size_id)
                  << "' cannot have the sign bit set to 1.";
         }
@@ -455,7 +455,7 @@
           is_zero &= (size->word(i) == 0);
         }
         if (is_zero) {
-          return _.diag(SPV_ERROR_INVALID_ID, &inst)
+          return _.diag(SPV_ERROR_INVALID_ID, inst)
                  << "Size operand <id> '" << _.getIdName(size_id)
                  << "' cannot be a constant zero.";
         }
@@ -469,16 +469,16 @@
 }
 
 spv_result_t ValidateAccessChain(ValidationState_t& _,
-                                 const Instruction& inst) {
+                                 const Instruction* inst) {
   std::string instr_name =
-      "Op" + std::string(spvOpcodeString(static_cast<SpvOp>(inst.opcode())));
+      "Op" + std::string(spvOpcodeString(static_cast<SpvOp>(inst->opcode())));
 
   // The result type must be OpTypePointer.
-  auto result_type = _.FindDef(inst.type_id());
+  auto result_type = _.FindDef(inst->type_id());
   if (SpvOpTypePointer != result_type->opcode()) {
-    return _.diag(SPV_ERROR_INVALID_ID, &inst)
+    return _.diag(SPV_ERROR_INVALID_ID, inst)
            << "The Result Type of " << instr_name << " <id> '"
-           << _.getIdName(inst.id()) << "' must be OpTypePointer. Found Op"
+           << _.getIdName(inst->id()) << "' must be OpTypePointer. Found Op"
            << spvOpcodeString(static_cast<SpvOp>(result_type->opcode())) << ".";
   }
 
@@ -489,11 +489,11 @@
 
   // Base must be a pointer, pointing to the base of a composite object.
   const auto base_index = 2;
-  const auto base_id = inst.GetOperandAs<uint32_t>(base_index);
+  const auto base_id = inst->GetOperandAs<uint32_t>(base_index);
   const auto base = _.FindDef(base_id);
   const auto base_type = _.FindDef(base->type_id());
   if (!base_type || SpvOpTypePointer != base_type->opcode()) {
-    return _.diag(SPV_ERROR_INVALID_ID, &inst)
+    return _.diag(SPV_ERROR_INVALID_ID, inst)
            << "The Base <id> '" << _.getIdName(base_id) << "' in " << instr_name
            << " instruction must be a pointer.";
   }
@@ -503,7 +503,7 @@
   auto result_type_storage_class = result_type->word(2);
   auto base_type_storage_class = base_type->word(2);
   if (result_type_storage_class != base_type_storage_class) {
-    return _.diag(SPV_ERROR_INVALID_ID, &inst)
+    return _.diag(SPV_ERROR_INVALID_ID, inst)
            << "The result pointer storage class and base "
               "pointer storage class in "
            << instr_name << " do not match.";
@@ -515,9 +515,9 @@
   // Check Universal Limit (SPIR-V Spec. Section 2.17).
   // The number of indexes passed to OpAccessChain may not exceed 255
   // The instruction includes 4 words + N words (for N indexes)
-  size_t num_indexes = inst.words().size() - 4;
-  if (inst.opcode() == SpvOpPtrAccessChain ||
-      inst.opcode() == SpvOpInBoundsPtrAccessChain) {
+  size_t num_indexes = inst->words().size() - 4;
+  if (inst->opcode() == SpvOpPtrAccessChain ||
+      inst->opcode() == SpvOpInBoundsPtrAccessChain) {
     // In pointer access chains, the element operand is required, but not
     // counted as an index.
     --num_indexes;
@@ -525,7 +525,7 @@
   const size_t num_indexes_limit =
       _.options()->universal_limits_.max_access_chain_indexes;
   if (num_indexes > num_indexes_limit) {
-    return _.diag(SPV_ERROR_INVALID_ID, &inst)
+    return _.diag(SPV_ERROR_INVALID_ID, inst)
            << "The number of indexes in " << instr_name << " may not exceed "
            << num_indexes_limit << ". Found " << num_indexes << " indexes.";
   }
@@ -537,18 +537,18 @@
   // on. Once any non-composite type is reached, there must be no remaining
   // (unused) indexes.
   auto starting_index = 4;
-  if (inst.opcode() == SpvOpPtrAccessChain ||
-      inst.opcode() == SpvOpInBoundsPtrAccessChain) {
+  if (inst->opcode() == SpvOpPtrAccessChain ||
+      inst->opcode() == SpvOpInBoundsPtrAccessChain) {
     ++starting_index;
   }
-  for (size_t i = starting_index; i < inst.words().size(); ++i) {
-    const uint32_t cur_word = inst.words()[i];
+  for (size_t i = starting_index; i < inst->words().size(); ++i) {
+    const uint32_t cur_word = inst->words()[i];
     // Earlier ID checks ensure that cur_word definition exists.
     auto cur_word_instr = _.FindDef(cur_word);
     // The index must be a scalar integer type (See OpAccessChain in the Spec.)
     auto index_type = _.FindDef(cur_word_instr->type_id());
     if (!index_type || SpvOpTypeInt != index_type->opcode()) {
-      return _.diag(SPV_ERROR_INVALID_ID, &inst)
+      return _.diag(SPV_ERROR_INVALID_ID, inst)
              << "Indexes passed to " << instr_name
              << " must be of type integer.";
     }
@@ -607,7 +607,7 @@
   // At this point, we have fully walked down from the base using the indeces.
   // The type being pointed to should be the same as the result type.
   if (type_pointee->id() != result_type_pointee->id()) {
-    return _.diag(SPV_ERROR_INVALID_ID, &inst)
+    return _.diag(SPV_ERROR_INVALID_ID, inst)
            << instr_name << " result type (Op"
            << spvOpcodeString(static_cast<SpvOp>(result_type_pointee->opcode()))
            << ") does not match the type that results from indexing into the "
@@ -622,27 +622,26 @@
 
 }  // namespace
 
-spv_result_t ValidateMemoryInstructions(ValidationState_t& _,
-                                        const Instruction* inst) {
+spv_result_t MemoryPass(ValidationState_t& _, const Instruction* inst) {
   switch (inst->opcode()) {
     case SpvOpVariable:
-      if (auto error = ValidateVariable(_, *inst)) return error;
+      if (auto error = ValidateVariable(_, inst)) return error;
       break;
     case SpvOpLoad:
-      if (auto error = ValidateLoad(_, *inst)) return error;
+      if (auto error = ValidateLoad(_, inst)) return error;
       break;
     case SpvOpStore:
-      if (auto error = ValidateStore(_, *inst)) return error;
+      if (auto error = ValidateStore(_, inst)) return error;
       break;
     case SpvOpCopyMemory:
     case SpvOpCopyMemorySized:
-      if (auto error = ValidateCopyMemory(_, *inst)) return error;
+      if (auto error = ValidateCopyMemory(_, inst)) return error;
       break;
     case SpvOpAccessChain:
     case SpvOpInBoundsAccessChain:
     case SpvOpPtrAccessChain:
     case SpvOpInBoundsPtrAccessChain:
-      if (auto error = ValidateAccessChain(_, *inst)) return error;
+      if (auto error = ValidateAccessChain(_, inst)) return error;
       break;
     case SpvOpImageTexelPointer:
     case SpvOpArrayLength: