Merge "Impose more discipline on uses of ModelArgumentInfo"
diff --git a/nn/runtime/ExecutionBuilder.cpp b/nn/runtime/ExecutionBuilder.cpp
index 4ab401e..6e8d7e8 100644
--- a/nn/runtime/ExecutionBuilder.cpp
+++ b/nn/runtime/ExecutionBuilder.cpp
@@ -125,8 +125,15 @@
         return ANEURALNETWORKS_BAD_DATA;
     }
     uint32_t l = static_cast<uint32_t>(length);
-    return mInputs[index].setFromPointer(mModel->getInputOperand(index), type,
-                                         const_cast<void*>(buffer), l);
+    if (!mInputs[index].unspecified()) {
+        LOG(ERROR) << "ANeuralNetworksExecution_setInput called when an input has already been "
+                      "provided";
+        return ANEURALNETWORKS_BAD_STATE;
+    }
+    int n;
+    std::tie(n, mInputs[index]) = ModelArgumentInfo::createFromPointer(
+            mModel->getInputOperand(index), type, const_cast<void*>(buffer), l);
+    return n;
 }
 
 int ExecutionBuilder::setInputFromMemory(uint32_t index, const ANeuralNetworksOperandType* type,
@@ -162,8 +169,16 @@
     }
     // TODO validate the rest
     uint32_t poolIndex = mMemories.add(memory);
-    return mInputs[index].setFromMemory(mModel->getInputOperand(index), type, poolIndex, offset,
-                                        length);
+    if (!mInputs[index].unspecified()) {
+        LOG(ERROR)
+                << "ANeuralNetworksExecution_setInputFromMemory called when an input has already "
+                   "been provided";
+        return ANEURALNETWORKS_BAD_STATE;
+    }
+    int n;
+    std::tie(n, mInputs[index]) = ModelArgumentInfo::createFromMemory(
+            mModel->getInputOperand(index), type, poolIndex, offset, length);
+    return n;
 }
 
 int ExecutionBuilder::setOutput(uint32_t index, const ANeuralNetworksOperandType* type,
@@ -187,7 +202,15 @@
         return ANEURALNETWORKS_BAD_DATA;
     }
     uint32_t l = static_cast<uint32_t>(length);
-    return mOutputs[index].setFromPointer(mModel->getOutputOperand(index), type, buffer, l);
+    if (!mOutputs[index].unspecified()) {
+        LOG(ERROR) << "ANeuralNetworksExecution_setOutput called when an output has already been "
+                      "provided";
+        return ANEURALNETWORKS_BAD_STATE;
+    }
+    int n;
+    std::tie(n, mOutputs[index]) =
+            ModelArgumentInfo::createFromPointer(mModel->getOutputOperand(index), type, buffer, l);
+    return n;
 }
 
 int ExecutionBuilder::setOutputFromMemory(uint32_t index, const ANeuralNetworksOperandType* type,
@@ -223,8 +246,15 @@
     }
     // TODO validate the rest
     uint32_t poolIndex = mMemories.add(memory);
-    return mOutputs[index].setFromMemory(mModel->getOutputOperand(index), type, poolIndex, offset,
-                                         length);
+    if (!mOutputs[index].unspecified()) {
+        LOG(ERROR) << "ANeuralNetworksExecution_setOutputFromMemory called when an output has "
+                      "already been provided";
+        return ANEURALNETWORKS_BAD_STATE;
+    }
+    int n;
+    std::tie(n, mOutputs[index]) = ModelArgumentInfo::createFromMemory(
+            mModel->getOutputOperand(index), type, poolIndex, offset, length);
+    return n;
 }
 
 int ExecutionBuilder::setMeasureTiming(bool measure) {
@@ -374,15 +404,15 @@
                    << " " << count;
         return ANEURALNETWORKS_BAD_DATA;
     }
-    const auto& dims = mOutputs[index].dimensions;
+    const auto& dims = mOutputs[index].dimensions();
     if (dims.empty()) {
         LOG(ERROR) << "ANeuralNetworksExecution_getOutputOperandDimensions can not query "
                       "dimensions of a scalar";
         return ANEURALNETWORKS_BAD_DATA;
     }
     std::copy(dims.begin(), dims.end(), dimensions);
-    return mOutputs[index].isSufficient ? ANEURALNETWORKS_NO_ERROR
-                                        : ANEURALNETWORKS_OUTPUT_INSUFFICIENT_SIZE;
+    return mOutputs[index].isSufficient() ? ANEURALNETWORKS_NO_ERROR
+                                          : ANEURALNETWORKS_OUTPUT_INSUFFICIENT_SIZE;
 }
 
 int ExecutionBuilder::getOutputOperandRank(uint32_t index, uint32_t* rank) {
@@ -405,9 +435,9 @@
                    << count;
         return ANEURALNETWORKS_BAD_DATA;
     }
-    *rank = static_cast<uint32_t>(mOutputs[index].dimensions.size());
-    return mOutputs[index].isSufficient ? ANEURALNETWORKS_NO_ERROR
-                                        : ANEURALNETWORKS_OUTPUT_INSUFFICIENT_SIZE;
+    *rank = static_cast<uint32_t>(mOutputs[index].dimensions().size());
+    return mOutputs[index].isSufficient() ? ANEURALNETWORKS_NO_ERROR
+                                          : ANEURALNETWORKS_OUTPUT_INSUFFICIENT_SIZE;
 }
 
 // Attempt synchronous execution of full model on CPU.
@@ -705,21 +735,21 @@
     }
     const auto deadline = makeDeadline(mTimeoutDuration);
     for (auto& p : mInputs) {
-        if (p.state == ModelArgumentInfo::UNSPECIFIED) {
+        if (p.state() == ModelArgumentInfo::UNSPECIFIED) {
             LOG(ERROR) << "ANeuralNetworksExecution_startComputeWithDependencies"
                           " not all inputs specified";
             return ANEURALNETWORKS_BAD_DATA;
         }
     }
     for (auto& p : mOutputs) {
-        if (p.state == ModelArgumentInfo::UNSPECIFIED) {
+        if (p.state() == ModelArgumentInfo::UNSPECIFIED) {
             LOG(ERROR) << "ANeuralNetworksExecution_startComputeWithDependencies"
                           " not all outputs specified";
             return ANEURALNETWORKS_BAD_DATA;
         }
     }
     for (uint32_t i = 0; i < mOutputs.size(); i++) {
-        if (mOutputs[i].state != ModelArgumentInfo::HAS_NO_VALUE &&
+        if (mOutputs[i].state() != ModelArgumentInfo::HAS_NO_VALUE &&
             !checkDimensionInfo(mModel->getOutputOperand(i), nullptr,
                                 "ANeuralNetworksExecution_startComputeWithDependencies", false)) {
             LOG(ERROR) << "ANeuralNetworksExecution_startComputeWithDependencies"
@@ -762,18 +792,18 @@
         return ANEURALNETWORKS_BAD_STATE;
     }
     for (auto& p : mInputs) {
-        if (p.state == ModelArgumentInfo::UNSPECIFIED) {
+        if (p.state() == ModelArgumentInfo::UNSPECIFIED) {
             LOG(ERROR) << "ANeuralNetworksExecution_" << name() << " not all inputs specified";
             return ANEURALNETWORKS_BAD_DATA;
-        } else if (p.state == ModelArgumentInfo::MEMORY) {
-            const Memory* memory = mMemories[p.locationAndLength.poolIndex];
-            if (!memory->getValidator().validateInputDimensions(p.dimensions)) {
+        } else if (p.state() == ModelArgumentInfo::MEMORY) {
+            const Memory* memory = mMemories[p.locationAndLength().poolIndex];
+            if (!memory->getValidator().validateInputDimensions(p.dimensions())) {
                 return ANEURALNETWORKS_OP_FAILED;
             }
         }
     }
     for (auto& p : mOutputs) {
-        if (p.state == ModelArgumentInfo::UNSPECIFIED) {
+        if (p.state() == ModelArgumentInfo::UNSPECIFIED) {
             LOG(ERROR) << "ANeuralNetworksExecution_" << name() << " not all outputs specified";
             return ANEURALNETWORKS_BAD_DATA;
         }
@@ -835,7 +865,7 @@
     std::vector<OutputShape> outputShapes(mOutputs.size());
     std::transform(mOutputs.begin(), mOutputs.end(), outputShapes.begin(),
                    [](const auto& x) -> OutputShape {
-                       return {.dimensions = x.dimensions, .isSufficient = true};
+                       return {.dimensions = x.dimensions(), .isSufficient = true};
                    });
     return outputShapes;
 }
@@ -858,20 +888,20 @@
     NN_RET_CHECK_EQ(outputShapes.size(), mOutputs.size());
     for (uint32_t i = 0; i < outputShapes.size(); i++) {
         // Check if only unspecified dimensions or rank are overwritten.
-        NN_RET_CHECK(isUpdatable(mOutputs[i].dimensions, outputShapes[i].dimensions));
+        NN_RET_CHECK(isUpdatable(mOutputs[i].dimensions(), outputShapes[i].dimensions));
     }
     for (uint32_t i = 0; i < outputShapes.size(); i++) {
-        mOutputs[i].dimensions = outputShapes[i].dimensions;
-        mOutputs[i].isSufficient = outputShapes[i].isSufficient;
+        mOutputs[i].dimensions() = outputShapes[i].dimensions;
+        mOutputs[i].isSufficient() = outputShapes[i].isSufficient;
     }
     return true;
 }
 
 bool ExecutionBuilder::updateMemories() {
     for (const auto& output : mOutputs) {
-        if (output.state != ModelArgumentInfo::MEMORY) continue;
-        const Memory* memory = mMemories[output.locationAndLength.poolIndex];
-        NN_RET_CHECK(memory->getValidator().updateMetadata({.dimensions = output.dimensions}));
+        if (output.state() != ModelArgumentInfo::MEMORY) continue;
+        const Memory* memory = mMemories[output.locationAndLength().poolIndex];
+        NN_RET_CHECK(memory->getValidator().updateMetadata({.dimensions = output.dimensions()}));
     }
     return true;
 }
@@ -885,8 +915,8 @@
     }
     bool success = status == ErrorStatus::NONE;
     for (const auto& output : mOutputs) {
-        if (output.state != ModelArgumentInfo::MEMORY) continue;
-        const Memory* memory = mMemories[output.locationAndLength.poolIndex];
+        if (output.state() != ModelArgumentInfo::MEMORY) continue;
+        const Memory* memory = mMemories[output.locationAndLength().poolIndex];
         memory->getValidator().setInitialized(success);
     }
     return status;
@@ -940,7 +970,7 @@
 void StepExecutor::mapInputOrOutput(const ModelArgumentInfo& builderInputOrOutput,
                                     ModelArgumentInfo* executorInputOrOutput) {
     *executorInputOrOutput = builderInputOrOutput;
-    switch (executorInputOrOutput->state) {
+    switch (executorInputOrOutput->state()) {
         default:
             CHECK(false) << "unexpected ModelArgumentInfo::state";
             break;
@@ -949,10 +979,10 @@
         case ModelArgumentInfo::UNSPECIFIED:
             break;
         case ModelArgumentInfo::MEMORY: {
-            const uint32_t builderPoolIndex = builderInputOrOutput.locationAndLength.poolIndex;
+            const uint32_t builderPoolIndex = builderInputOrOutput.locationAndLength().poolIndex;
             const Memory* memory = mExecutionBuilder->mMemories[builderPoolIndex];
             const uint32_t executorPoolIndex = mMemories.add(memory);
-            executorInputOrOutput->locationAndLength.poolIndex = executorPoolIndex;
+            executorInputOrOutput->locationAndLength().poolIndex = executorPoolIndex;
             break;
         }
     }
@@ -967,22 +997,26 @@
 
     uint32_t poolIndex = mMemories.add(memory);
     uint32_t length = TypeManager::get()->getSizeOfData(inputOrOutputOperand);
-    return inputOrOutputInfo->setFromMemory(inputOrOutputOperand, /*type=*/nullptr, poolIndex,
-                                            offset, length);
+    CHECK(inputOrOutputInfo->unspecified());
+    int n;
+    std::tie(n, *inputOrOutputInfo) =
+            ModelArgumentInfo::createFromMemory(inputOrOutputOperand,
+                                                /*type=*/nullptr, poolIndex, offset, length);
+    return n;
 }
 
 static void logArguments(const char* kind, const std::vector<ModelArgumentInfo>& args) {
     for (unsigned i = 0; i < args.size(); i++) {
         const auto& arg = args[i];
         std::string prefix = kind + std::string("[") + std::to_string(i) + "] = ";
-        switch (arg.state) {
+        switch (arg.state()) {
             case ModelArgumentInfo::POINTER:
-                VLOG(EXECUTION) << prefix << "POINTER(" << SHOW_IF_DEBUG(arg.buffer) << ")";
+                VLOG(EXECUTION) << prefix << "POINTER(" << SHOW_IF_DEBUG(arg.buffer()) << ")";
                 break;
             case ModelArgumentInfo::MEMORY:
                 VLOG(EXECUTION) << prefix << "MEMORY("
-                                << "pool=" << arg.locationAndLength.poolIndex << ", "
-                                << "off=" << arg.locationAndLength.offset << ")";
+                                << "pool=" << arg.locationAndLength().poolIndex << ", "
+                                << "off=" << arg.locationAndLength().offset << ")";
                 break;
             case ModelArgumentInfo::HAS_NO_VALUE:
                 VLOG(EXECUTION) << prefix << "HAS_NO_VALUE";
@@ -991,7 +1025,7 @@
                 VLOG(EXECUTION) << prefix << "UNSPECIFIED";
                 break;
             default:
-                VLOG(EXECUTION) << prefix << "state(" << arg.state << ")";
+                VLOG(EXECUTION) << prefix << "state(" << arg.state() << ")";
                 break;
         }
     }
diff --git a/nn/runtime/ExecutionPlan.cpp b/nn/runtime/ExecutionPlan.cpp
index f19065b..91ef6b1 100644
--- a/nn/runtime/ExecutionPlan.cpp
+++ b/nn/runtime/ExecutionPlan.cpp
@@ -1003,14 +1003,14 @@
 
 std::optional<ExecutionPlan::Buffer> ExecutionPlan::getBufferFromModelArgumentInfo(
         const ModelArgumentInfo& info, const ExecutionBuilder* executionBuilder) const {
-    switch (info.state) {
+    switch (info.state()) {
         case ModelArgumentInfo::POINTER: {
-            return Buffer(info.buffer, info.locationAndLength.length);
+            return Buffer(info.buffer(), info.length());
         } break;
         case ModelArgumentInfo::MEMORY: {
             if (std::optional<RunTimePoolInfo> poolInfo =
-                        executionBuilder->getRunTimePoolInfo(info.locationAndLength.poolIndex)) {
-                return Buffer(*poolInfo, info.locationAndLength.offset);
+                        executionBuilder->getRunTimePoolInfo(info.locationAndLength().poolIndex)) {
+                return Buffer(*poolInfo, info.locationAndLength().offset);
             } else {
                 LOG(ERROR) << "Unable to map operand memory pool";
                 return std::nullopt;
@@ -1021,7 +1021,7 @@
             return std::nullopt;
         } break;
         default: {
-            LOG(ERROR) << "Unexpected operand memory state: " << static_cast<int>(info.state);
+            LOG(ERROR) << "Unexpected operand memory state: " << static_cast<int>(info.state());
             return std::nullopt;
         } break;
     }
diff --git a/nn/runtime/Manager.cpp b/nn/runtime/Manager.cpp
index 1592a3d..96bdccc 100644
--- a/nn/runtime/Manager.cpp
+++ b/nn/runtime/Manager.cpp
@@ -282,14 +282,13 @@
     const uint32_t nextPoolIndex = memories->size();
     int64_t total = 0;
     for (const auto& info : args) {
-        if (info.state == ModelArgumentInfo::POINTER) {
-            const DataLocation& loc = info.locationAndLength;
+        if (info.state() == ModelArgumentInfo::POINTER) {
             // TODO Good enough alignment?
-            total += alignBytesNeeded(static_cast<uint32_t>(total), loc.length);
+            total += alignBytesNeeded(static_cast<uint32_t>(total), info.length());
             ptrArgsLocations.push_back({.poolIndex = nextPoolIndex,
                                         .offset = static_cast<uint32_t>(total),
-                                        .length = loc.length});
-            total += loc.length;
+                                        .length = info.length()});
+            total += info.length();
         }
     };
     if (total > 0xFFFFFFFF) {
@@ -348,10 +347,10 @@
     if (inputPtrArgsMemory != nullptr) {
         uint32_t ptrInputIndex = 0;
         for (const auto& info : inputs) {
-            if (info.state == ModelArgumentInfo::POINTER) {
+            if (info.state() == ModelArgumentInfo::POINTER) {
                 const DataLocation& loc = inputPtrArgsLocations[ptrInputIndex++];
                 uint8_t* const data = inputPtrArgsMemory->getPointer();
-                memcpy(data + loc.offset, info.buffer, loc.length);
+                memcpy(data + loc.offset, info.buffer(), loc.length);
             }
         }
     }
@@ -412,10 +411,10 @@
     if (outputPtrArgsMemory != nullptr) {
         uint32_t ptrOutputIndex = 0;
         for (const auto& info : outputs) {
-            if (info.state == ModelArgumentInfo::POINTER) {
+            if (info.state() == ModelArgumentInfo::POINTER) {
                 const DataLocation& loc = outputPtrArgsLocations[ptrOutputIndex++];
                 const uint8_t* const data = outputPtrArgsMemory->getPointer();
-                memcpy(info.buffer, data + loc.offset, loc.length);
+                memcpy(info.buffer(), data + loc.offset, loc.length);
             }
         }
     }
@@ -457,10 +456,10 @@
     if (inputPtrArgsMemory != nullptr) {
         uint32_t ptrInputIndex = 0;
         for (const auto& info : inputs) {
-            if (info.state == ModelArgumentInfo::POINTER) {
+            if (info.state() == ModelArgumentInfo::POINTER) {
                 const DataLocation& loc = inputPtrArgsLocations[ptrInputIndex++];
                 uint8_t* const data = inputPtrArgsMemory->getPointer();
-                memcpy(data + loc.offset, info.buffer, loc.length);
+                memcpy(data + loc.offset, info.buffer(), loc.length);
             }
         }
     }
@@ -528,10 +527,10 @@
         }
         uint32_t ptrOutputIndex = 0;
         for (const auto& info : outputs) {
-            if (info.state == ModelArgumentInfo::POINTER) {
+            if (info.state() == ModelArgumentInfo::POINTER) {
                 const DataLocation& loc = outputPtrArgsLocations[ptrOutputIndex++];
                 const uint8_t* const data = outputPtrArgsMemory->getPointer();
-                memcpy(info.buffer, data + loc.offset, loc.length);
+                memcpy(info.buffer(), data + loc.offset, loc.length);
             }
         }
     }
@@ -768,13 +767,13 @@
             [&requestPoolInfos](const std::vector<ModelArgumentInfo>& argumentInfos) {
                 std::vector<DataLocation> ptrArgsLocations;
                 for (const ModelArgumentInfo& argumentInfo : argumentInfos) {
-                    if (argumentInfo.state == ModelArgumentInfo::POINTER) {
+                    if (argumentInfo.state() == ModelArgumentInfo::POINTER) {
                         ptrArgsLocations.push_back(
                                 {.poolIndex = static_cast<uint32_t>(requestPoolInfos.size()),
                                  .offset = 0,
-                                 .length = argumentInfo.locationAndLength.length});
+                                 .length = argumentInfo.length()});
                         requestPoolInfos.emplace_back(RunTimePoolInfo::createFromExistingBuffer(
-                                static_cast<uint8_t*>(argumentInfo.buffer)));
+                                static_cast<uint8_t*>(argumentInfo.buffer())));
                     }
                 }
                 return ptrArgsLocations;
diff --git a/nn/runtime/Manager.h b/nn/runtime/Manager.h
index 0ece99b..4dac086 100644
--- a/nn/runtime/Manager.h
+++ b/nn/runtime/Manager.h
@@ -39,8 +39,8 @@
 class Device;
 class ExecutionBurstController;
 class MetaModel;
+class ModelArgumentInfo;
 class VersionedIPreparedModel;
-struct ModelArgumentInfo;
 
 // A unified interface for actual driver prepared model as well as the CPU.
 class PreparedModel {
diff --git a/nn/runtime/ModelArgumentInfo.cpp b/nn/runtime/ModelArgumentInfo.cpp
index f8ddbfe..cf24004 100644
--- a/nn/runtime/ModelArgumentInfo.cpp
+++ b/nn/runtime/ModelArgumentInfo.cpp
@@ -19,6 +19,7 @@
 #include "ModelArgumentInfo.h"
 
 #include <algorithm>
+#include <utility>
 #include <vector>
 
 #include "HalInterfaces.h"
@@ -31,61 +32,74 @@
 
 using namespace hal;
 
-int ModelArgumentInfo::setFromPointer(const Operand& operand,
-                                      const ANeuralNetworksOperandType* type, void* data,
-                                      uint32_t length) {
+static const std::pair<int, ModelArgumentInfo> kBadDataModelArgumentInfo{ANEURALNETWORKS_BAD_DATA,
+                                                                         {}};
+
+std::pair<int, ModelArgumentInfo> ModelArgumentInfo::createFromPointer(
+        const Operand& operand, const ANeuralNetworksOperandType* type, void* data,
+        uint32_t length) {
     if ((data == nullptr) != (length == 0)) {
         const char* dataPtrMsg = data ? "NOT_NULLPTR" : "NULLPTR";
         LOG(ERROR) << "Data pointer must be nullptr if and only if length is zero (data = "
                    << dataPtrMsg << ", length = " << length << ")";
-        return ANEURALNETWORKS_BAD_DATA;
+        return kBadDataModelArgumentInfo;
     }
+
+    ModelArgumentInfo ret;
     if (data == nullptr) {
-        state = ModelArgumentInfo::HAS_NO_VALUE;
+        ret.mState = ModelArgumentInfo::HAS_NO_VALUE;
     } else {
-        NN_RETURN_IF_ERROR(updateDimensionInfo(operand, type));
+        if (int n = ret.updateDimensionInfo(operand, type)) {
+            return {n, ModelArgumentInfo()};
+        }
         if (operand.type != OperandType::OEM) {
-            uint32_t neededLength = TypeManager::get()->getSizeOfData(operand.type, dimensions);
+            uint32_t neededLength =
+                    TypeManager::get()->getSizeOfData(operand.type, ret.mDimensions);
             if (neededLength != length && neededLength != 0) {
                 LOG(ERROR) << "Setting argument with invalid length: " << length
                            << ", expected length: " << neededLength;
-                return ANEURALNETWORKS_BAD_DATA;
+                return kBadDataModelArgumentInfo;
             }
         }
-        state = ModelArgumentInfo::POINTER;
+        ret.mState = ModelArgumentInfo::POINTER;
     }
-    buffer = data;
-    locationAndLength = {.poolIndex = 0, .offset = 0, .length = length};
-    return ANEURALNETWORKS_NO_ERROR;
+    ret.mBuffer = data;
+    ret.mLocationAndLength = {.poolIndex = 0, .offset = 0, .length = length};
+    return {ANEURALNETWORKS_NO_ERROR, ret};
 }
 
-int ModelArgumentInfo::setFromMemory(const Operand& operand, const ANeuralNetworksOperandType* type,
-                                     uint32_t poolIndex, uint32_t offset, uint32_t length) {
-    NN_RETURN_IF_ERROR(updateDimensionInfo(operand, type));
+std::pair<int, ModelArgumentInfo> ModelArgumentInfo::createFromMemory(
+        const Operand& operand, const ANeuralNetworksOperandType* type, uint32_t poolIndex,
+        uint32_t offset, uint32_t length) {
+    ModelArgumentInfo ret;
+    if (int n = ret.updateDimensionInfo(operand, type)) {
+        return {n, ModelArgumentInfo()};
+    }
     const bool isMemorySizeKnown = offset != 0 || length != 0;
     if (isMemorySizeKnown && operand.type != OperandType::OEM) {
-        const uint32_t neededLength = TypeManager::get()->getSizeOfData(operand.type, dimensions);
+        const uint32_t neededLength =
+                TypeManager::get()->getSizeOfData(operand.type, ret.mDimensions);
         if (neededLength != length && neededLength != 0) {
             LOG(ERROR) << "Setting argument with invalid length: " << length
                        << " (offset: " << offset << "), expected length: " << neededLength;
-            return ANEURALNETWORKS_BAD_DATA;
+            return kBadDataModelArgumentInfo;
         }
     }
 
-    state = ModelArgumentInfo::MEMORY;
-    locationAndLength = {.poolIndex = poolIndex, .offset = offset, .length = length};
-    buffer = nullptr;
-    return ANEURALNETWORKS_NO_ERROR;
+    ret.mState = ModelArgumentInfo::MEMORY;
+    ret.mLocationAndLength = {.poolIndex = poolIndex, .offset = offset, .length = length};
+    ret.mBuffer = nullptr;
+    return {ANEURALNETWORKS_NO_ERROR, ret};
 }
 
 int ModelArgumentInfo::updateDimensionInfo(const Operand& operand,
                                            const ANeuralNetworksOperandType* newType) {
     if (newType == nullptr) {
-        dimensions = operand.dimensions;
+        mDimensions = operand.dimensions;
     } else {
         const uint32_t count = newType->dimensionCount;
-        dimensions = hidl_vec<uint32_t>(count);
-        std::copy(&newType->dimensions[0], &newType->dimensions[count], dimensions.begin());
+        mDimensions = hidl_vec<uint32_t>(count);
+        std::copy(&newType->dimensions[0], &newType->dimensions[count], mDimensions.begin());
     }
     return ANEURALNETWORKS_NO_ERROR;
 }
@@ -98,12 +112,22 @@
     uint32_t ptrArgsIndex = 0;
     for (size_t i = 0; i < count; i++) {
         const auto& info = argumentInfos[i];
-        ioInfos[i] = {
-                .hasNoValue = info.state == ModelArgumentInfo::HAS_NO_VALUE,
-                .location = info.state == ModelArgumentInfo::POINTER
-                                    ? ptrArgsLocations[ptrArgsIndex++]
-                                    : info.locationAndLength,
-                .dimensions = info.dimensions,
+        switch (info.state()) {
+            case ModelArgumentInfo::POINTER:
+                ioInfos[i] = {.hasNoValue = false,
+                              .location = ptrArgsLocations[ptrArgsIndex++],
+                              .dimensions = info.dimensions()};
+                break;
+            case ModelArgumentInfo::MEMORY:
+                ioInfos[i] = {.hasNoValue = false,
+                              .location = info.locationAndLength(),
+                              .dimensions = info.dimensions()};
+                break;
+            case ModelArgumentInfo::HAS_NO_VALUE:
+                ioInfos[i] = {.hasNoValue = true};
+                break;
+            default:
+                CHECK(false);
         };
     }
     return ioInfos;
diff --git a/nn/runtime/ModelArgumentInfo.h b/nn/runtime/ModelArgumentInfo.h
index 61dfc1e..22dd34c 100644
--- a/nn/runtime/ModelArgumentInfo.h
+++ b/nn/runtime/ModelArgumentInfo.h
@@ -17,36 +17,93 @@
 #ifndef ANDROID_FRAMEWORKS_ML_NN_RUNTIME_MODEL_ARGUMENT_INFO_H
 #define ANDROID_FRAMEWORKS_ML_NN_RUNTIME_MODEL_ARGUMENT_INFO_H
 
+#include <utility>
 #include <vector>
 
 #include "HalInterfaces.h"
 #include "NeuralNetworks.h"
+#include "Utils.h"
 
 namespace android {
 namespace nn {
 
 // TODO move length out of DataLocation
-struct ModelArgumentInfo {
+//
+// NOTE: The primary usage pattern is that a ModelArgumentInfo instance
+//       is not modified once it is created (unless it is reassigned to).
+//       There are a small number of use cases where it NEEDS to be modified,
+//       and we have a limited number of methods that support this.
+class ModelArgumentInfo {
+   public:
+    ModelArgumentInfo() {}
+
+    static std::pair<int, ModelArgumentInfo> createFromPointer(
+            const hal::Operand& operand, const ANeuralNetworksOperandType* type,
+            void* data /* nullptr means HAS_NO_VALUE */, uint32_t length);
+    static std::pair<int, ModelArgumentInfo> createFromMemory(
+            const hal::Operand& operand, const ANeuralNetworksOperandType* type, uint32_t poolIndex,
+            uint32_t offset, uint32_t length);
+
+    enum State { POINTER, MEMORY, HAS_NO_VALUE, UNSPECIFIED };
+
+    State state() const { return mState; }
+
+    bool unspecified() const { return mState == UNSPECIFIED; }
+
+    void* buffer() const {
+        CHECK_EQ(mState, POINTER);
+        return mBuffer;
+    }
+
+    const std::vector<uint32_t>& dimensions() const {
+        CHECK(mState == POINTER || mState == MEMORY);
+        return mDimensions;
+    }
+    std::vector<uint32_t>& dimensions() {
+        CHECK(mState == POINTER || mState == MEMORY);
+        return mDimensions;
+    }
+
+    bool isSufficient() const {
+        CHECK(mState == POINTER || mState == MEMORY);
+        return mIsSufficient;
+    }
+    bool& isSufficient() {
+        CHECK(mState == POINTER || mState == MEMORY);
+        return mIsSufficient;
+    }
+
+    uint32_t length() const {
+        CHECK(mState == POINTER || mState == MEMORY);
+        return mLocationAndLength.length;
+    }
+
+    const hal::DataLocation& locationAndLength() const {
+        CHECK_EQ(mState, MEMORY);
+        return mLocationAndLength;
+    }
+    hal::DataLocation& locationAndLength() {
+        CHECK_EQ(mState, MEMORY);
+        return mLocationAndLength;
+    }
+
+   private:
+    int updateDimensionInfo(const hal::Operand& operand, const ANeuralNetworksOperandType* newType);
+
     // Whether the argument was specified as being in a Memory, as a pointer,
     // has no value, or has not been specified.
     // If POINTER then:
-    //   locationAndLength.length is valid.
-    //   dimensions is valid.
-    //   buffer is valid
+    //   mLocationAndLength.length is valid.
+    //   mDimensions is valid.
+    //   mBuffer is valid.
     // If MEMORY then:
-    //   locationAndLength.{poolIndex, offset, length} is valid.
-    //   dimensions is valid.
-    enum { POINTER, MEMORY, HAS_NO_VALUE, UNSPECIFIED } state = UNSPECIFIED;
-    hal::DataLocation locationAndLength;
-    std::vector<uint32_t> dimensions;
-    void* buffer;
-    bool isSufficient = true;
-
-    int setFromPointer(const hal::Operand& operand, const ANeuralNetworksOperandType* type,
-                       void* buffer, uint32_t length);
-    int setFromMemory(const hal::Operand& operand, const ANeuralNetworksOperandType* type,
-                      uint32_t poolIndex, uint32_t offset, uint32_t length);
-    int updateDimensionInfo(const hal::Operand& operand, const ANeuralNetworksOperandType* newType);
+    //   mLocationAndLength.{poolIndex, offset, length} is valid.
+    //   mDimensions is valid.
+    State mState = UNSPECIFIED;            // fixed at creation
+    void* mBuffer = nullptr;               // fixed at creation
+    hal::DataLocation mLocationAndLength;  // can be updated after creation
+    std::vector<uint32_t> mDimensions;     // can be updated after creation
+    bool mIsSufficient = true;             // can be updated after creation
 };
 
 // Convert ModelArgumentInfo to HIDL RequestArgument. For pointer arguments, use the location
diff --git a/nn/runtime/test/TestValidation.cpp b/nn/runtime/test/TestValidation.cpp
index 9e1f910..1b3e6be 100644
--- a/nn/runtime/test/TestValidation.cpp
+++ b/nn/runtime/test/TestValidation.cpp
@@ -1201,6 +1201,12 @@
     EXPECT_EQ(ANeuralNetworksExecution_setInput(mExecution, 0, &kInvalidTensorType2, buffer,
                                                 sizeof(float)),
               ANEURALNETWORKS_BAD_DATA);
+
+    // Cannot do this twice.
+    EXPECT_EQ(ANeuralNetworksExecution_setInput(mExecution, 0, nullptr, buffer, 8),
+              ANEURALNETWORKS_NO_ERROR);
+    EXPECT_EQ(ANeuralNetworksExecution_setInput(mExecution, 0, nullptr, buffer, 8),
+              ANEURALNETWORKS_BAD_STATE);
 }
 
 TEST_F(ValidationTestExecution, SetOutput) {
@@ -1229,6 +1235,12 @@
     EXPECT_EQ(ANeuralNetworksExecution_setOutput(mExecution, 0, &kInvalidTensorType2, buffer,
                                                  sizeof(float)),
               ANEURALNETWORKS_BAD_DATA);
+
+    // Cannot do this twice.
+    EXPECT_EQ(ANeuralNetworksExecution_setOutput(mExecution, 0, nullptr, buffer, 8),
+              ANEURALNETWORKS_NO_ERROR);
+    EXPECT_EQ(ANeuralNetworksExecution_setOutput(mExecution, 0, nullptr, buffer, 8),
+              ANEURALNETWORKS_BAD_STATE);
 }
 
 TEST_F(ValidationTestExecution, SetInputFromMemory) {
@@ -1281,6 +1293,15 @@
                                                           memory, 0, sizeof(float)),
               ANEURALNETWORKS_BAD_DATA);
 
+    // Cannot do this twice.
+    EXPECT_EQ(ANeuralNetworksExecution_setInputFromMemory(mExecution, 0, nullptr, memory, 0, 8),
+              ANEURALNETWORKS_NO_ERROR);
+    EXPECT_EQ(ANeuralNetworksExecution_setInputFromMemory(mExecution, 0, nullptr, memory, 0, 8),
+              ANEURALNETWORKS_BAD_STATE);
+    char buffer[memorySize];
+    EXPECT_EQ(ANeuralNetworksExecution_setInput(mExecution, 0, nullptr, buffer, 8),
+              ANEURALNETWORKS_BAD_STATE);
+
     // close memory
     close(memoryFd);
 }
@@ -1381,6 +1402,15 @@
                                                            memory, 0, sizeof(float)),
               ANEURALNETWORKS_BAD_DATA);
 
+    // Cannot do this twice.
+    EXPECT_EQ(ANeuralNetworksExecution_setOutputFromMemory(execution, 0, nullptr, memory, 0, 8),
+              ANEURALNETWORKS_NO_ERROR);
+    EXPECT_EQ(ANeuralNetworksExecution_setOutputFromMemory(execution, 0, nullptr, memory, 0, 8),
+              ANEURALNETWORKS_BAD_STATE);
+    char buffer[memorySize];
+    EXPECT_EQ(ANeuralNetworksExecution_setOutput(execution, 0, nullptr, buffer, 8),
+              ANEURALNETWORKS_BAD_STATE);
+
     // close memory
     close(memoryFd);
 }