blob: 1264196c96c54c16892eded89bd9ce43bb5f7c4d [file] [log] [blame]
# -*- coding: utf-8 -*-
# Copyright 2015 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Unit tests for daisy chain wrapper class."""
from __future__ import absolute_import
import os
import pkgutil
import gslib.cloud_api
from gslib.daisy_chain_wrapper import DaisyChainWrapper
from gslib.storage_url import StorageUrlFromString
import gslib.tests.testcase as testcase
from gslib.util import TRANSFER_BUFFER_SIZE
_TEST_FILE = 'test.txt'
class TestDaisyChainWrapper(testcase.GsUtilUnitTestCase):
"""Unit tests for the DaisyChainWrapper class."""
_temp_test_file = None
_dummy_url = StorageUrlFromString('gs://bucket/object')
def setUp(self):
super(TestDaisyChainWrapper, self).setUp()
self.test_data_file = self._GetTestFile()
self.test_data_file_len = os.path.getsize(self.test_data_file)
def _GetTestFile(self):
contents = pkgutil.get_data('gslib', 'tests/test_data/%s' % _TEST_FILE)
if not self._temp_test_file:
# Write to a temp file because pkgutil doesn't expose a stream interface.
self._temp_test_file = self.CreateTempFile(
file_name=_TEST_FILE, contents=contents)
return self._temp_test_file
class MockDownloadCloudApi(gslib.cloud_api.CloudApi):
"""Mock CloudApi that implements GetObjectMedia for testing."""
def __init__(self, write_values):
"""Initialize the mock that will be used by the download thread.
Args:
write_values: List of values that will be used for calls to write(),
in order, by the download thread. An Exception class may be part of
the list; if so, the Exception will be raised after previous
values are consumed.
"""
self._write_values = write_values
self.get_calls = 0
def GetObjectMedia(self, unused_bucket_name, unused_object_name,
download_stream, start_byte=0, end_byte=None,
**kwargs):
"""Writes self._write_values to the download_stream."""
# Writes from start_byte up to, but not including end_byte (if not None).
# Does not slice values;
# self._write_values must line up with start/end_byte.
self.get_calls += 1
bytes_read = 0
for write_value in self._write_values:
if bytes_read < start_byte:
bytes_read += len(write_value)
continue
if end_byte and bytes_read >= end_byte:
break
if isinstance(write_value, Exception):
raise write_value
download_stream.write(write_value)
bytes_read += len(write_value)
def _WriteFromWrapperToFile(self, daisy_chain_wrapper, file_path):
"""Writes all contents from the DaisyChainWrapper to the named file."""
with open(file_path, 'wb') as upload_stream:
while True:
data = daisy_chain_wrapper.read(TRANSFER_BUFFER_SIZE)
if not data:
break
upload_stream.write(data)
def testDownloadSingleChunk(self):
"""Tests a single call to GetObjectMedia."""
write_values = []
with open(self.test_data_file, 'rb') as stream:
while True:
data = stream.read(TRANSFER_BUFFER_SIZE)
if not data:
break
write_values.append(data)
upload_file = self.CreateTempFile()
# Test for a single call even if the chunk size is larger than the data.
for chunk_size in (self.test_data_file_len, self.test_data_file_len + 1):
mock_api = self.MockDownloadCloudApi(write_values)
daisy_chain_wrapper = DaisyChainWrapper(
self._dummy_url, self.test_data_file_len, mock_api,
download_chunk_size=chunk_size)
self._WriteFromWrapperToFile(daisy_chain_wrapper, upload_file)
# Since the chunk size is >= the file size, only a single GetObjectMedia
# call should be made.
self.assertEquals(mock_api.get_calls, 1)
with open(upload_file, 'rb') as upload_stream:
with open(self.test_data_file, 'rb') as download_stream:
self.assertEqual(upload_stream.read(), download_stream.read())
def testDownloadMultiChunk(self):
"""Tests multiple calls to GetObjectMedia."""
upload_file = self.CreateTempFile()
write_values = []
with open(self.test_data_file, 'rb') as stream:
while True:
data = stream.read(TRANSFER_BUFFER_SIZE)
if not data:
break
write_values.append(data)
mock_api = self.MockDownloadCloudApi(write_values)
daisy_chain_wrapper = DaisyChainWrapper(
self._dummy_url, self.test_data_file_len, mock_api,
download_chunk_size=TRANSFER_BUFFER_SIZE)
self._WriteFromWrapperToFile(daisy_chain_wrapper, upload_file)
num_expected_calls = self.test_data_file_len / TRANSFER_BUFFER_SIZE
if self.test_data_file_len % TRANSFER_BUFFER_SIZE:
num_expected_calls += 1
# Since the chunk size is < the file size, multiple calls to GetObjectMedia
# should be made.
self.assertEqual(mock_api.get_calls, num_expected_calls)
with open(upload_file, 'rb') as upload_stream:
with open(self.test_data_file, 'rb') as download_stream:
self.assertEqual(upload_stream.read(), download_stream.read())
def testDownloadWithZeroWrites(self):
"""Tests 0-byte writes to the download stream from GetObjectMedia."""
write_values = []
with open(self.test_data_file, 'rb') as stream:
while True:
write_values.append(b'')
data = stream.read(TRANSFER_BUFFER_SIZE)
write_values.append(b'')
if not data:
break
write_values.append(data)
upload_file = self.CreateTempFile()
mock_api = self.MockDownloadCloudApi(write_values)
daisy_chain_wrapper = DaisyChainWrapper(
self._dummy_url, self.test_data_file_len, mock_api,
download_chunk_size=self.test_data_file_len)
self._WriteFromWrapperToFile(daisy_chain_wrapper, upload_file)
self.assertEquals(mock_api.get_calls, 1)
with open(upload_file, 'rb') as upload_stream:
with open(self.test_data_file, 'rb') as download_stream:
self.assertEqual(upload_stream.read(), download_stream.read())
def testDownloadWithPartialWrite(self):
"""Tests unaligned writes to the download stream from GetObjectMedia."""
with open(self.test_data_file, 'rb') as stream:
chunk = stream.read(TRANSFER_BUFFER_SIZE)
one_byte = chunk[0]
chunk_minus_one_byte = chunk[1:TRANSFER_BUFFER_SIZE]
half_chunk = chunk[0:TRANSFER_BUFFER_SIZE/2]
write_values_dict = {
'First byte first chunk unaligned':
(one_byte, chunk_minus_one_byte, chunk, chunk),
'Last byte first chunk unaligned':
(chunk_minus_one_byte, chunk, chunk),
'First byte second chunk unaligned':
(chunk, one_byte, chunk_minus_one_byte, chunk),
'Last byte second chunk unaligned':
(chunk, chunk_minus_one_byte, one_byte, chunk),
'First byte final chunk unaligned':
(chunk, chunk, one_byte, chunk_minus_one_byte),
'Last byte final chunk unaligned':
(chunk, chunk, chunk_minus_one_byte, one_byte),
'Half chunks':
(half_chunk, half_chunk, half_chunk),
'Many unaligned':
(one_byte, half_chunk, one_byte, half_chunk, chunk,
chunk_minus_one_byte, chunk, one_byte, half_chunk, one_byte)
}
upload_file = self.CreateTempFile()
for case_name, write_values in write_values_dict.iteritems():
expected_contents = b''
for write_value in write_values:
expected_contents += write_value
mock_api = self.MockDownloadCloudApi(write_values)
daisy_chain_wrapper = DaisyChainWrapper(
self._dummy_url, len(expected_contents), mock_api,
download_chunk_size=self.test_data_file_len)
self._WriteFromWrapperToFile(daisy_chain_wrapper, upload_file)
with open(upload_file, 'rb') as upload_stream:
self.assertEqual(upload_stream.read(), expected_contents,
'Uploaded file contents for case %s did not match'
% case_name)
def testSeekAndReturn(self):
"""Tests seeking to the end of the wrapper (simulates getting size)."""
write_values = []
with open(self.test_data_file, 'rb') as stream:
while True:
data = stream.read(TRANSFER_BUFFER_SIZE)
if not data:
break
write_values.append(data)
upload_file = self.CreateTempFile()
mock_api = self.MockDownloadCloudApi(write_values)
daisy_chain_wrapper = DaisyChainWrapper(
self._dummy_url, self.test_data_file_len, mock_api,
download_chunk_size=self.test_data_file_len)
with open(upload_file, 'wb') as upload_stream:
current_position = 0
daisy_chain_wrapper.seek(0, whence=os.SEEK_END)
daisy_chain_wrapper.seek(current_position)
while True:
data = daisy_chain_wrapper.read(TRANSFER_BUFFER_SIZE)
current_position += len(data)
daisy_chain_wrapper.seek(0, whence=os.SEEK_END)
daisy_chain_wrapper.seek(current_position)
if not data:
break
upload_stream.write(data)
self.assertEquals(mock_api.get_calls, 1)
with open(upload_file, 'rb') as upload_stream:
with open(self.test_data_file, 'rb') as download_stream:
self.assertEqual(upload_stream.read(), download_stream.read())
def testRestartDownloadThread(self):
"""Tests seek to non-stored position; this restarts the download thread."""
write_values = []
with open(self.test_data_file, 'rb') as stream:
while True:
data = stream.read(TRANSFER_BUFFER_SIZE)
if not data:
break
write_values.append(data)
upload_file = self.CreateTempFile()
mock_api = self.MockDownloadCloudApi(write_values)
daisy_chain_wrapper = DaisyChainWrapper(
self._dummy_url, self.test_data_file_len, mock_api,
download_chunk_size=self.test_data_file_len)
daisy_chain_wrapper.read(TRANSFER_BUFFER_SIZE)
daisy_chain_wrapper.read(TRANSFER_BUFFER_SIZE)
daisy_chain_wrapper.seek(0)
self._WriteFromWrapperToFile(daisy_chain_wrapper, upload_file)
self.assertEquals(mock_api.get_calls, 2)
with open(upload_file, 'rb') as upload_stream:
with open(self.test_data_file, 'rb') as download_stream:
self.assertEqual(upload_stream.read(), download_stream.read())
def testDownloadThreadException(self):
"""Tests that an exception is propagated via the upload thread."""
class DownloadException(Exception):
pass
write_values = [b'a', b'b',
DownloadException('Download thread forces failure')]
upload_file = self.CreateTempFile()
mock_api = self.MockDownloadCloudApi(write_values)
daisy_chain_wrapper = DaisyChainWrapper(
self._dummy_url, self.test_data_file_len, mock_api,
download_chunk_size=self.test_data_file_len)
try:
self._WriteFromWrapperToFile(daisy_chain_wrapper, upload_file)
self.fail('Expected exception')
except DownloadException, e:
self.assertIn('Download thread forces failure', str(e))
def testInvalidSeek(self):
"""Tests that seeking fails for unsupported seek arguments."""
daisy_chain_wrapper = DaisyChainWrapper(
self._dummy_url, self.test_data_file_len, self.MockDownloadCloudApi([]))
try:
# SEEK_CUR is invalid.
daisy_chain_wrapper.seek(0, whence=os.SEEK_CUR)
self.fail('Expected exception')
except IOError, e:
self.assertIn('does not support seek mode', str(e))
try:
# Seeking from the end with an offset is invalid.
daisy_chain_wrapper.seek(1, whence=os.SEEK_END)
self.fail('Expected exception')
except IOError, e:
self.assertIn('Invalid seek during daisy chain', str(e))