Refactoring secure_storage_manager

Refactoring secure file handling on secure_storage_manager

Bug: 253906012
Bug: 253926844
Bug: 258602662
Test: Build.py, run keymint TA unittests
Change-Id: I7b2e36df65176a8d7cce7925c739afdde4327619
diff --git a/secure_storage_manager.rs b/secure_storage_manager.rs
index de3b415..a1d421e 100644
--- a/secure_storage_manager.rs
+++ b/secure_storage_manager.rs
@@ -51,30 +51,70 @@
     format!("{}.{}", KM_ATTESTATION_KEY_CERT_PREFIX, suffix)
 }
 
-// TODO: refactor functions to have (Session, SecureFile) inside a single structure and implement
-//       drop on it to call close automatically.
-/// Opens a secure storage session and creates the requested file
-fn create_file(file_name: &str) -> Result<(Session, SecureFile), Error> {
-    let mut session = Session::new(Port::TamperProof, true).map_err(|e| {
-        km_err!(SecureHwCommunicationFailed, "couldn't create storage session: {:?}", e)
-    })?;
-    let file = session.open_file(file_name, OpenMode::Create).map_err(|e| {
-        km_err!(SecureHwCommunicationFailed, "couldn't create file {}: {:?}", file_name, e)
-    })?;
-    Ok((session, file))
+// `session` and `secure_file` are of type `Option` because close() takes self by value, so this was
+// needed for the `Drop` implementation. The intent though is that they should always be populated
+// on a OpenSecureFile object; which is `OpenSecureFile::new` behavior.
+struct OpenSecureFile {
+    session: Option<Session>,
+    secure_file: Option<SecureFile>,
+}
+
+impl OpenSecureFile {
+    /// Opens a secure storage session and creates the requested file
+    fn new(file_name: &str) -> Result<Self, Error> {
+        let mut session = Session::new(Port::TamperProof, true).map_err(|e| {
+            km_err!(SecureHwCommunicationFailed, "couldn't create storage session: {:?}", e)
+        })?;
+        let secure_file = session.open_file(file_name, OpenMode::Create).map_err(|e| {
+            km_err!(SecureHwCommunicationFailed, "couldn't create file {}: {:?}", file_name, e)
+        })?;
+        Ok(OpenSecureFile { session: Some(session), secure_file: Some(secure_file) })
+    }
+
+    /// Writes provided data in the previously opened file
+    fn write_all(&mut self, data: &[u8]) -> Result<(), Error> {
+        // Even though we are handling the case when secure_file and session are None, this is not
+        // expected; if an OpenSecureFile object exists its `secure_file` and `session` elements
+        // shall be populated.
+        let session =
+            self.session.as_mut().ok_or(km_err!(UnknownError, "session shouldn't ever be None"))?;
+        let file = self
+            .secure_file
+            .as_mut()
+            .ok_or(km_err!(UnknownError, "secure_file shouldn't ever be None"))?;
+        session.write_all(file, data).map_err(|e| {
+            km_err!(SecureHwCommunicationFailed, "failed to write data; received error: {:?}", e)
+        })
+    }
+
+    /// Close the session and file handlers by taking ownership and letting the value be dropped
+    #[cfg(test)]
+    fn close(self) {}
+}
+
+impl Drop for OpenSecureFile {
+    fn drop(&mut self) {
+        // Even though we are handling the case when secure_file and session are None, this is not
+        // expected; if an OpenSecureFile object exists its `secure_file` and `session` elements
+        // shall be populated.
+        if let Some(file) = self.secure_file.take() {
+            file.close();
+        }
+        if let Some(session) = self.session.take() {
+            session.close();
+        }
+    }
 }
 
 /// Creates an empty attestation IDs file
-fn create_attestation_id_file() -> Result<(Session, SecureFile), Error> {
-    create_file(KM_ATTESTATION_ID_FILENAME)
+fn create_attestation_id_file() -> Result<OpenSecureFile, Error> {
+    OpenSecureFile::new(KM_ATTESTATION_ID_FILENAME)
 }
 
 /// Creates and empty attestation key/certificates file for the given algorithm
-fn create_attestation_key_file(
-    algorithm: SigningAlgorithm,
-) -> Result<(Session, SecureFile), Error> {
+fn create_attestation_key_file(algorithm: SigningAlgorithm) -> Result<OpenSecureFile, Error> {
     let file_name = get_key_slot_file_name(algorithm);
-    create_file(&file_name)
+    OpenSecureFile::new(&file_name)
 }
 
 /// Creates a new attestation key/certificates file and saves the provided data there
@@ -82,12 +122,10 @@
     algorithm: SigningAlgorithm,
     key_data: &[u8],
 ) -> Result<(), Error> {
-    let (mut session, mut file) = create_attestation_key_file(algorithm)?;
-    session.write_all(&mut file, &key_data).map_err(|e| {
+    let mut file = create_attestation_key_file(algorithm)?;
+    file.write_all(&key_data).map_err(|e| {
         km_err!(SecureHwCommunicationFailed, "failed to provision attestation key file: {:?}", e)
     })?;
-    file.close();
-    session.close();
     Ok(())
 }
 
@@ -102,7 +140,7 @@
     manufacturer: &[u8],
     model: &[u8],
 ) -> Result<(), Error> {
-    let (mut session, mut file) = create_attestation_id_file()?;
+    let mut file = create_attestation_id_file()?;
 
     let mut attestation_ids = keymaster_attributes::AttestationIds::new();
 
@@ -135,12 +173,10 @@
         km_err!(SecureHwCommunicationFailed, "couldn't serialize attestationIds: {:?}", e)
     })?;
 
-    session.write_all(&mut file, &serialized_buffer).map_err(|e| {
+    file.write_all(&serialized_buffer).map_err(|e| {
         km_err!(SecureHwCommunicationFailed, "failed to provision attestation IDs file: {:?}", e)
     })?;
 
-    file.close();
-    session.close();
     Ok(())
 }
 
@@ -339,7 +375,7 @@
     }
 
     fn read_certificates_test(algorithm: SigningAlgorithm) {
-        let (mut session, mut file) =
+        let mut file =
             create_attestation_key_file(algorithm).expect("Couldn't create attestation key file");
         let mut key_cert = keymaster_attributes::AttestationKey::new();
         let certs_data = [[b'a'; 2048], [b'\0'; 2048], [b'c'; 2048]];
@@ -356,10 +392,8 @@
 
         let serialized_buffer = key_cert.write_to_bytes().expect("Couldn't serialize certs");
 
-        session.write_all(&mut file, &serialized_buffer).unwrap();
-
+        file.write_all(&serialized_buffer).unwrap();
         file.close();
-        session.close();
 
         let key_type = SigningKeyType { which: SigningKey::Batch, algo_hint: algorithm };
         let certs = get_cert_chain(key_type).expect("Couldn't get certificates from storage");
@@ -377,12 +411,12 @@
             raw_protobuf.extend_from_slice(&field_header);
             raw_protobuf.extend_from_slice(cert_data);
         }
-        let (mut session, mut file) =
-            create_attestation_key_file(algorithm).expect("Couldn't create attestation key file");
-        session.write_all(&mut file, &raw_protobuf).unwrap();
 
+        let mut file =
+            create_attestation_key_file(algorithm).expect("Couldn't create attestation key file");
+        file.write_all(&raw_protobuf).unwrap();
         file.close();
-        session.close();
+
         let certs_comp = get_cert_chain(key_type).expect("Couldn't get certificates from storage");
 
         expect_eq!(certs, certs_comp, "Retrieved certificates didn't match");
@@ -397,7 +431,7 @@
     }
 
     fn read_key_test(algorithm: SigningAlgorithm) {
-        let (mut session, mut file) =
+        let mut file =
             create_attestation_key_file(algorithm).expect("Couldn't create attestation key file");
 
         let mut key_cert = keymaster_attributes::AttestationKey::new();
@@ -410,9 +444,9 @@
         key_cert.set_key(test_key.to_vec());
 
         let serialized_buffer = key_cert.write_to_bytes().expect("Couldn't serialize key");
-        session.write_all(&mut file, &serialized_buffer).unwrap();
+
+        file.write_all(&serialized_buffer).unwrap();
         file.close();
-        session.close();
 
         let key_type = SigningKeyType { which: SigningKey::Batch, algo_hint: algorithm };
         let att_key = read_attestation_key(key_type).expect("Couldn't read key from storage");
@@ -428,11 +462,10 @@
         raw_protobuf.extend_from_slice(&key_header);
         raw_protobuf.extend_from_slice(&test_key);
 
-        let (mut session, mut file) =
+        let mut file =
             create_attestation_key_file(algorithm).expect("Couldn't create attestation key file");
-        session.write_all(&mut file, &raw_protobuf).unwrap();
+        file.write_all(&raw_protobuf).unwrap();
         file.close();
-        session.close();
 
         let att_key_comp = read_attestation_key(key_type).expect("Couldn't read key from storage");
 
@@ -449,8 +482,7 @@
 
     #[test]
     fn single_attestation_id_field() {
-        let (mut session, mut file) =
-            create_attestation_id_file().expect("Couldn't create attestation id file");
+        let mut file = create_attestation_id_file().expect("Couldn't create attestation id file");
 
         let mut attestation_ids = keymaster_attributes::AttestationIds::new();
         let brand = b"new brand";
@@ -460,10 +492,8 @@
         let serialized_buffer =
             attestation_ids.write_to_bytes().expect("Couldn't serialize attestationIds");
 
-        session.write_all(&mut file, &serialized_buffer).unwrap();
-
+        file.write_all(&serialized_buffer).unwrap();
         file.close();
-        session.close();
 
         let attestation_ids_info =
             read_attestation_ids().expect("Couldn't read attestation IDs from storage");
@@ -486,11 +516,11 @@
 
         // Now using a raw protobuf
         let raw_protobuf = [10, 9, 110, 101, 119, 32, 98, 114, 97, 110, 100];
-        let (mut session, mut file) =
-            create_attestation_id_file().expect("Couldn't create attestation id file");
-        session.write_all(&mut file, &raw_protobuf).unwrap();
+
+        let mut file = create_attestation_id_file().expect("Couldn't create attestation id file");
+        file.write_all(&raw_protobuf).unwrap();
         file.close();
-        session.close();
+
         let attestation_ids_comp = read_attestation_ids()
             .expect("Couldn't read comparison set of attestation IDs from storage");
 
@@ -506,9 +536,7 @@
 
     #[test]
     fn all_attestation_id_fields() {
-        let (mut session, mut file) =
-            create_attestation_id_file().expect("Couldn't create attestation id file");
-
+        let mut file = create_attestation_id_file().expect("Couldn't create attestation id file");
         let mut attestation_ids = keymaster_attributes::AttestationIds::new();
         let brand = b"unknown brand";
         let device = b"my brand new device";
@@ -531,10 +559,8 @@
         let serialized_buffer =
             attestation_ids.write_to_bytes().expect("Couldn't serialize attestationIds");
 
-        session.write_all(&mut file, &serialized_buffer).unwrap();
-
+        file.write_all(&serialized_buffer).unwrap();
         file.close();
-        session.close();
 
         let attestation_ids_info =
             read_attestation_ids().expect("Couldn't read attestation IDs from storage");
@@ -570,11 +596,9 @@
             32, 35, 36, 37, 37, 94, 66, 11, 119, 111, 114, 107, 105, 110, 103, 32, 111, 110, 101,
         ];
 
-        let (mut session, mut file) =
-            create_attestation_id_file().expect("Couldn't create attestation id file");
-        session.write_all(&mut file, &raw_protobuf).unwrap();
+        let mut file = create_attestation_id_file().expect("Couldn't create attestation id file");
+        file.write_all(&raw_protobuf).unwrap();
         file.close();
-        session.close();
 
         let attestation_ids_comp = read_attestation_ids()
             .expect("Couldn't read comparison set of attestation IDs from storage");
@@ -591,12 +615,11 @@
 
     #[test]
     fn delete_attestation_ids_file() {
-        let (mut session, mut file) =
-            create_attestation_id_file().expect("Couldn't create attestation id file");
+        let mut file = create_attestation_id_file().expect("Couldn't create attestation id file");
         let raw_protobuf = [10, 9, 110, 101, 119, 32, 98, 114, 97, 110, 100];
-        session.write_all(&mut file, &raw_protobuf).unwrap();
+        file.write_all(&raw_protobuf).unwrap();
         file.close();
-        session.close();
+
         expect!(check_attestation_id_file_exists(), "Couldn't create attestation IDs file");
         expect!(delete_attestation_ids().is_ok(), "Couldn't delete attestation IDs file");
         expect_eq!(