Merge changes I83162c52,I4b4a7775,I9b2cc8de,I0b63d2fd,I8fa2d271, ... into main * changes: Fix import Always enable gradual rolling out features for adte owners Add a member to owners Print full output when rolling output is enabled and subprocess failed Use StringIO instead of a string list Erase the rolling outputs after command finish Only display rolling output on tty terminals Support displaying TF subprocess progress in a rolling window Skip logging and hash calculation for 0 or 100 rollout percentage
diff --git a/Android.bp b/Android.bp index 60c8327..67f7510 100644 --- a/Android.bp +++ b/Android.bp
@@ -57,3 +57,11 @@ canonical_path_from_root: false, }, } + +filegroup { + name: "adte-owners-files", + srcs: [ + "OWNERS_ADTE_TEAM", + "OWNERS", + ], +}
diff --git a/OWNERS_ADTE_TEAM b/OWNERS_ADTE_TEAM index 76e26bc..12ca660 100644 --- a/OWNERS_ADTE_TEAM +++ b/OWNERS_ADTE_TEAM
@@ -2,6 +2,7 @@ davidjames@google.com hwj@google.com hzalek@google.com +ihcinihsdk@google.com kevindagostino@google.com liuyg@google.com lucafarsi@google.com
diff --git a/atest/Android.bp b/atest/Android.bp index 7337b1d..c5fda1f 100644 --- a/atest/Android.bp +++ b/atest/Android.bp
@@ -68,6 +68,7 @@ defaults: ["atest_binary_defaults"], main: "atest_main.py", data: [ + ":adte-owners-files", ":atest_flag_list_for_completion", ":atest_log_uploader", ],
diff --git a/atest/atest_utils.py b/atest/atest_utils.py index a893d60..2bc306d 100644 --- a/atest/atest_utils.py +++ b/atest/atest_utils.py
@@ -27,7 +27,9 @@ import fnmatch import hashlib import html -import importlib +import importlib.resources +import importlib.util +import io import itertools import json import logging @@ -40,6 +42,7 @@ import shutil import subprocess import sys +import threading from threading import Thread import traceback from typing import Any, Dict, IO, List, Set, Tuple @@ -54,7 +57,7 @@ from atest.metrics import metrics_utils from atest.tf_proto import test_record_pb2 -_BUILD_OUTPUT_ROLLING_LINES = 6 +DEFAULT_OUTPUT_ROLLING_LINES = 6 _BASH_CLEAR_PREVIOUS_LINE_CODE = '\033[F\033[K' _BASH_RESET_CODE = '\033[0m' DIST_OUT_DIR = Path( @@ -270,44 +273,111 @@ return output -def _stream_io_output(io_input: IO, io_output: IO, max_lines=None): +def stream_io_output( + io_input: IO, + max_lines=None, + full_output_receiver: IO = None, + io_output: IO = None, + is_io_output_atty=None, +): """Stream an IO output with max number of rolling lines to display if set. Args: input: The file-like object to read the output from. - output: The file-like object to write the output to. max_lines: The maximum number of rolling lines to display. If None, all lines will be displayed. + full_output_receiver: Optional io to receive the full output. + io_output: The file-like object to write the output to. + is_io_output_atty: Whether the io_output is a TTY. """ - print('\n----------------------------------------------------') - term_width, _ = get_terminal_size() - full_output = [] - last_lines = None if not max_lines else deque(maxlen=max_lines) - last_number_of_lines = 0 - for line in iter(io_input.readline, ''): - full_output.append(line) - line = line.rstrip() - if last_lines is None: + if io_output is None: + io_output = _original_sys_stdout + if is_io_output_atty is None: + is_io_output_atty = _has_colors(io_output) + if not max_lines or not is_io_output_atty: + for line in iter(io_input.readline, ''): + if not line: + break + if full_output_receiver is not None: + full_output_receiver.write( + line if isinstance(line, str) else line.decode('utf-8') + ) io_output.write(line) - io_output.write('\n') io_output.flush() - continue + return + + term_width, _ = get_terminal_size() + last_lines = deque(maxlen=max_lines) + is_rolling = True + + def reset_output(): + if not is_rolling: + return + io_output.write(_BASH_CLEAR_PREVIOUS_LINE_CODE * (len(last_lines) + 2)) + + def write_output(new_lines: list[str]): + if not is_rolling: + return + last_lines.extend(new_lines) + lines = ['========== Rolling subprocess output =========='] + lines.extend(last_lines) + lines.append('-----------------------------------------------') + io_output.write('\n'.join(lines)) + io_output.write('\n') + io_output.flush() + + original_stdout = sys.stdout + + lock = threading.Lock() + + class SafeStdout: + + def __init__(self): + self._buffers = [] + + def write(self, buf: str) -> None: + if len(buf) == 1 and buf[0] == '\n' and self._buffers: + with lock: + reset_output() + original_stdout.write(''.join(self._buffers)) + original_stdout.write('\n') + original_stdout.flush() + write_output([]) + self._buffers.clear() + else: + self._buffers.append(buf) + + def flush(self) -> None: + original_stdout.flush() + + sys.stdout = SafeStdout() + + for line in iter(io_input.readline, ''): + if not line: + break + line = line.decode('utf-8') if isinstance(line, bytes) else line + if full_output_receiver is not None: + full_output_receiver.write(line) + line = line.rstrip() # Split the line if it's longer than the terminal width wrapped_lines = ( [line] if len(line) <= term_width else [line[i : i + term_width] for i in range(0, len(line), term_width)] ) - last_lines.extend(wrapped_lines) - io_output.write(_BASH_CLEAR_PREVIOUS_LINE_CODE * last_number_of_lines) - io_output.write('\n'.join(last_lines)) - io_output.write('\n') + with lock: + reset_output() + write_output(wrapped_lines) + + with lock: + reset_output() + is_rolling = False + io_output.write(_BASH_RESET_CODE) io_output.flush() - last_number_of_lines = len(last_lines) + + sys.stdout = original_stdout + io_input.close() - io_output.write(_BASH_RESET_CODE) - io_output.flush() - print('----------------------------------------------------') def run_limited_output( @@ -336,12 +406,18 @@ start_new_session=start_new_session, text=True, ) as proc: - _stream_io_output( - proc.stdout, _original_sys_stdout, _BUILD_OUTPUT_ROLLING_LINES + full_output_receiver = io.StringIO() + stream_io_output( + proc.stdout, + DEFAULT_OUTPUT_ROLLING_LINES, + full_output_receiver, + _original_sys_stdout, ) returncode = proc.wait() if returncode: - raise subprocess.CalledProcessError(returncode, cmd, full_output) + raise subprocess.CalledProcessError( + returncode, cmd, full_output_receiver.getvalue() + ) def get_build_out_dir(*joinpaths) -> Path: @@ -533,6 +609,11 @@ return all((len(args.tests) == 1, args.tests[0][0] == ':')) +def is_atty_terminal() -> bool: + """Check if the current process is running in a TTY.""" + return getattr(_original_sys_stdout, 'isatty', lambda: False)() + + def _has_colors(stream): """Check the output stream is colorful.
diff --git a/atest/atest_utils_unittest.py b/atest/atest_utils_unittest.py index 0b8d6a3..243b6b5 100755 --- a/atest/atest_utils_unittest.py +++ b/atest/atest_utils_unittest.py
@@ -66,6 +66,7 @@ ---------------------------- """ + class StreamIoOutputTest(unittest.TestCase): """Class that tests the _stream_io_output function.""" @@ -76,9 +77,13 @@ io_input.seek(0) io_output = StringIO() - atest_utils._stream_io_output(io_input, io_output, max_lines=None) + atest_utils.stream_io_output( + io_input, max_lines=None, io_output=io_output, is_io_output_atty=True + ) - self.assertNotIn(atest_utils._BASH_CLEAR_PREVIOUS_LINE_CODE, io_output.getvalue()) + self.assertNotIn( + atest_utils._BASH_CLEAR_PREVIOUS_LINE_CODE, io_output.getvalue() + ) @mock.patch.object(atest_utils, 'get_terminal_size', return_value=(5, -1)) def test_stream_io_output_wrap_long_lines(self, _): @@ -88,7 +93,9 @@ io_input.seek(0) io_output = StringIO() - atest_utils._stream_io_output(io_input, io_output, max_lines=10) + atest_utils.stream_io_output( + io_input, max_lines=10, io_output=io_output, is_io_output_atty=True + ) self.assertIn('11111\n11111', io_output.getvalue()) @@ -100,10 +107,16 @@ io_input.seek(0) io_output = StringIO() - atest_utils._stream_io_output(io_input, io_output, max_lines=2) + atest_utils.stream_io_output( + io_input, max_lines=2, io_output=io_output, is_io_output_atty=True + ) self.assertIn( - atest_utils._BASH_CLEAR_PREVIOUS_LINE_CODE * 2 + '2\n3\n', + '2\n3\n', + io_output.getvalue(), + ) + self.assertNotIn( + '1\n2\n3\n', io_output.getvalue(), ) @@ -115,10 +128,12 @@ io_input.seek(0) io_output = StringIO() - atest_utils._stream_io_output(io_input, io_output, max_lines=4) + atest_utils.stream_io_output( + io_input, max_lines=4, io_output=io_output, is_io_output_atty=True + ) self.assertIn( - atest_utils._BASH_CLEAR_PREVIOUS_LINE_CODE * 2 + '1\n2\n3\n', + '1\n2\n3\n', io_output.getvalue(), )
diff --git a/atest/rollout_control.py b/atest/rollout_control.py index df146f5..9dd9f01 100644 --- a/atest/rollout_control.py +++ b/atest/rollout_control.py
@@ -18,12 +18,34 @@ import functools import getpass import hashlib +import importlib.resources import logging import os from atest import atest_enum from atest.metrics import metrics +@functools.cache +def _get_project_owners() -> list[str]: + """Returns the owners of the feature.""" + owners = [] + try: + with importlib.resources.as_file( + importlib.resources.files('atest').joinpath('OWNERS') + ) as version_file_path: + owners.extend(version_file_path.read_text(encoding='utf-8').splitlines()) + except (ModuleNotFoundError, FileNotFoundError) as e: + logging.error(e) + try: + with importlib.resources.as_file( + importlib.resources.files('atest').joinpath('OWNERS_ADTE_TEAM') + ) as version_file_path: + owners.extend(version_file_path.read_text(encoding='utf-8').splitlines()) + except (ModuleNotFoundError, FileNotFoundError) as e: + logging.error(e) + return [line.split('@')[0] for line in owners if '@google.com' in line] + + class RolloutControlledFeature: """Base class for Atest features under rollout control.""" @@ -33,6 +55,7 @@ rollout_percentage: float, env_control_flag: str, feature_id: int = None, + owners: list[str] | None = None, ): """Initializes the object. @@ -45,6 +68,8 @@ disable. feature_id: The ID of the feature that is controlled by rollout control for metric collection purpose. Must be a positive integer. + owners: The owners of the feature. If not provided, the owners of the + feature will be read from OWNERS file. """ if rollout_percentage < 0 or rollout_percentage > 100: raise ValueError( @@ -55,10 +80,13 @@ raise ValueError( 'Feature ID must be a positive integer. Got %s instead.' % feature_id ) + if owners is None: + owners = _get_project_owners() self._name = name self._rollout_percentage = rollout_percentage self._env_control_flag = env_control_flag self._feature_id = feature_id + self._owners = owners def _check_env_control_flag(self) -> bool | None: """Checks the environment variable to override the feature enablement. @@ -98,22 +126,30 @@ ) return override_flag_value + if self._rollout_percentage == 100: + return True + if username is None: username = getpass.getuser() if not username: - logging.error( + logging.debug( 'Unable to determine the username. Disabling the feature %s.', self._name, ) return False - hash_object = hashlib.sha256() - hash_object.update((username + ' ' + self._name).encode('utf-8')) + is_enabled = username in self._owners - is_enabled = ( - int(hash_object.hexdigest(), 16) % 100 < self._rollout_percentage - ) + if not is_enabled: + if self._rollout_percentage == 0: + return False + + hash_object = hashlib.sha256() + hash_object.update((username + ' ' + self._name).encode('utf-8')) + is_enabled = ( + int(hash_object.hexdigest(), 16) % 100 < self._rollout_percentage + ) logging.debug( 'Feature %s is %s for user %s.', @@ -122,7 +158,7 @@ username, ) - if self._feature_id and 0 < self._rollout_percentage < 100: + if self._feature_id: metrics.LocalDetectEvent( detect_type=atest_enum.DetectType.ROLLOUT_CONTROLLED_FEATURE_ID, result=self._feature_id if is_enabled else -self._feature_id, @@ -137,3 +173,10 @@ env_control_flag='DISABLE_BAZEL_MODE_BY_DEFAULT', feature_id=1, ) + +rolling_tf_subprocess_output = RolloutControlledFeature( + name='rolling_tf_subprocess_output', + rollout_percentage=0, + env_control_flag='ROLLING_TF_SUBPROCESS_OUTPUT', + feature_id=2, +)
diff --git a/atest/rollout_control_unittest.py b/atest/rollout_control_unittest.py index 05dc9b0..ca000d0 100644 --- a/atest/rollout_control_unittest.py +++ b/atest/rollout_control_unittest.py
@@ -56,7 +56,7 @@ def test_is_enabled_username_undetermined_returns_false(self): sut = rollout_control.RolloutControlledFeature( name='test_feature', - rollout_percentage=100, + rollout_percentage=99, env_control_flag='TEST_FEATURE', ) @@ -91,3 +91,14 @@ with mock.patch.dict('os.environ', {'TEST_FEATURE': 'false'}): self.assertFalse(sut.is_enabled()) + + def test_is_enabled_is_owner_returns_true(self): + sut = rollout_control.RolloutControlledFeature( + name='test_feature', + rollout_percentage=0, + env_control_flag='TEST_FEATURE', + owners=['owner_name'], + ) + + self.assertFalse(sut.is_enabled('name')) + self.assertTrue(sut.is_enabled('owner_name'))
diff --git a/atest/test_runners/atest_tf_test_runner.py b/atest/test_runners/atest_tf_test_runner.py index f227c4a..87e1046 100644 --- a/atest/test_runners/atest_tf_test_runner.py +++ b/atest/test_runners/atest_tf_test_runner.py
@@ -31,6 +31,7 @@ import select import shutil import socket +import threading import time from typing import Any, Dict, List, Set, Tuple @@ -40,6 +41,7 @@ from atest import constants from atest import module_info from atest import result_reporter +from atest import rollout_control from atest.atest_enum import DetectType, ExitCode from atest.coverage import coverage from atest.logstorage import logstorage_utils @@ -399,12 +401,26 @@ run_cmds = self.generate_run_commands( test_infos, extra_args, server.getsockname()[1] ) + is_rolling_output = ( + rollout_control.rolling_tf_subprocess_output.is_enabled() + and not extra_args.get(constants.VERBOSE, False) + and atest_utils.is_atty_terminal() + ) + logging.debug('Running test: %s', run_cmds[0]) subproc = self.run( run_cmds[0], output_to_stdout=extra_args.get(constants.VERBOSE, False), env_vars=self.generate_env_vars(extra_args), + rolling_output_lines=is_rolling_output, ) + + if is_rolling_output: + threading.Thread( + target=atest_utils.stream_io_output, + args=(subproc.stdout, atest_utils.DEFAULT_OUTPUT_ROLLING_LINES), + ).start() + self.handle_subprocess( subproc, partial(self._start_monitor, server, subproc, reporter, extra_args),
diff --git a/atest/test_runners/test_runner_base.py b/atest/test_runners/test_runner_base.py index 499cb1f..927960a 100644 --- a/atest/test_runners/test_runner_base.py +++ b/atest/test_runners/test_runner_base.py
@@ -79,6 +79,7 @@ """Init stuff for base class.""" self.results_dir = results_dir self.test_log_file = None + self._subprocess_stdout = None if not self.NAME: raise atest_error.NoTestRunnerName('Class var NAME is not defined.') if not self.EXECUTABLE: @@ -116,7 +117,13 @@ """Checks whether this runner requires device update.""" return False - def run(self, cmd, output_to_stdout=False, env_vars=None): + def run( + self, + cmd, + output_to_stdout=False, + env_vars=None, + rolling_output_lines=False, + ): """Shell out and execute command. Args: @@ -127,20 +134,34 @@ reporter to print the test results. Set to True to see the output of the cmd. This would be appropriate for verbose runs. env_vars: Environment variables passed to the subprocess. + rolling_output_lines: If True, the subprocess output will be streamed + with rolling lines when output_to_stdout is False. """ - if not output_to_stdout: - self.test_log_file = tempfile.NamedTemporaryFile( - mode='w', dir=self.results_dir, delete=True - ) logging.debug('Executing command: %s', cmd) - return subprocess.Popen( - cmd, - start_new_session=True, - shell=True, - stderr=subprocess.STDOUT, - stdout=self.test_log_file, - env=env_vars, - ) + if rolling_output_lines: + proc = subprocess.Popen( + cmd, + start_new_session=True, + shell=True, + stderr=subprocess.STDOUT, + stdout=None if output_to_stdout else subprocess.PIPE, + env=env_vars, + ) + self._subprocess_stdout = proc.stdout + return proc + else: + if not output_to_stdout: + self.test_log_file = tempfile.NamedTemporaryFile( + mode='w', dir=self.results_dir, delete=True + ) + return subprocess.Popen( + cmd, + start_new_session=True, + shell=True, + stderr=subprocess.STDOUT, + stdout=self.test_log_file, + env=env_vars, + ) # pylint: disable=broad-except def handle_subprocess(self, subproc, func): @@ -165,11 +186,15 @@ # we have to save it above. logging.debug('Subproc already terminated, skipping') finally: - if self.test_log_file: + full_output = '' + if self._subprocess_stdout: + full_output = self._subprocess_stdout.read() + elif self.test_log_file: with open(self.test_log_file.name, 'r') as f: - intro_msg = 'Unexpected Issue. Raw Output:' - print(atest_utils.mark_red(intro_msg)) - print(f.read()) + full_output = f.read() + if full_output: + print(atest_utils.mark_red('Unexpected Issue. Raw Output:')) + print(full_output) # Ignore socket.recv() raising due to ctrl-c if not error.args or error.args[0] != errno.EINTR: raise error