[ATFT] Add code to prevent interleaved operation. Add lock to make sure that only one operation is allowed on one device at one time. Bug: b/77496482 Test: Unit tests, local test. Change-Id: Ibbddcee9a5c7554212b33796e25e25382f29861a
diff --git a/at-factory-tool/atft.py b/at-factory-tool/atft.py index 7b4fc09..a7f2b70 100644 --- a/at-factory-tool/atft.py +++ b/at-factory-tool/atft.py
@@ -820,8 +820,10 @@ # Indicate whether refresh is paused. If we could acquire this lock, this # means that the refresh is paused. We would pause the refresh during each # fastboot command since on Windows, a fastboot device would disappear from - # fastboot devices while a fastboot command is issued. - self.refresh_pause_lock = threading.Semaphore(value=0) + # fastboot devices while a fastboot command is issued. We use semaphore to + # allow layered pause and resume, unless the last layer is resumed, the + # refresh is in pause state. + self.refresh_pause_lock = threading.Semaphore(0) # 'fastboot devices' can only run sequentially, so we use this lock to check # if there's already a 'fastboot devices' command running. If so, we ignore @@ -1830,6 +1832,7 @@ self.refresh_timer.start() if self.refresh_pause_lock.acquire(False): + # Semaphore > 0, refresh is still paused. self.refresh_pause_lock.release() self._SendDeviceListedEvent() else: @@ -2176,6 +2179,9 @@ except DeviceNotFoundException as e: self._SendAlertEvent(self.ALERT_NO_ATFA) return + except FastbootFailure as e: + self._HandleException('E', e) + return callback = self._GetRegFile data = self.SaveFileArg(message, filename, callback) event = Event(self.save_file_event, value=data) @@ -2194,6 +2200,9 @@ except DeviceNotFoundException as e: self._SendAlertEvent(self.ALERT_NO_ATFA) return + except FastbootFailure as e: + self._HandleException('E', e) + return callback = self._GetAuditFile data = self.SaveFileArg(message, filename, callback) event = Event(self.save_file_event, value=data) @@ -2542,7 +2551,29 @@ evt = Event(self.print_event, wx.ID_ANY, msg) wx.QueueEvent(self, evt) - def _SendOperationStartEvent(self, operation, target=None): + def _StartOperation(self, operation, target): + if not target: + self.PauseRefresh() + return True + if target.operation_lock.acquire(False): + target.operation = operation + self._SendOperationStartEvent(operation, target) + self.PauseRefresh() + return True + + self._SendAlertEvent( + 'Target: ' + str(target) + ' is currently in another operation: ' + + target.operation + '. Please try again later') + return False + + def _EndOperation(self, target): + self.ResumeRefresh() + if not target: + return + target.operation = None + target.operation_lock.release() + + def _SendOperationStartEvent(self, operation, target): """Send an event to print an operation start message. Args: @@ -2831,7 +2862,7 @@ Whether the check succeed or not. """ operation = 'Check ATFA status' - self._SendOperationStartEvent(operation) + self._SendOperationStartEvent(operation, self.atft_manager.atfa_dev) self.PauseRefresh() try: @@ -2905,8 +2936,8 @@ """ operation = 'Fuse bootloader verified boot key' serial = target.serial_number - self._SendOperationStartEvent(operation, target) - self.PauseRefresh() + if not self._StartOperation(operation, target): + return try: self.atft_manager.FuseVbootKey(target) @@ -2951,7 +2982,7 @@ self._HandleException('E', e, operation) return finally: - self.ResumeRefresh() + self._EndOperation(target) # Wait until callback finishes. After the callback, reboot_lock would be @@ -3019,8 +3050,8 @@ target: The target device DeviceInfo object. """ operation = 'Fuse permanent attributes' - self._SendOperationStartEvent(operation, target) - self.PauseRefresh() + if not self._StartOperation(operation, target): + return try: self.atft_manager.FusePermAttr(target) @@ -3031,7 +3062,7 @@ self._HandleException('E', e, operation) return finally: - self.ResumeRefresh() + self._EndOperation(target) self._SendOperationSucceedEvent(operation, target) @@ -3067,8 +3098,8 @@ target: The target device DeviceInfo object. """ operation = 'Lock android verified boot' - self._SendOperationStartEvent(operation, target) - self.PauseRefresh() + if not self._StartOperation(operation, target): + return try: self.atft_manager.LockAvb(target) @@ -3076,7 +3107,7 @@ self._HandleException('E', e, operation) return finally: - self.ResumeRefresh() + self._EndOperation(target) self._SendOperationSucceedEvent(operation, target) @@ -3111,8 +3142,8 @@ """Reboot ATFA device. """ operation = 'Reboot ATFA device' - self._SendOperationStartEvent(operation) - self.PauseRefresh() + if not self._StartOperation(operation, self.atft_manager.atfa_dev): + return try: self.atft_manager.RebootATFA() @@ -3124,7 +3155,7 @@ self._HandleException('E', e, operation, self.atft_manager.atfa_dev) return finally: - self.ResumeRefresh() + self._EndOperation(self.atft_manager.atfa_dev) self._SendOperationSucceedEvent(operation) @@ -3132,8 +3163,8 @@ """Shutdown ATFA device. """ operation = 'Shutdown ATFA device' - self._SendOperationStartEvent(operation) - self.PauseRefresh() + if not self._StartOperation(operation, self.atft_manager.atfa_dev): + return try: self.atft_manager.ShutdownATFA() @@ -3145,7 +3176,7 @@ self._HandleException('E', e, operation, self.atft_manager.atfa_dev) return finally: - self.ResumeRefresh() + self._EndOperation(self.atft_manager.atfa_dev) self._SendOperationSucceedEvent(operation) @@ -3185,8 +3216,8 @@ target: The target to be provisioned. """ operation = 'Attestation Key Provisioning' - self._SendOperationStartEvent(operation, target) - self.PauseRefresh() + if not self._StartOperation(operation, target): + return try: self.atft_manager.Provision(target) @@ -3200,7 +3231,7 @@ self._UpdateKeysLeftInATFA() return finally: - self.ResumeRefresh() + self._EndOperation(target) self._SendOperationSucceedEvent(operation, target) self._CheckLowKeyAlert() @@ -3254,8 +3285,8 @@ pathname: The path name to the key bundle file. """ operation = 'ATFA device store and process key bundle' - self._SendOperationStartEvent(operation) - self.PauseRefresh() + if not self._StartOperation(operation, self.atft_manager.atfa_dev): + return try: self.atft_manager.atfa_dev.Download(pathname) self.atft_manager.ProcessATFAKey() @@ -3274,7 +3305,7 @@ self.ALERT_PROCESS_KEY_FAILURE + e.msg.encode('utf-8')) return finally: - self.ResumeRefresh() + self._EndOperation(self.atft_manager.atfa_dev) def _UpdateATFACallback(self, pathname): self._CreateThread(self._UpdateATFA, pathname) @@ -3286,12 +3317,11 @@ pathname: The path name to the key bundle file. """ operation = 'Update ATFA device' - self._SendOperationStartEvent(operation) - self.PauseRefresh() + if not self._StartOperation(operation, self.atft_manager.atfa_dev): + return try: self.atft_manager.atfa_dev.Download(pathname) self.atft_manager.UpdateATFA() - self._SendOperationSucceedEvent(operation) except DeviceNotFoundException as e: e.SetMsg('No Available ATFA!') self._HandleException('W', e, operation) @@ -3302,14 +3332,16 @@ self.ALERT_UPDATE_FAILURE + e.msg.encode('utf-8')) return finally: - self.ResumeRefresh() + self._EndOperation(self.atft_manager.atfa_dev) + + self._SendOperationSucceedEvent(operation) def _PurgeKey(self): """Purge the key for the selected product in the ATFA device. """ operation = 'ATFA purge key' - self._SendOperationStartEvent(operation) - self.PauseRefresh() + if not self._StartOperation(operation, self.atft_manager.atfa_dev): + return try: self.atft_manager.PurgeATFAKey() self._SendOperationSucceedEvent(operation) @@ -3327,7 +3359,7 @@ self.ALERT_PURGE_KEY_FAILURE + e.msg.encode('utf-8')) return finally: - self.ResumeRefresh() + self._EndOperation(self.atft_manager.atfa_dev) def _GetRegFile(self, filepath): self._CreateThread(self._GetFileFromATFA, filepath, 'reg') @@ -3354,16 +3386,14 @@ # Should not reach here. return operation = 'ATFA device prepare and download ' + file_type + ' file' - self._SendOperationStartEvent(operation) - self.PauseRefresh() - filepath = filepath.encode('utf-8') + if not self._StartOperation(operation, self.atft_manager.atfa_dev): + return try: + filepath = filepath.encode('utf-8') write_file = open(filepath, 'w+') write_file.close() self.atft_manager.PrepareFile(file_type) self.atft_manager.atfa_dev.Upload(filepath) - self._SendOperationSucceedEvent(operation) - self._SendAlertEvent(alert_message + filepath) except DeviceNotFoundException as e: e.SetMsg('No Available ATFA!') self._HandleException('W', e, operation) @@ -3379,7 +3409,10 @@ alert_cannot_get_file_message + e.msg.encode('utf-8')) return finally: - self.ResumeRefresh() + self._EndOperation(self.atft_manager.atfa_dev) + + self._SendOperationSucceedEvent(operation) + self._SendAlertEvent(alert_message + filepath) def _GetSelectedSerials(self): """Get the list of selected serial numbers in the device list.
diff --git a/at-factory-tool/atft_unittest.py b/at-factory-tool/atft_unittest.py index 69c13c5..d56ff76 100644 --- a/at-factory-tool/atft_unittest.py +++ b/at-factory-tool/atft_unittest.py
@@ -68,6 +68,8 @@ self.provision_status = provision_status self.provision_state = ProvisionState() self.time_set = False + self.operation_lock = MagicMock() + self.operation = None def __eq__(self, other): return (self.serial_number == other.serial_number and
diff --git a/at-factory-tool/atftman.py b/at-factory-tool/atftman.py index 7839abf..92fc3e5 100644 --- a/at-factory-tool/atftman.py +++ b/at-factory-tool/atftman.py
@@ -153,6 +153,10 @@ # The number of attestation keys left for the selected product. This # attribute is only meaning for ATFA device. self.keys_left = None + # Only one operation is allowed on one device at one time. + self.operation_lock = threading.Lock() + # Current operation. + self.operation = None def Copy(self): return DeviceInfo(None, self.serial_number, self.location,