Merge "simpleperf: fix missing python test results."
diff --git a/simpleperf/scripts/test/do_test.py b/simpleperf/scripts/test/do_test.py
index 201d251..47d3cbb 100755
--- a/simpleperf/scripts/test/do_test.py
+++ b/simpleperf/scripts/test/do_test.py
@@ -207,6 +207,14 @@
     ok: bool
     duration: str
 
+    def __str__(self) -> str:
+        if self.ok:
+            s = 'OK'
+        else:
+            s = f'FAILED (at try_time {self.try_time})'
+        s += f' {self.duration}'
+        return s
+
 
 class TestProcess:
     """ Create a test process to run selected tests on a device. """
@@ -312,8 +320,6 @@
 
     def update(self, test_proc: TestProcess):
         if test_proc.name not in self.test_process_bars:
-            if not test_proc.alive:
-                return
             bar = tqdm(total=len(test_proc.tests),
                        desc=test_proc.name, ascii=' ##',
                        bar_format="{l_bar}{bar} | {n_fmt}/{total_fmt} [{elapsed}]")
@@ -325,8 +331,10 @@
         if add:
             bar.update(add)
             self.total_bar.update(add)
-        if not test_proc.alive:
-            bar.close()
+
+    def end_test_proc(self, test_proc: TestProcess):
+        if test_proc.name in self.test_process_bars:
+            self.test_process_bars[test_proc.name].close()
             del self.test_process_bars[test_proc.name]
 
     def end_tests(self):
@@ -336,40 +344,55 @@
 
 
 class TestSummary:
-    def __init__(self, test_count: int):
-        self.summary_fh = open('test_summary.txt', 'w')
-        self.failed_summary_fh = open('failed_test_summary.txt', 'w')
-        self.results: Dict[Tuple[str, str], TestResult] = {}
-        self.test_count = test_count
+    def __init__(
+            self, devices: List[Device],
+            device_tests: List[str],
+            repeat_count: int, host_tests: List[str]):
+        self.results: Dict[Tuple[str, str], Optional[TestResult]] = {}
+        for test in device_tests:
+            for device in devices:
+                for repeat_index in range(1, repeat_count + 1):
+                    self.results[(test, '%s_repeat_%d' % (device.name, repeat_index))] = None
+        for test in host_tests:
+            self.results[(test, 'host')] = None
+        self.write_summary()
+
+    @property
+    def test_count(self) -> int:
+        return len(self.results)
 
     @property
     def failed_test_count(self) -> int:
-        return self.test_count - sum(1 for result in self.results.values() if result.ok)
+        count = 0
+        for result in self.results.values():
+            if result is None or not result.ok:
+                count += 1
+        return count
 
     def update(self, test_proc: TestProcess):
+        if test_proc.device:
+            test_env = '%s_repeat_%d' % (test_proc.device.name, test_proc.repeat_index)
+        else:
+            test_env = 'host'
+        has_update = False
         for test, result in test_proc.test_results.items():
-            key = (test, '%s_try_%s' % (test_proc.name, result.try_time))
-            if key not in self.results:
+            key = (test, test_env)
+            if self.results[key] != result:
                 self.results[key] = result
-                self._write_result(key[0], key[1], result)
+                has_update = True
+        if has_update:
+            self.write_summary()
 
-    def _write_result(self, test_name: str, test_env: str, test_result: TestResult):
-        print(
-            '%s    %s    %s    %s' %
-            (test_name, test_env, 'OK' if test_result.ok else 'FAILED', test_result.duration),
-            file=self.summary_fh, flush=True)
-        if not test_result.ok:
-            print('%s    %s    FAILED    %s' % (test_name, test_env, test_result.duration),
-                  file=self.failed_summary_fh, flush=True)
-
-    def end_tests(self):
-        # Show sorted results after testing.
-        self.summary_fh.seek(0, 0)
-        self.failed_summary_fh.seek(0, 0)
-        for key in sorted(self.results.keys()):
-            self._write_result(key[0], key[1], self.results[key])
-        self.summary_fh.close()
-        self.failed_summary_fh.close()
+    def write_summary(self):
+        with open('test_summary.txt', 'w') as fh, \
+                open('failed_test_summary.txt', 'w') as failed_fh:
+            for key in sorted(self.results.keys()):
+                test_name, test_env = key
+                result = self.results[key]
+                message = f'{test_name}    {test_env}    {result}'
+                print(message, file=fh)
+                if not result or not result.ok:
+                    print(message, file=failed_fh)
 
 
 class TestManager:
@@ -418,7 +441,8 @@
         total_test_count = (len(device_tests) + len(device_serialized_tests)
                             ) * len(self.devices) * self.repeat_count + len(host_tests)
         self.progress_bar = ProgressBar(total_test_count)
-        self.test_summary = TestSummary(total_test_count)
+        self.test_summary = TestSummary(self.devices, device_tests + device_serialized_tests,
+                                        self.repeat_count, host_tests)
         if device_tests:
             self.run_device_tests(device_tests)
         if device_serialized_tests:
@@ -427,7 +451,6 @@
             self.run_host_tests(host_tests)
         self.progress_bar.end_tests()
         self.progress_bar = None
-        self.test_summary.end_tests()
 
     def run_device_tests(self, tests: List[str]):
         """ Tests can run in parallel on different devices. """
@@ -462,8 +485,13 @@
             # Process dead procs.
             for test_proc in dead_procs:
                 test_proc.join()
-                if not test_proc.finished and test_proc.restart():
-                    continue
+                if not test_proc.finished:
+                    if test_proc.restart():
+                        continue
+                    else:
+                        self.progress_bar.update(test_proc)
+                        self.test_summary.update(test_proc)
+                self.progress_bar.end_test_proc(test_proc)
                 test_procs.remove(test_proc)
                 if test_proc.repeat_index < repeat_count:
                     test_procs.append(