init: Detach daemon only after sepolicy is loaded

The new sequence of operation would be:

1: Load sepolicy - Daemon will continue to be alive and serve any I/O request

2: After sepolicy loading is complete - Switch the device-mapper tables.

3: Kill the block device daemon launched in the first-stage init.

4: Re-launch the daemon with the correct selinux labels set.

5: Enforce the sepolicy

Bug: 240321741
Test: Full OTA on pixel
Signed-off-by: Akilesh Kailash <akailash@google.com>
Change-Id: Idd392f0f0aae7d93e546c0ec0762e6c07b6263e4
diff --git a/fs_mgr/libsnapshot/include/libsnapshot/snapshot.h b/fs_mgr/libsnapshot/include/libsnapshot/snapshot.h
index 34c7baf..cdff06e 100644
--- a/fs_mgr/libsnapshot/include/libsnapshot/snapshot.h
+++ b/fs_mgr/libsnapshot/include/libsnapshot/snapshot.h
@@ -332,10 +332,13 @@
     // Helper function for second stage init to restorecon on the rollback indicator.
     static std::string GetGlobalRollbackIndicatorPath();
 
-    // Detach dm-user devices from the current snapuserd, and populate
-    // |snapuserd_argv| with the necessary arguments to restart snapuserd
-    // and reattach them.
-    bool DetachSnapuserdForSelinux(std::vector<std::string>* snapuserd_argv);
+    // Populate |snapuserd_argv| with the necessary arguments to restart snapuserd
+    // after loading selinux policy.
+    bool PrepareSnapuserdArgsForSelinux(std::vector<std::string>* snapuserd_argv);
+
+    // Detach dm-user devices from the first stage snapuserd. Load
+    // new dm-user tables after loading selinux policy.
+    bool DetachFirstStageSnapuserdForSelinux();
 
     // Perform the transition from the selinux stage of snapuserd into the
     // second-stage of snapuserd. This process involves re-creating the dm-user
diff --git a/fs_mgr/libsnapshot/snapshot.cpp b/fs_mgr/libsnapshot/snapshot.cpp
index 59abd6f..6fed09c 100644
--- a/fs_mgr/libsnapshot/snapshot.cpp
+++ b/fs_mgr/libsnapshot/snapshot.cpp
@@ -1746,13 +1746,6 @@
 
         auto misc_name = user_cow_name;
 
-        DmTable table;
-        table.Emplace<DmTargetUser>(0, target.spec.length, misc_name);
-        if (!dm_.LoadTableAndActivate(user_cow_name, table)) {
-            LOG(ERROR) << "Unable to swap tables for " << misc_name;
-            continue;
-        }
-
         std::string source_device_name;
         if (snapshot_status.old_partition_size() > 0) {
             source_device_name = GetSourceDeviceName(snapshot);
@@ -1780,13 +1773,6 @@
             continue;
         }
 
-        // Wait for ueventd to acknowledge and create the control device node.
-        std::string control_device = "/dev/dm-user/" + misc_name;
-        if (!WaitForDevice(control_device, 10s)) {
-            LOG(ERROR) << "dm-user control device no found:  " << misc_name;
-            continue;
-        }
-
         if (transition == InitTransition::SELINUX_DETACH) {
             if (!UpdateUsesUserSnapshots(lock.get())) {
                 auto message = misc_name + "," + cow_image_device + "," + source_device;
@@ -1804,6 +1790,20 @@
             continue;
         }
 
+        DmTable table;
+        table.Emplace<DmTargetUser>(0, target.spec.length, misc_name);
+        if (!dm_.LoadTableAndActivate(user_cow_name, table)) {
+            LOG(ERROR) << "Unable to swap tables for " << misc_name;
+            continue;
+        }
+
+        // Wait for ueventd to acknowledge and create the control device node.
+        std::string control_device = "/dev/dm-user/" + misc_name;
+        if (!WaitForDevice(control_device, 10s)) {
+            LOG(ERROR) << "dm-user control device no found:  " << misc_name;
+            continue;
+        }
+
         uint64_t base_sectors;
         if (!UpdateUsesUserSnapshots(lock.get())) {
             base_sectors =
@@ -4136,10 +4136,71 @@
     return status.state() != UpdateState::None && status.using_snapuserd();
 }
 
-bool SnapshotManager::DetachSnapuserdForSelinux(std::vector<std::string>* snapuserd_argv) {
+bool SnapshotManager::PrepareSnapuserdArgsForSelinux(std::vector<std::string>* snapuserd_argv) {
     return PerformInitTransition(InitTransition::SELINUX_DETACH, snapuserd_argv);
 }
 
+bool SnapshotManager::DetachFirstStageSnapuserdForSelinux() {
+    LOG(INFO) << "Detaching first stage snapuserd";
+
+    auto lock = LockExclusive();
+    if (!lock) return false;
+
+    std::vector<std::string> snapshots;
+    if (!ListSnapshots(lock.get(), &snapshots)) {
+        LOG(ERROR) << "Failed to list snapshots.";
+        return false;
+    }
+
+    size_t num_cows = 0;
+    size_t ok_cows = 0;
+    for (const auto& snapshot : snapshots) {
+        std::string user_cow_name = GetDmUserCowName(snapshot, GetSnapshotDriver(lock.get()));
+
+        if (dm_.GetState(user_cow_name) == DmDeviceState::INVALID) {
+            continue;
+        }
+
+        DeviceMapper::TargetInfo target;
+        if (!GetSingleTarget(user_cow_name, TableQuery::Table, &target)) {
+            continue;
+        }
+
+        auto target_type = DeviceMapper::GetTargetType(target.spec);
+        if (target_type != "user") {
+            LOG(ERROR) << "Unexpected target type for " << user_cow_name << ": " << target_type;
+            continue;
+        }
+
+        num_cows++;
+        auto misc_name = user_cow_name;
+
+        DmTable table;
+        table.Emplace<DmTargetUser>(0, target.spec.length, misc_name);
+        if (!dm_.LoadTableAndActivate(user_cow_name, table)) {
+            LOG(ERROR) << "Unable to swap tables for " << misc_name;
+            continue;
+        }
+
+        // Wait for ueventd to acknowledge and create the control device node.
+        std::string control_device = "/dev/dm-user/" + misc_name;
+        if (!WaitForDevice(control_device, 10s)) {
+            LOG(ERROR) << "dm-user control device no found:  " << misc_name;
+            continue;
+        }
+
+        ok_cows++;
+        LOG(INFO) << "control device is ready: " << control_device;
+    }
+
+    if (ok_cows != num_cows) {
+        LOG(ERROR) << "Could not transition all snapuserd consumers.";
+        return false;
+    }
+
+    return true;
+}
+
 bool SnapshotManager::PerformSecondStageInitTransition() {
     return PerformInitTransition(InitTransition::SECOND_STAGE);
 }
diff --git a/init/snapuserd_transition.cpp b/init/snapuserd_transition.cpp
index 5c821b0..6972f30 100644
--- a/init/snapuserd_transition.cpp
+++ b/init/snapuserd_transition.cpp
@@ -226,12 +226,9 @@
 
     argv_.emplace_back("snapuserd");
     argv_.emplace_back("-no_socket");
-    if (!sm_->DetachSnapuserdForSelinux(&argv_)) {
+    if (!sm_->PrepareSnapuserdArgsForSelinux(&argv_)) {
         LOG(FATAL) << "Could not perform selinux transition";
     }
-
-    // Make sure the process is gone so we don't have any selinux audits.
-    KillFirstStageSnapuserd(old_pid_);
 }
 
 void SnapuserdSelinuxHelper::FinishTransition() {
@@ -301,6 +298,12 @@
 }
 
 void SnapuserdSelinuxHelper::RelaunchFirstStageSnapuserd() {
+    if (!sm_->DetachFirstStageSnapuserdForSelinux()) {
+        LOG(FATAL) << "Could not perform selinux transition";
+    }
+
+    KillFirstStageSnapuserd(old_pid_);
+
     auto fd = GetRamdiskSnapuserdFd();
     if (!fd) {
         LOG(FATAL) << "Environment variable " << kSnapuserdFirstStageFdVar << " was not set!";