blob: a0139f52cb6bfcd201a6ab9fbe38f519c65dcd82 [file] [log] [blame]
# Copyright (C) 2020 The Android Open Source Project
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http:#www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Test low-level channel stuff"""
# pylint: disable=missing-docstring,unused-argument,redefined-outer-name
from threading import Thread
import time
from pickle import dumps, HIGHEST_PROTOCOL
from itertools import repeat
import logging
from array import array
import pytest
from .channel import (
Channel,
MAX_MSG_SIZE,
MessageSendContext,
PeerDiedException,
)
from .reduction import (
MessageTooLargeError,
payload_is_in_fd_marker,
)
from .shm import ShmSharedMemorySendContextMixin
from .util_test import isolated_resman # pylint: disable=unused-import
from .util import fdid, unix_pipe
log = logging.getLogger(__name__)
def test_make_channel():
Channel.make_pair()
def test_recv_eof():
channel1, channel2 = Channel.make_pair()
channel2.close()
with pytest.raises(PeerDiedException):
channel1.recv()
def test_recv_nonblock():
channel1, _channel2 = Channel.make_pair()
with pytest.raises(BlockingIOError):
channel1.recv(block=False)
def test_recv_block():
channel1, channel2 = Channel.make_pair()
received_value = None
def _do_receive():
nonlocal received_value
received_value = channel1.recv()
recv_thread = Thread(target=_do_receive, daemon=True)
recv_thread.start()
time.sleep(0.1)
assert not received_value
channel2.send("blarg")
recv_thread.join()
assert received_value == "blarg"
def _make_array(nr_bytes):
return array("b", repeat(1, nr_bytes))
def _make_huge_obj():
return _make_array(1024 * 1024)
class TestSendContext(ShmSharedMemorySendContextMixin, MessageSendContext):
pass
def test_send_large_object(isolated_resman):
"""Test that we can send and receive large objects"""
channel1, channel2 = Channel.make_pair()
huge_obj = _make_huge_obj()
pickled_huge = dumps(huge_obj, HIGHEST_PROTOCOL)
assert len(pickled_huge) > MAX_MSG_SIZE
del pickled_huge
channel2.send(huge_obj, TestSendContext)
received_obj = channel1.recv()
assert received_obj == huge_obj
def test_send_large_object_marker(isolated_resman):
channel1, channel2 = Channel.make_pair()
channel2.send(payload_is_in_fd_marker, TestSendContext)
assert channel1.recv() is payload_is_in_fd_marker
def test_send_large_object_multiple_fds(isolated_resman):
channel1, channel2 = Channel.make_pair()
fd1, fd2 = unix_pipe()
payload = (fd1, fd2, _make_huge_obj())
channel2.send(payload, TestSendContext)
received_fd1, received_fd2, received_obj = channel1.recv()
assert fdid(fd1.fileno()) == fdid(received_fd1.fileno())
assert fd1.fileno() != received_fd1.fileno()
assert fdid(fd2.fileno()) == fdid(received_fd2.fileno())
assert fd2.fileno() != received_fd2.fileno()
assert received_obj == payload[-1]
def _make_exact_pickle_size(target_sz):
"""Make an object that pickles to exactly SZ bytes"""
small_array = _make_array(10000)
pickled_small_array = dumps(small_array, HIGHEST_PROTOCOL)
assert len(pickled_small_array) <= target_sz
extra = target_sz - len(pickled_small_array)
big_array = _make_array(extra + len(small_array))
pickled_big_array = dumps(big_array, HIGHEST_PROTOCOL)
assert len(pickled_big_array) == target_sz
return big_array
def test_send_too_big_without_memfd():
_channel1, channel2 = Channel.make_pair()
target_length = MAX_MSG_SIZE + 1
array = _make_exact_pickle_size(target_length)
with pytest.raises(MessageTooLargeError):
channel2.send(array)
def test_send_packet_limit():
channel1, channel2 = Channel.make_pair()
target_length = MAX_MSG_SIZE
array = _make_exact_pickle_size(target_length)
channel2.send(array)
received_array = channel1.recv()
assert array == received_array