Dont memory plan for inputs (#2155)
Summary:
Pull Request resolved: https://github.com/pytorch/executorch/pull/2155
For KV cache with IO tHis results in
1. allocating kv cache in the memory plan but also allocated by llama runner
2. Doing actual copy of kv cache
Also we should really make plan_input = false by default. I dont imagine a case
where this does not result in making copies. Planning for output is fine but
still dangerous as people may assume having reference to output tensor is all
good without realizing the underlying memory being shared.
ghstack-source-id: 216889056
exported-using-ghexport
validated oss ci is clean. have to by pass because ci think its needs internal linter to pass.
bypass-github-export-checks
Reviewed By: mergennachin
Differential Revision: D54161288
fbshipit-source-id: b5e7aa42d4a72e455550af5d7467f46f2a1017f8
diff --git a/examples/models/llama2/builder.py b/examples/models/llama2/builder.py
index fd055c0..1b23b20 100644
--- a/examples/models/llama2/builder.py
+++ b/examples/models/llama2/builder.py
@@ -21,6 +21,8 @@
from executorch.exir import EdgeProgramManager
from executorch.exir.backend.partitioner import Partitioner
from executorch.exir.capture._config import EdgeCompileConfig, ExecutorchBackendConfig
+
+from executorch.exir.passes import MemoryPlanningPass
from executorch.exir.passes.quant_fusion_pass import QuantFusionPass
from executorch.exir.passes.sym_shape_eval_pass import ConstraintBasedSymShapeEvalPass
from torch._export import capture_pre_autograd_graph
@@ -310,6 +312,9 @@
passes=[
QuantFusionPass(),
],
+ memory_planning_pass=MemoryPlanningPass(
+ "greedy", alloc_graph_input=False
+ ),
sym_shape_eval_pass=ConstraintBasedSymShapeEvalPass(),
)
)
diff --git a/runtime/executor/tensor_parser_aten.cpp b/runtime/executor/tensor_parser_aten.cpp
index e407d07..c310bb8 100644
--- a/runtime/executor/tensor_parser_aten.cpp
+++ b/runtime/executor/tensor_parser_aten.cpp
@@ -109,11 +109,6 @@
ET_LOG(Error, "getTensorDataPtr() failed: 0x%" PRIx32, data_ptr.error());
return data_ptr.error();
}
- ET_CHECK_OR_RETURN_ERROR(
- data_ptr.get() != nullptr,
- Internal,
- "Expected non-null data for tensor with shape dynamism %d",
- int(s_tensor->shape_dynamism()));
tensor.unsafeGetTensorImpl()->unsafe_storage().set_data_ptr(
at::DataPtr(data_ptr.get(), DeviceType::CPU));
}