blob: 5af032f540a30c7bd0af319d38f9eca2e9006fb8 [file] [log] [blame]
#
# Copyright (C) 2016 The Android Open Source Project
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Tests for ndk.workqueue."""
import multiprocessing
import os
from queue import Queue
import signal
import sys
from threading import Event
from types import FrameType
import time
import unittest
from ndk.workqueue import DummyWorkQueue, TaskError, Worker, WorkQueue
def put(_worker: Worker, i: int) -> int:
"""Returns an the passed argument."""
return i
class Functor:
"""Functor that returns the argument passed to the constructor."""
def __init__(self, value: int) -> None:
self.value = value
def __call__(self, _worker: Worker) -> int:
return self.value
def block_on_event(_worker: Worker, event: Event) -> None:
"""Blocks until the event is signalled."""
event.wait()
def update_status(worker: Worker, ready_event: Event, finish_event: Event,
new_status: str) -> None:
"""Updates the worker's status and waits for an event before finishing."""
worker.status = new_status
ready_event.set()
finish_event.wait()
def sigterm_handler(_signum: int, _trace: FrameType) -> None:
"""Raises SystemExit."""
sys.exit()
def sleep_until_sigterm(pid_queue: Queue) -> None:
"""Sleeps until signalled, then passes the PID through the queue."""
signal.signal(signal.SIGTERM, sigterm_handler)
try:
while True:
time.sleep(60) # There is no signal.pause() on Windows :(
finally:
pid_queue.put(os.getpid())
def spawn_child(_worker: Worker, pid_queue: Queue) -> None:
"""Spawns a child process to check behavior of terminate().
The PIDs of both processes are returned via the pid_queue, and then both
processes go to sleep. SIGTERM will be caught by both processes, and the
PIDs will be passed through the queue again to inform the caller that both
processes were signalled.
"""
os.fork()
pid_queue.put(os.getpid())
sleep_until_sigterm(pid_queue)
def raise_error() -> None:
"""Raises a RuntimeError to be re-raised in the caller."""
raise RuntimeError('Error in child')
class WorkQueueTest(unittest.TestCase):
"""Tests for WorkQueue."""
def test_put_func(self) -> None:
"""Test that we can pass a function to the queue and get results."""
workqueue = WorkQueue(4)
workqueue.add_task(put, 1)
workqueue.add_task(put, 2)
expected_results = [1, 2]
while expected_results:
i = workqueue.get_result()
self.assertIn(i, expected_results)
expected_results.remove(i)
workqueue.terminate()
workqueue.join()
def test_put_functor(self) -> None:
"""Test that we can pass a functor to the queue and get results."""
workqueue = WorkQueue(4)
workqueue.add_task(Functor(1))
workqueue.add_task(Functor(2))
expected_results = [1, 2]
while expected_results:
i = workqueue.get_result()
self.assertIn(i, expected_results)
expected_results.remove(i)
workqueue.terminate()
workqueue.join()
def test_finished(self) -> None:
"""Tests that finished() returns the correct result."""
workqueue = WorkQueue(4)
self.assertTrue(workqueue.finished())
manager = multiprocessing.Manager()
event = manager.Event()
workqueue.add_task(block_on_event, event)
self.assertFalse(workqueue.finished())
event.set()
workqueue.get_result()
self.assertTrue(workqueue.finished())
workqueue.terminate()
workqueue.join()
self.assertTrue(workqueue.finished())
def test_status(self) -> None:
"""Tests that worker status can be accessed from the parent."""
workqueue = WorkQueue(1)
manager = multiprocessing.Manager()
ready_event = manager.Event()
finish_event = manager.Event()
self.assertEqual(Worker.IDLE_STATUS, workqueue.workers[0].status)
workqueue.add_task(update_status, ready_event, finish_event, 'working')
ready_event.wait()
self.assertEqual('working', workqueue.workers[0].status)
finish_event.set()
workqueue.get_result()
self.assertEqual(Worker.IDLE_STATUS, workqueue.workers[0].status)
workqueue.terminate()
workqueue.join()
def test_subprocesses_killed(self) -> None:
"""Tests that terminate() kills descendents of worker processes."""
workqueue = WorkQueue(4)
manager = multiprocessing.Manager()
queue = manager.Queue()
workqueue.add_task(spawn_child, queue)
pids = []
pids.append(queue.get())
pids.append(queue.get())
workqueue.terminate()
workqueue.join()
killed_pid = queue.get()
self.assertIn(killed_pid, pids)
pids.remove(killed_pid)
killed_pid = queue.get()
self.assertIn(killed_pid, pids)
pids.remove(killed_pid)
def test_subprocess_exception(self) -> None:
"""Tests that exceptions raised in the task are re-raised."""
workqueue = WorkQueue()
try:
workqueue.add_task(raise_error)
with self.assertRaises(TaskError):
workqueue.get_result()
finally:
workqueue.terminate()
workqueue.join()
class DummyWorkQueueTest(unittest.TestCase):
"""Tests for DummyWorkQueue."""
def test_put_func(self) -> None:
"""Test that we can pass a function to the queue and get results."""
workqueue = DummyWorkQueue()
workqueue.add_task(put, 1)
workqueue.add_task(put, 2)
expected_results = [1, 2]
while expected_results:
i = workqueue.get_result()
self.assertIn(i, expected_results)
expected_results.remove(i)
workqueue.terminate()
workqueue.join()
def test_put_functor(self) -> None:
"""Test that we can pass a functor to the queue and get results."""
workqueue = DummyWorkQueue()
workqueue.add_task(Functor(1))
workqueue.add_task(Functor(2))
expected_results = [1, 2]
while expected_results:
i = workqueue.get_result()
self.assertIn(i, expected_results)
expected_results.remove(i)
workqueue.terminate()
workqueue.join()
def test_finished(self) -> None:
"""Tests that finished() returns the correct result."""
workqueue = WorkQueue()
self.assertTrue(workqueue.finished())
workqueue.add_task(put, 1)
self.assertFalse(workqueue.finished())
workqueue.get_result()
self.assertTrue(workqueue.finished())
workqueue.terminate()
workqueue.join()
self.assertTrue(workqueue.finished())
def test_subprocess_exception(self) -> None:
"""Tests that exceptions raised in the task are re-raised."""
workqueue = DummyWorkQueue()
try:
workqueue.add_task(raise_error)
with self.assertRaises(TaskError):
workqueue.get_result()
finally:
workqueue.terminate()
workqueue.join()