[ATFT] Support reprovisioning.

Support key reprovisioning. Also fix a bug that would cause serial
number sometime to disappear.

Bug: b/73350868
Test: Unit test and local test on Linux.
Change-Id: Ia9fb0d4b9415f84650361edc088f94003059baf9
diff --git a/at-factory-tool/atft.py b/at-factory-tool/atft.py
index a7f2b70..56276e6 100644
--- a/at-factory-tool/atft.py
+++ b/at-factory-tool/atft.py
@@ -858,9 +858,8 @@
     if not self.log.log_dir_file:
       self._SendAlertEvent(self.ALERT_FAIL_TO_CREATE_LOG)
 
-    self.StartRefreshingDevices()
-
     self.ShowStartScreen()
+    self.StartRefreshingDevices()
 
   def CreateAtftManager(self):
     """Create an AtftManager object.
@@ -1230,6 +1229,12 @@
     self.ALERT_WRONG_ORIG_PASSWORD = [
         'Wrong Original Password!!!',
         '原密码错误!!!'][index]
+    self.ALERT_REPROVISION = [
+        lambda device:
+            'The device ' + str(device) + ' already has attestation key, '
+            'do you want to reprovision a new key?',
+        lambda device:
+            '设备' + str(device) + '中已经有一个密钥,是否覆盖?'][index]
 
     self.STATUS_MAPPED = ['Mapped', '已关联位置'][index]
     self.STATUS_NOT_MAPPED = ['Not mapped', '未关联位置'][index]
@@ -1718,7 +1723,7 @@
       target_devs_output_serial = wx.StaticText(
           target_devs_output_title, wx.ID_ANY, '')
       target_devs_output_serial.SetForegroundColour(self.COLOR_BLACK)
-      target_devs_output_serial.SetMinSize((180 * scale, 0))
+      target_devs_output_serial.SetMinSize((180 * scale, 15))
       target_devs_output_title_sizer.Add(
           target_devs_output_serial, 0, wx.TOP, 18)
       serial_font = wx.Font(
@@ -2002,6 +2007,23 @@
     if self._GetCachedATFAKeysLeft() == 0:
       self._SendAlertEvent(self.ALERT_PROV_NO_KEYS)
       return
+    for serial in selected_serials:
+      target_dev = self.atft_manager.GetTargetDevice(serial)
+      if not target_dev:
+        continue
+      status = target_dev.provision_status
+      if (TEST_MODE):
+        target_dev.provision_status = ProvisionStatus.WAITING
+      elif (
+          target_dev.provision_state.bootloader_locked and
+          target_dev.provision_state.avb_perm_attr_set and
+          target_dev.provision_state.avb_locked):
+        if (target_dev.provision_state.provisioned and
+            not self._ShowWarning(self.ALERT_REPROVISION(target_dev))):
+          continue
+        target_dev.provision_status = ProvisionStatus.WAITING
+      else:
+        self._SendAlertEvent(self.ALERT_PROV_PROVED)
     self._CreateThread(self._ManualProvision, selected_serials)
 
   def OnCheckATFAStatus(self, event):
@@ -2698,7 +2720,6 @@
             serial_text = (
                 self.FIELD_SERIAL_NUMBER + ': ' + str(serial_number))
             status = target_dev.provision_status
-
       self._ShowTargetDevice(i, serial_number, serial_text, status)
 
   def _ShowTargetDevice(self, i, serial_number, serial_text, status):
@@ -3189,25 +3210,12 @@
     # Reset alert_shown
     self.first_key_alert_shown = False
     self.second_key_alert_shown = False
-    pending_targets = []
     for serial in selected_serials:
       target_dev = self.atft_manager.GetTargetDevice(serial)
       if not target_dev:
         continue
-      pending_targets.append(target_dev)
-      status = target_dev.provision_status
-      if (TEST_MODE or (
-          target_dev.provision_state.bootloader_locked and
-          target_dev.provision_state.avb_perm_attr_set and
-          target_dev.provision_state.avb_locked and
-          not target_dev.provision_state.provisioned
-        )):
-        target_dev.provision_status = ProvisionStatus.WAITING
-      else:
-        self._SendAlertEvent(self.ALERT_PROV_PROVED)
-    for target in pending_targets:
-      if target.provision_status == ProvisionStatus.WAITING:
-        self._ProvisionTarget(target)
+      if target_dev.provision_status == ProvisionStatus.WAITING:
+        self._ProvisionTarget(target_dev)
 
   def _ProvisionTarget(self, target):
     """Provision the attestation key into the specific target.
diff --git a/at-factory-tool/atft_unittest.py b/at-factory-tool/atft_unittest.py
index d56ff76..09bf015 100644
--- a/at-factory-tool/atft_unittest.py
+++ b/at-factory-tool/atft_unittest.py
@@ -1189,10 +1189,60 @@
     self.device_map[self.TEST_SERIAL2] = test_dev2
     self.device_map[self.TEST_SERIAL3] = test_dev3
     serials = [self.TEST_SERIAL1, self.TEST_SERIAL2, self.TEST_SERIAL3]
-    mock_atft._ManualProvision(serials)
+    mock_atft._GetSelectedSerials = MagicMock()
+    mock_atft._GetSelectedSerials.return_value = serials
+    mock_atft.atft_manager.atfa_dev = MagicMock()
+    mock_atft._ShowWarning = MagicMock()
+    mock_atft._ShowWarning.return_value = False
+    mock_atft.OnManualProvision(None)
     calls = [call(test_dev1), call(test_dev2)]
     mock_atft.atft_manager.Provision.assert_has_calls(calls)
 
+  def testManualProvisionReprovision(self):
+    mock_atft = MockAtft()
+    mock_atft.PauseRefresh = MagicMock()
+    mock_atft.ResumeRefresh = MagicMock()
+    mock_atft._SendStartMessageEvent = MagicMock()
+    mock_atft._SendSucceedMessageEvent = MagicMock()
+    mock_atft._HandleException = MagicMock()
+    mock_atft.atft_manager.Provision = MagicMock()
+    mock_atft._SendAlertEvent = MagicMock()
+    mock_atft._CheckLowKeyAlert = MagicMock()
+    mock_atft.atft_manager.GetTargetDevice.side_effect = (
+        self.MockGetTargetDevice)
+    test_dev1 = TestDeviceInfo(self.TEST_SERIAL1, self.TEST_LOCATION1,
+                               ProvisionStatus.PROVISION_FAILED)
+    test_dev1.provision_state.bootloader_locked = True
+    test_dev1.provision_state.avb_perm_attr_set = True
+    test_dev1.provision_state.avb_locked = True
+    test_dev2 = TestDeviceInfo(self.TEST_SERIAL2, self.TEST_LOCATION2,
+                               ProvisionStatus.PROVISION_SUCCESS)
+    test_dev2.provision_state.bootloader_locked = True
+    test_dev2.provision_state.avb_perm_attr_set = True
+    test_dev2.provision_state.avb_locked = True
+    test_dev2.provision_state.provisioned = True
+    self.device_map[self.TEST_SERIAL1] = test_dev1
+    self.device_map[self.TEST_SERIAL2] = test_dev2
+    serials = [self.TEST_SERIAL1, self.TEST_SERIAL2]
+    mock_atft._GetSelectedSerials = MagicMock()
+    mock_atft._GetSelectedSerials.return_value = serials
+    mock_atft.atft_manager.atfa_dev = MagicMock()
+    mock_atft._ShowWarning = MagicMock()
+    # User click No for reprovision.
+    mock_atft._ShowWarning.return_value = False
+    mock_atft.OnManualProvision(None)
+    mock_atft.atft_manager.Provision.assert_called_once_with(test_dev1)
+    mock_atft._ShowWarning.assert_called_once()
+
+    # User click yes.
+    mock_atft._ShowWarning.reset_mock()
+    mock_atft.atft_manager.Provision.reset_mock()
+    mock_atft._ShowWarning.return_value = True
+    mock_atft.OnManualProvision(None)
+    calls = [call(test_dev1), call(test_dev2)]
+    mock_atft.atft_manager.Provision.assert_has_calls(calls)
+    mock_atft._ShowWarning.assert_called_once()
+
   def testManualProvisionExceptions(self):
     mock_atft = MockAtft()
     mock_atft.PauseRefresh = MagicMock()
@@ -1222,9 +1272,13 @@
     self.device_map[self.TEST_SERIAL2] = test_dev2
     self.device_map[self.TEST_SERIAL3] = test_dev3
     serials = [self.TEST_SERIAL1, self.TEST_SERIAL2, self.TEST_SERIAL3]
+    mock_atft._GetSelectedSerials = MagicMock()
+    mock_atft._GetSelectedSerials.return_value = serials
+    mock_atft.atft_manager.atfa_dev = MagicMock()
+    mock_atft._ShowWarning = MagicMock()
     mock_atft.atft_manager.Provision.side_effect = (
         fastboot_exceptions.FastbootFailure(''))
-    mock_atft._ManualProvision(serials)
+    mock_atft.OnManualProvision(None)
     self.assertEqual(2, mock_atft._HandleException.call_count)
     test_dev1 = TestDeviceInfo(self.TEST_SERIAL1, self.TEST_LOCATION1,
                                ProvisionStatus.PROVISION_FAILED)
@@ -1243,10 +1297,11 @@
     self.device_map[self.TEST_SERIAL2] = test_dev2
     self.device_map[self.TEST_SERIAL3] = test_dev3
     serials = [self.TEST_SERIAL1, self.TEST_SERIAL2, self.TEST_SERIAL3]
+    mock_atft._GetSelectedSerials.return_value = serials
     mock_atft._HandleException.reset_mock()
     mock_atft.atft_manager.Provision.side_effect = (
         fastboot_exceptions.DeviceNotFoundException())
-    mock_atft._ManualProvision(serials)
+    mock_atft.OnManualProvision(None)
     self.assertEqual(2, mock_atft._HandleException.call_count)
 
   # Test atft._ProcessKey