blob: 59e94c6afdf72fb43b83bae591f7065f86df6b9f [file] [log] [blame]
# Copyright 2024, The Android Open Source Project
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Classes used to handle banners."""
from __future__ import annotations
from datetime import date
import json
import logging
from pathlib import Path
from typing import Any, Callable
from atest import atest_utils
from atest import constants
class BannerHistory:
"""A history for banner handling."""
_LAST_BANNER_PROMPT_DATE = 'last_banner_prompt_date'
@staticmethod
def create(config_dir: Path) -> BannerHistory:
config_dir.mkdir(parents=True, exist_ok=True)
history_file = config_dir.joinpath('banner.json')
if not history_file.exists():
history_file.touch()
history = {}
else:
try:
history = json.loads(history_file.read_text())
except json.JSONDecodeError as e:
atest_utils.print_and_log_error(
'Banner history json file is in a bad format: %s', e
)
history = {}
return BannerHistory(history_file, history)
def __init__(self, history_file: Path, history: dict):
self._history_file = history_file
self._history = history
def get_last_banner_prompt_date(self) -> str:
"""Get the last date when banner was prompt."""
return self._history.get(BannerHistory._LAST_BANNER_PROMPT_DATE, '')
def set_last_banner_prompt_date(self, date: str):
"""Set the last date when banner was prompt."""
self._history[BannerHistory._LAST_BANNER_PROMPT_DATE] = date
self._history_file.write_text(json.dumps(self._history))
class BannerPrinter:
"""A printer used to collect and print banners."""
@staticmethod
def create() -> BannerPrinter:
return BannerPrinter(atest_utils.get_config_folder())
def __init__(self, config_dir: Path):
self._messages = []
self._config_dir = config_dir
def register(self, message: str):
"""Register a banner message."""
self._messages.append(message)
def print(self, print_func: Callable = None, date_supplier: Callable = None):
"""Print the banners."""
if not self._messages:
return
if not print_func:
print_func = lambda m: atest_utils.colorful_print(m, constants.MAGENTA)
if not date_supplier:
date_supplier = lambda: str(date.today())
today = date_supplier()
history = BannerHistory.create(self._config_dir)
if history.get_last_banner_prompt_date() != today:
for message in self._messages:
print_func(message)
history.set_last_banner_prompt_date(today)