| # -*- coding: utf-8 -*- |
| # Copyright 2015 Google Inc. All Rights Reserved. |
| # |
| # Licensed under the Apache License, Version 2.0 (the "License"); |
| # you may not use this file except in compliance with the License. |
| # You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| """Unit tests for daisy chain wrapper class.""" |
| |
| from __future__ import absolute_import |
| |
| import os |
| import pkgutil |
| |
| import gslib.cloud_api |
| from gslib.daisy_chain_wrapper import DaisyChainWrapper |
| from gslib.storage_url import StorageUrlFromString |
| import gslib.tests.testcase as testcase |
| from gslib.util import TRANSFER_BUFFER_SIZE |
| |
| |
| _TEST_FILE = 'test.txt' |
| |
| |
| class TestDaisyChainWrapper(testcase.GsUtilUnitTestCase): |
| """Unit tests for the DaisyChainWrapper class.""" |
| |
| _temp_test_file = None |
| _dummy_url = StorageUrlFromString('gs://bucket/object') |
| |
| def setUp(self): |
| super(TestDaisyChainWrapper, self).setUp() |
| self.test_data_file = self._GetTestFile() |
| self.test_data_file_len = os.path.getsize(self.test_data_file) |
| |
| def _GetTestFile(self): |
| contents = pkgutil.get_data('gslib', 'tests/test_data/%s' % _TEST_FILE) |
| if not self._temp_test_file: |
| # Write to a temp file because pkgutil doesn't expose a stream interface. |
| self._temp_test_file = self.CreateTempFile( |
| file_name=_TEST_FILE, contents=contents) |
| return self._temp_test_file |
| |
| class MockDownloadCloudApi(gslib.cloud_api.CloudApi): |
| """Mock CloudApi that implements GetObjectMedia for testing.""" |
| |
| def __init__(self, write_values): |
| """Initialize the mock that will be used by the download thread. |
| |
| Args: |
| write_values: List of values that will be used for calls to write(), |
| in order, by the download thread. An Exception class may be part of |
| the list; if so, the Exception will be raised after previous |
| values are consumed. |
| """ |
| self._write_values = write_values |
| self.get_calls = 0 |
| |
| def GetObjectMedia(self, unused_bucket_name, unused_object_name, |
| download_stream, start_byte=0, end_byte=None, |
| **kwargs): |
| """Writes self._write_values to the download_stream.""" |
| # Writes from start_byte up to, but not including end_byte (if not None). |
| # Does not slice values; |
| # self._write_values must line up with start/end_byte. |
| self.get_calls += 1 |
| bytes_read = 0 |
| for write_value in self._write_values: |
| if bytes_read < start_byte: |
| bytes_read += len(write_value) |
| continue |
| if end_byte and bytes_read >= end_byte: |
| break |
| if isinstance(write_value, Exception): |
| raise write_value |
| download_stream.write(write_value) |
| bytes_read += len(write_value) |
| |
| def _WriteFromWrapperToFile(self, daisy_chain_wrapper, file_path): |
| """Writes all contents from the DaisyChainWrapper to the named file.""" |
| with open(file_path, 'wb') as upload_stream: |
| while True: |
| data = daisy_chain_wrapper.read(TRANSFER_BUFFER_SIZE) |
| if not data: |
| break |
| upload_stream.write(data) |
| |
| def testDownloadSingleChunk(self): |
| """Tests a single call to GetObjectMedia.""" |
| write_values = [] |
| with open(self.test_data_file, 'rb') as stream: |
| while True: |
| data = stream.read(TRANSFER_BUFFER_SIZE) |
| if not data: |
| break |
| write_values.append(data) |
| upload_file = self.CreateTempFile() |
| # Test for a single call even if the chunk size is larger than the data. |
| for chunk_size in (self.test_data_file_len, self.test_data_file_len + 1): |
| mock_api = self.MockDownloadCloudApi(write_values) |
| daisy_chain_wrapper = DaisyChainWrapper( |
| self._dummy_url, self.test_data_file_len, mock_api, |
| download_chunk_size=chunk_size) |
| self._WriteFromWrapperToFile(daisy_chain_wrapper, upload_file) |
| # Since the chunk size is >= the file size, only a single GetObjectMedia |
| # call should be made. |
| self.assertEquals(mock_api.get_calls, 1) |
| with open(upload_file, 'rb') as upload_stream: |
| with open(self.test_data_file, 'rb') as download_stream: |
| self.assertEqual(upload_stream.read(), download_stream.read()) |
| |
| def testDownloadMultiChunk(self): |
| """Tests multiple calls to GetObjectMedia.""" |
| upload_file = self.CreateTempFile() |
| write_values = [] |
| with open(self.test_data_file, 'rb') as stream: |
| while True: |
| data = stream.read(TRANSFER_BUFFER_SIZE) |
| if not data: |
| break |
| write_values.append(data) |
| mock_api = self.MockDownloadCloudApi(write_values) |
| daisy_chain_wrapper = DaisyChainWrapper( |
| self._dummy_url, self.test_data_file_len, mock_api, |
| download_chunk_size=TRANSFER_BUFFER_SIZE) |
| self._WriteFromWrapperToFile(daisy_chain_wrapper, upload_file) |
| num_expected_calls = self.test_data_file_len / TRANSFER_BUFFER_SIZE |
| if self.test_data_file_len % TRANSFER_BUFFER_SIZE: |
| num_expected_calls += 1 |
| # Since the chunk size is < the file size, multiple calls to GetObjectMedia |
| # should be made. |
| self.assertEqual(mock_api.get_calls, num_expected_calls) |
| with open(upload_file, 'rb') as upload_stream: |
| with open(self.test_data_file, 'rb') as download_stream: |
| self.assertEqual(upload_stream.read(), download_stream.read()) |
| |
| def testDownloadWithZeroWrites(self): |
| """Tests 0-byte writes to the download stream from GetObjectMedia.""" |
| write_values = [] |
| with open(self.test_data_file, 'rb') as stream: |
| while True: |
| write_values.append(b'') |
| data = stream.read(TRANSFER_BUFFER_SIZE) |
| write_values.append(b'') |
| if not data: |
| break |
| write_values.append(data) |
| upload_file = self.CreateTempFile() |
| mock_api = self.MockDownloadCloudApi(write_values) |
| daisy_chain_wrapper = DaisyChainWrapper( |
| self._dummy_url, self.test_data_file_len, mock_api, |
| download_chunk_size=self.test_data_file_len) |
| self._WriteFromWrapperToFile(daisy_chain_wrapper, upload_file) |
| self.assertEquals(mock_api.get_calls, 1) |
| with open(upload_file, 'rb') as upload_stream: |
| with open(self.test_data_file, 'rb') as download_stream: |
| self.assertEqual(upload_stream.read(), download_stream.read()) |
| |
| def testDownloadWithPartialWrite(self): |
| """Tests unaligned writes to the download stream from GetObjectMedia.""" |
| with open(self.test_data_file, 'rb') as stream: |
| chunk = stream.read(TRANSFER_BUFFER_SIZE) |
| one_byte = chunk[0] |
| chunk_minus_one_byte = chunk[1:TRANSFER_BUFFER_SIZE] |
| half_chunk = chunk[0:TRANSFER_BUFFER_SIZE/2] |
| |
| write_values_dict = { |
| 'First byte first chunk unaligned': |
| (one_byte, chunk_minus_one_byte, chunk, chunk), |
| 'Last byte first chunk unaligned': |
| (chunk_minus_one_byte, chunk, chunk), |
| 'First byte second chunk unaligned': |
| (chunk, one_byte, chunk_minus_one_byte, chunk), |
| 'Last byte second chunk unaligned': |
| (chunk, chunk_minus_one_byte, one_byte, chunk), |
| 'First byte final chunk unaligned': |
| (chunk, chunk, one_byte, chunk_minus_one_byte), |
| 'Last byte final chunk unaligned': |
| (chunk, chunk, chunk_minus_one_byte, one_byte), |
| 'Half chunks': |
| (half_chunk, half_chunk, half_chunk), |
| 'Many unaligned': |
| (one_byte, half_chunk, one_byte, half_chunk, chunk, |
| chunk_minus_one_byte, chunk, one_byte, half_chunk, one_byte) |
| } |
| upload_file = self.CreateTempFile() |
| for case_name, write_values in write_values_dict.iteritems(): |
| expected_contents = b'' |
| for write_value in write_values: |
| expected_contents += write_value |
| mock_api = self.MockDownloadCloudApi(write_values) |
| daisy_chain_wrapper = DaisyChainWrapper( |
| self._dummy_url, len(expected_contents), mock_api, |
| download_chunk_size=self.test_data_file_len) |
| self._WriteFromWrapperToFile(daisy_chain_wrapper, upload_file) |
| with open(upload_file, 'rb') as upload_stream: |
| self.assertEqual(upload_stream.read(), expected_contents, |
| 'Uploaded file contents for case %s did not match' |
| % case_name) |
| |
| def testSeekAndReturn(self): |
| """Tests seeking to the end of the wrapper (simulates getting size).""" |
| write_values = [] |
| with open(self.test_data_file, 'rb') as stream: |
| while True: |
| data = stream.read(TRANSFER_BUFFER_SIZE) |
| if not data: |
| break |
| write_values.append(data) |
| upload_file = self.CreateTempFile() |
| mock_api = self.MockDownloadCloudApi(write_values) |
| daisy_chain_wrapper = DaisyChainWrapper( |
| self._dummy_url, self.test_data_file_len, mock_api, |
| download_chunk_size=self.test_data_file_len) |
| with open(upload_file, 'wb') as upload_stream: |
| current_position = 0 |
| daisy_chain_wrapper.seek(0, whence=os.SEEK_END) |
| daisy_chain_wrapper.seek(current_position) |
| while True: |
| data = daisy_chain_wrapper.read(TRANSFER_BUFFER_SIZE) |
| current_position += len(data) |
| daisy_chain_wrapper.seek(0, whence=os.SEEK_END) |
| daisy_chain_wrapper.seek(current_position) |
| if not data: |
| break |
| upload_stream.write(data) |
| self.assertEquals(mock_api.get_calls, 1) |
| with open(upload_file, 'rb') as upload_stream: |
| with open(self.test_data_file, 'rb') as download_stream: |
| self.assertEqual(upload_stream.read(), download_stream.read()) |
| |
| def testRestartDownloadThread(self): |
| """Tests seek to non-stored position; this restarts the download thread.""" |
| write_values = [] |
| with open(self.test_data_file, 'rb') as stream: |
| while True: |
| data = stream.read(TRANSFER_BUFFER_SIZE) |
| if not data: |
| break |
| write_values.append(data) |
| upload_file = self.CreateTempFile() |
| mock_api = self.MockDownloadCloudApi(write_values) |
| daisy_chain_wrapper = DaisyChainWrapper( |
| self._dummy_url, self.test_data_file_len, mock_api, |
| download_chunk_size=self.test_data_file_len) |
| daisy_chain_wrapper.read(TRANSFER_BUFFER_SIZE) |
| daisy_chain_wrapper.read(TRANSFER_BUFFER_SIZE) |
| daisy_chain_wrapper.seek(0) |
| self._WriteFromWrapperToFile(daisy_chain_wrapper, upload_file) |
| self.assertEquals(mock_api.get_calls, 2) |
| with open(upload_file, 'rb') as upload_stream: |
| with open(self.test_data_file, 'rb') as download_stream: |
| self.assertEqual(upload_stream.read(), download_stream.read()) |
| |
| def testDownloadThreadException(self): |
| """Tests that an exception is propagated via the upload thread.""" |
| |
| class DownloadException(Exception): |
| pass |
| |
| write_values = [b'a', b'b', |
| DownloadException('Download thread forces failure')] |
| upload_file = self.CreateTempFile() |
| mock_api = self.MockDownloadCloudApi(write_values) |
| daisy_chain_wrapper = DaisyChainWrapper( |
| self._dummy_url, self.test_data_file_len, mock_api, |
| download_chunk_size=self.test_data_file_len) |
| try: |
| self._WriteFromWrapperToFile(daisy_chain_wrapper, upload_file) |
| self.fail('Expected exception') |
| except DownloadException, e: |
| self.assertIn('Download thread forces failure', str(e)) |
| |
| def testInvalidSeek(self): |
| """Tests that seeking fails for unsupported seek arguments.""" |
| daisy_chain_wrapper = DaisyChainWrapper( |
| self._dummy_url, self.test_data_file_len, self.MockDownloadCloudApi([])) |
| try: |
| # SEEK_CUR is invalid. |
| daisy_chain_wrapper.seek(0, whence=os.SEEK_CUR) |
| self.fail('Expected exception') |
| except IOError, e: |
| self.assertIn('does not support seek mode', str(e)) |
| |
| try: |
| # Seeking from the end with an offset is invalid. |
| daisy_chain_wrapper.seek(1, whence=os.SEEK_END) |
| self.fail('Expected exception') |
| except IOError, e: |
| self.assertIn('Invalid seek during daisy chain', str(e)) |