blob: 2b2cd7b6c5204b7e957d6b14804236c44c6740e5 [file] [log] [blame]
"""test_check_labels.py"""
from typing import Any, List
from unittest import main, mock, TestCase
from check_labels import (
add_label_err_comment,
delete_all_label_err_comments,
main as check_labels_main,
)
from github_utils import GitHubComment
from label_utils import BOT_AUTHORS, LABEL_ERR_MSG_TITLE
from test_trymerge import mock_gh_get_info, mocked_gh_graphql
from trymerge import GitHubPR
def mock_parse_args() -> object:
class Object:
def __init__(self) -> None:
self.pr_num = 76123
return Object()
def mock_add_label_err_comment(pr: "GitHubPR") -> None:
pass
def mock_delete_all_label_err_comments(pr: "GitHubPR") -> None:
pass
def mock_get_comments() -> List[GitHubComment]:
return [
# Case 1 - a non label err comment
GitHubComment(
body_text="mock_body_text",
created_at="",
author_login="",
author_association="",
editor_login=None,
database_id=1,
url="",
),
# Case 2 - a label err comment
GitHubComment(
body_text=" #" + LABEL_ERR_MSG_TITLE.replace("`", ""),
created_at="",
author_login=BOT_AUTHORS[1],
author_association="",
editor_login=None,
database_id=2,
url="",
),
]
class TestCheckLabels(TestCase):
@mock.patch("trymerge.gh_graphql", side_effect=mocked_gh_graphql)
@mock.patch("trymerge.GitHubPR.get_comments", return_value=[mock_get_comments()[0]])
@mock.patch("check_labels.gh_post_pr_comment")
def test_correctly_add_label_err_comment(
self, mock_gh_post_pr_comment: Any, mock_get_comments: Any, mock_gh_grphql: Any
) -> None:
"Test add label err comment when similar comments don't exist."
pr = GitHubPR("pytorch", "pytorch", 75095)
add_label_err_comment(pr)
mock_gh_post_pr_comment.assert_called_once()
@mock.patch("trymerge.gh_graphql", side_effect=mocked_gh_graphql)
@mock.patch("trymerge.GitHubPR.get_comments", return_value=[mock_get_comments()[1]])
@mock.patch("check_labels.gh_post_pr_comment")
def test_not_add_label_err_comment(
self, mock_gh_post_pr_comment: Any, mock_get_comments: Any, mock_gh_grphql: Any
) -> None:
"Test not add label err comment when similar comments exist."
pr = GitHubPR("pytorch", "pytorch", 75095)
add_label_err_comment(pr)
mock_gh_post_pr_comment.assert_not_called()
@mock.patch("trymerge.gh_graphql", side_effect=mocked_gh_graphql)
@mock.patch("trymerge.GitHubPR.get_comments", return_value=mock_get_comments())
@mock.patch("check_labels.gh_delete_comment")
def test_correctly_delete_all_label_err_comments(
self, mock_gh_delete_comment: Any, mock_get_comments: Any, mock_gh_grphql: Any
) -> None:
"Test only delete label err comment."
pr = GitHubPR("pytorch", "pytorch", 75095)
delete_all_label_err_comments(pr)
mock_gh_delete_comment.assert_called_once_with("pytorch", "pytorch", 2)
@mock.patch("trymerge.gh_get_pr_info", return_value=mock_gh_get_info())
@mock.patch("check_labels.parse_args", return_value=mock_parse_args())
@mock.patch("check_labels.has_required_labels", return_value=False)
@mock.patch(
"check_labels.delete_all_label_err_comments",
side_effect=mock_delete_all_label_err_comments,
)
@mock.patch(
"check_labels.add_label_err_comment", side_effect=mock_add_label_err_comment
)
def test_ci_comments_and_exit0_without_required_labels(
self,
mock_add_label_err_comment: Any,
mock_delete_all_label_err_comments: Any,
mock_has_required_labels: Any,
mock_parse_args: Any,
mock_gh_get_info: Any,
) -> None:
with self.assertRaises(SystemExit) as sys_exit:
check_labels_main()
self.assertEqual(str(sys_exit.exception), "0")
mock_add_label_err_comment.assert_called_once()
mock_delete_all_label_err_comments.assert_not_called()
@mock.patch("trymerge.gh_get_pr_info", return_value=mock_gh_get_info())
@mock.patch("check_labels.parse_args", return_value=mock_parse_args())
@mock.patch("check_labels.has_required_labels", return_value=True)
@mock.patch(
"check_labels.delete_all_label_err_comments",
side_effect=mock_delete_all_label_err_comments,
)
@mock.patch(
"check_labels.add_label_err_comment", side_effect=mock_add_label_err_comment
)
def test_ci_exit0_with_required_labels(
self,
mock_add_label_err_comment: Any,
mock_delete_all_label_err_comments: Any,
mock_has_required_labels: Any,
mock_parse_args: Any,
mock_gh_get_info: Any,
) -> None:
with self.assertRaises(SystemExit) as sys_exit:
check_labels_main()
self.assertEqual(str(sys_exit.exception), "0")
mock_add_label_err_comment.assert_not_called()
mock_delete_all_label_err_comments.assert_called_once()
if __name__ == "__main__":
main()