| # Copyright (C) 2020 The Android Open Source Project |
| # |
| # Licensed under the Apache License, Version 2.0 (the "License"); |
| # you may not use this file except in compliance with the License. |
| # You may obtain a copy of the License at |
| # |
| # http:#www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| """Test low-level channel stuff""" |
| |
| # pylint: disable=missing-docstring,unused-argument,redefined-outer-name |
| |
| from threading import Thread |
| import time |
| from pickle import dumps, HIGHEST_PROTOCOL |
| from itertools import repeat |
| import logging |
| from array import array |
| |
| import pytest |
| |
| from .channel import ( |
| Channel, |
| MAX_MSG_SIZE, |
| MessageSendContext, |
| PeerDiedException, |
| ) |
| from .reduction import ( |
| MessageTooLargeError, |
| payload_is_in_fd_marker, |
| ) |
| from .shm import ShmSharedMemorySendContextMixin |
| from .util_test import isolated_resman # pylint: disable=unused-import |
| from .util import fdid, unix_pipe |
| |
| log = logging.getLogger(__name__) |
| |
| def test_make_channel(): |
| Channel.make_pair() |
| |
| def test_recv_eof(): |
| channel1, channel2 = Channel.make_pair() |
| channel2.close() |
| with pytest.raises(PeerDiedException): |
| channel1.recv() |
| |
| def test_recv_nonblock(): |
| channel1, _channel2 = Channel.make_pair() |
| with pytest.raises(BlockingIOError): |
| channel1.recv(block=False) |
| |
| def test_recv_block(): |
| channel1, channel2 = Channel.make_pair() |
| received_value = None |
| def _do_receive(): |
| nonlocal received_value |
| received_value = channel1.recv() |
| recv_thread = Thread(target=_do_receive, daemon=True) |
| recv_thread.start() |
| time.sleep(0.1) |
| assert not received_value |
| channel2.send("blarg") |
| recv_thread.join() |
| assert received_value == "blarg" |
| |
| def _make_array(nr_bytes): |
| return array("b", repeat(1, nr_bytes)) |
| |
| def _make_huge_obj(): |
| return _make_array(1024 * 1024) |
| |
| class TestSendContext(ShmSharedMemorySendContextMixin, MessageSendContext): |
| pass |
| |
| def test_send_large_object(isolated_resman): |
| """Test that we can send and receive large objects""" |
| channel1, channel2 = Channel.make_pair() |
| huge_obj = _make_huge_obj() |
| pickled_huge = dumps(huge_obj, HIGHEST_PROTOCOL) |
| assert len(pickled_huge) > MAX_MSG_SIZE |
| del pickled_huge |
| channel2.send(huge_obj, TestSendContext) |
| received_obj = channel1.recv() |
| assert received_obj == huge_obj |
| |
| def test_send_large_object_marker(isolated_resman): |
| channel1, channel2 = Channel.make_pair() |
| channel2.send(payload_is_in_fd_marker, TestSendContext) |
| assert channel1.recv() is payload_is_in_fd_marker |
| |
| def test_send_large_object_multiple_fds(isolated_resman): |
| channel1, channel2 = Channel.make_pair() |
| fd1, fd2 = unix_pipe() |
| payload = (fd1, fd2, _make_huge_obj()) |
| channel2.send(payload, TestSendContext) |
| received_fd1, received_fd2, received_obj = channel1.recv() |
| assert fdid(fd1.fileno()) == fdid(received_fd1.fileno()) |
| assert fd1.fileno() != received_fd1.fileno() |
| assert fdid(fd2.fileno()) == fdid(received_fd2.fileno()) |
| assert fd2.fileno() != received_fd2.fileno() |
| assert received_obj == payload[-1] |
| |
| def _make_exact_pickle_size(target_sz): |
| """Make an object that pickles to exactly SZ bytes""" |
| small_array = _make_array(10000) |
| pickled_small_array = dumps(small_array, HIGHEST_PROTOCOL) |
| assert len(pickled_small_array) <= target_sz |
| extra = target_sz - len(pickled_small_array) |
| big_array = _make_array(extra + len(small_array)) |
| pickled_big_array = dumps(big_array, HIGHEST_PROTOCOL) |
| assert len(pickled_big_array) == target_sz |
| return big_array |
| |
| def test_send_too_big_without_memfd(): |
| _channel1, channel2 = Channel.make_pair() |
| target_length = MAX_MSG_SIZE + 1 |
| array = _make_exact_pickle_size(target_length) |
| with pytest.raises(MessageTooLargeError): |
| channel2.send(array) |
| |
| def test_send_packet_limit(): |
| channel1, channel2 = Channel.make_pair() |
| target_length = MAX_MSG_SIZE |
| array = _make_exact_pickle_size(target_length) |
| channel2.send(array) |
| received_array = channel1.recv() |
| assert array == received_array |