blob: 8b5ff28d13b6fd77ff66db359c35afddb91dffc6 [file] [log] [blame]
# Copyright 2010 Google Inc.
#
# Permission is hereby granted, free of charge, to any person obtaining a
# copy of this software and associated documentation files (the
# "Software"), to deal in the Software without restriction, including
# without limitation the rights to use, copy, modify, merge, publish, dis-
# tribute, sublicense, and/or sell copies of the Software, and to permit
# persons to whom the Software is furnished to do so, subject to the fol-
# lowing conditions:
#
# The above copyright notice and this permission notice shall be included
# in all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
# OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABIL-
# ITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT
# SHALL THE AUTHOR BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,
# WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
# IN THE SOFTWARE.
"""
Provides basic mocks of core storage service classes, for unit testing:
ACL, Key, Bucket, Connection, and StorageUri. We implement a subset of
the interfaces defined in the real boto classes, but don't handle most
of the optional params (which we indicate with the constant "NOT_IMPL").
"""
import copy
import boto
import base64
import re
from hashlib import md5
from boto.utils import compute_md5
from boto.utils import find_matching_headers
from boto.utils import merge_headers_by_name
from boto.s3.prefix import Prefix
from boto.compat import six
NOT_IMPL = None
class MockAcl(object):
def __init__(self, parent=NOT_IMPL):
pass
def startElement(self, name, attrs, connection):
pass
def endElement(self, name, value, connection):
pass
def to_xml(self):
return '<mock_ACL_XML/>'
class MockKey(object):
def __init__(self, bucket=None, name=None):
self.bucket = bucket
self.name = name
self.data = None
self.etag = None
self.size = None
self.closed = True
self.content_encoding = None
self.content_language = None
self.content_type = None
self.last_modified = 'Wed, 06 Oct 2010 05:11:54 GMT'
self.BufferSize = 8192
def __repr__(self):
if self.bucket:
return '<MockKey: %s,%s>' % (self.bucket.name, self.name)
else:
return '<MockKey: %s>' % self.name
def get_contents_as_string(self, headers=NOT_IMPL,
cb=NOT_IMPL, num_cb=NOT_IMPL,
torrent=NOT_IMPL,
version_id=NOT_IMPL):
return self.data
def get_contents_to_file(self, fp, headers=NOT_IMPL,
cb=NOT_IMPL, num_cb=NOT_IMPL,
torrent=NOT_IMPL,
version_id=NOT_IMPL,
res_download_handler=NOT_IMPL):
fp.write(self.data)
def get_file(self, fp, headers=NOT_IMPL, cb=NOT_IMPL, num_cb=NOT_IMPL,
torrent=NOT_IMPL, version_id=NOT_IMPL,
override_num_retries=NOT_IMPL):
fp.write(self.data)
def _handle_headers(self, headers):
if not headers:
return
if find_matching_headers('Content-Encoding', headers):
self.content_encoding = merge_headers_by_name('Content-Encoding',
headers)
if find_matching_headers('Content-Type', headers):
self.content_type = merge_headers_by_name('Content-Type', headers)
if find_matching_headers('Content-Language', headers):
self.content_language = merge_headers_by_name('Content-Language',
headers)
# Simplistic partial implementation for headers: Just supports range GETs
# of flavor 'Range: bytes=xyz-'.
def open_read(self, headers=None, query_args=NOT_IMPL,
override_num_retries=NOT_IMPL):
if self.closed:
self.read_pos = 0
self.closed = False
if headers and 'Range' in headers:
match = re.match('bytes=([0-9]+)-$', headers['Range'])
if match:
self.read_pos = int(match.group(1))
def close(self, fast=NOT_IMPL):
self.closed = True
def read(self, size=0):
self.open_read()
if size == 0:
data = self.data[self.read_pos:]
self.read_pos = self.size
else:
data = self.data[self.read_pos:self.read_pos+size]
self.read_pos += size
if not data:
self.close()
return data
def set_contents_from_file(self, fp, headers=None, replace=NOT_IMPL,
cb=NOT_IMPL, num_cb=NOT_IMPL,
policy=NOT_IMPL, md5=NOT_IMPL,
res_upload_handler=NOT_IMPL):
self.data = fp.read()
self.set_etag()
self.size = len(self.data)
self._handle_headers(headers)
def set_contents_from_stream(self, fp, headers=None, replace=NOT_IMPL,
cb=NOT_IMPL, num_cb=NOT_IMPL, policy=NOT_IMPL,
reduced_redundancy=NOT_IMPL, query_args=NOT_IMPL,
size=NOT_IMPL):
self.data = ''
chunk = fp.read(self.BufferSize)
while chunk:
self.data += chunk
chunk = fp.read(self.BufferSize)
self.set_etag()
self.size = len(self.data)
self._handle_headers(headers)
def set_contents_from_string(self, s, headers=NOT_IMPL, replace=NOT_IMPL,
cb=NOT_IMPL, num_cb=NOT_IMPL, policy=NOT_IMPL,
md5=NOT_IMPL, reduced_redundancy=NOT_IMPL):
self.data = copy.copy(s)
self.set_etag()
self.size = len(s)
self._handle_headers(headers)
def set_contents_from_filename(self, filename, headers=None,
replace=NOT_IMPL, cb=NOT_IMPL,
num_cb=NOT_IMPL, policy=NOT_IMPL,
md5=NOT_IMPL, res_upload_handler=NOT_IMPL):
fp = open(filename, 'rb')
self.set_contents_from_file(fp, headers, replace, cb, num_cb,
policy, md5, res_upload_handler)
fp.close()
def copy(self, dst_bucket_name, dst_key, metadata=NOT_IMPL,
reduced_redundancy=NOT_IMPL, preserve_acl=NOT_IMPL):
dst_bucket = self.bucket.connection.get_bucket(dst_bucket_name)
return dst_bucket.copy_key(dst_key, self.bucket.name,
self.name, metadata)
@property
def provider(self):
provider = None
if self.bucket and self.bucket.connection:
provider = self.bucket.connection.provider
return provider
def set_etag(self):
"""
Set etag attribute by generating hex MD5 checksum on current
contents of mock key.
"""
m = md5()
if not isinstance(self.data, bytes):
m.update(self.data.encode('utf-8'))
else:
m.update(self.data)
hex_md5 = m.hexdigest()
self.etag = hex_md5
def compute_md5(self, fp):
"""
:type fp: file
:param fp: File pointer to the file to MD5 hash. The file pointer
will be reset to the beginning of the file before the
method returns.
:rtype: tuple
:return: A tuple containing the hex digest version of the MD5 hash
as the first element and the base64 encoded version of the
plain digest as the second element.
"""
tup = compute_md5(fp)
# Returned values are MD5 hash, base64 encoded MD5 hash, and file size.
# The internal implementation of compute_md5() needs to return the
# file size but we don't want to return that value to the external
# caller because it changes the class interface (i.e. it might
# break some code) so we consume the third tuple value here and
# return the remainder of the tuple to the caller, thereby preserving
# the existing interface.
self.size = tup[2]
return tup[0:2]
class MockBucket(object):
def __init__(self, connection=None, name=None, key_class=NOT_IMPL):
self.name = name
self.keys = {}
self.acls = {name: MockAcl()}
# default object ACLs are one per bucket and not supported for keys
self.def_acl = MockAcl()
self.subresources = {}
self.connection = connection
self.logging = False
def __repr__(self):
return 'MockBucket: %s' % self.name
def copy_key(self, new_key_name, src_bucket_name,
src_key_name, metadata=NOT_IMPL, src_version_id=NOT_IMPL,
storage_class=NOT_IMPL, preserve_acl=NOT_IMPL,
encrypt_key=NOT_IMPL, headers=NOT_IMPL, query_args=NOT_IMPL):
new_key = self.new_key(key_name=new_key_name)
src_key = self.connection.get_bucket(
src_bucket_name).get_key(src_key_name)
new_key.data = copy.copy(src_key.data)
new_key.size = len(new_key.data)
return new_key
def disable_logging(self):
self.logging = False
def enable_logging(self, target_bucket_prefix):
self.logging = True
def get_logging_config(self):
return {"Logging": {}}
def get_versioning_status(self, headers=NOT_IMPL):
return False
def get_acl(self, key_name='', headers=NOT_IMPL, version_id=NOT_IMPL):
if key_name:
# Return ACL for the key.
return self.acls[key_name]
else:
# Return ACL for the bucket.
return self.acls[self.name]
def get_def_acl(self, key_name=NOT_IMPL, headers=NOT_IMPL,
version_id=NOT_IMPL):
# Return default ACL for the bucket.
return self.def_acl
def get_subresource(self, subresource, key_name=NOT_IMPL, headers=NOT_IMPL,
version_id=NOT_IMPL):
if subresource in self.subresources:
return self.subresources[subresource]
else:
return '<Subresource/>'
def new_key(self, key_name=None):
mock_key = MockKey(self, key_name)
self.keys[key_name] = mock_key
self.acls[key_name] = MockAcl()
return mock_key
def delete_key(self, key_name, headers=NOT_IMPL,
version_id=NOT_IMPL, mfa_token=NOT_IMPL):
if key_name not in self.keys:
raise boto.exception.StorageResponseError(404, 'Not Found')
del self.keys[key_name]
def get_all_keys(self, headers=NOT_IMPL):
return six.itervalues(self.keys)
def get_key(self, key_name, headers=NOT_IMPL, version_id=NOT_IMPL):
# Emulate behavior of boto when get_key called with non-existent key.
if key_name not in self.keys:
return None
return self.keys[key_name]
def list(self, prefix='', delimiter='', marker=NOT_IMPL,
headers=NOT_IMPL):
prefix = prefix or '' # Turn None into '' for prefix match.
# Return list instead of using a generator so we don't get
# 'dictionary changed size during iteration' error when performing
# deletions while iterating (e.g., during test cleanup).
result = []
key_name_set = set()
for k in six.itervalues(self.keys):
if k.name.startswith(prefix):
k_name_past_prefix = k.name[len(prefix):]
if delimiter:
pos = k_name_past_prefix.find(delimiter)
else:
pos = -1
if (pos != -1):
key_or_prefix = Prefix(
bucket=self, name=k.name[:len(prefix)+pos+1])
else:
key_or_prefix = MockKey(bucket=self, name=k.name)
if key_or_prefix.name not in key_name_set:
key_name_set.add(key_or_prefix.name)
result.append(key_or_prefix)
return result
def set_acl(self, acl_or_str, key_name='', headers=NOT_IMPL,
version_id=NOT_IMPL):
# We only handle setting ACL XML here; if you pass a canned ACL
# the get_acl call will just return that string name.
if key_name:
# Set ACL for the key.
self.acls[key_name] = MockAcl(acl_or_str)
else:
# Set ACL for the bucket.
self.acls[self.name] = MockAcl(acl_or_str)
def set_def_acl(self, acl_or_str, key_name=NOT_IMPL, headers=NOT_IMPL,
version_id=NOT_IMPL):
# We only handle setting ACL XML here; if you pass a canned ACL
# the get_acl call will just return that string name.
# Set default ACL for the bucket.
self.def_acl = acl_or_str
def set_subresource(self, subresource, value, key_name=NOT_IMPL,
headers=NOT_IMPL, version_id=NOT_IMPL):
self.subresources[subresource] = value
class MockProvider(object):
def __init__(self, provider):
self.provider = provider
def get_provider_name(self):
return self.provider
class MockConnection(object):
def __init__(self, aws_access_key_id=NOT_IMPL,
aws_secret_access_key=NOT_IMPL, is_secure=NOT_IMPL,
port=NOT_IMPL, proxy=NOT_IMPL, proxy_port=NOT_IMPL,
proxy_user=NOT_IMPL, proxy_pass=NOT_IMPL,
host=NOT_IMPL, debug=NOT_IMPL,
https_connection_factory=NOT_IMPL,
calling_format=NOT_IMPL,
path=NOT_IMPL, provider='s3',
bucket_class=NOT_IMPL):
self.buckets = {}
self.provider = MockProvider(provider)
def create_bucket(self, bucket_name, headers=NOT_IMPL, location=NOT_IMPL,
policy=NOT_IMPL, storage_class=NOT_IMPL):
if bucket_name in self.buckets:
raise boto.exception.StorageCreateError(
409, 'BucketAlreadyOwnedByYou',
"<Message>Your previous request to create the named bucket "
"succeeded and you already own it.</Message>")
mock_bucket = MockBucket(name=bucket_name, connection=self)
self.buckets[bucket_name] = mock_bucket
return mock_bucket
def delete_bucket(self, bucket, headers=NOT_IMPL):
if bucket not in self.buckets:
raise boto.exception.StorageResponseError(
404, 'NoSuchBucket', '<Message>no such bucket</Message>')
del self.buckets[bucket]
def get_bucket(self, bucket_name, validate=NOT_IMPL, headers=NOT_IMPL):
if bucket_name not in self.buckets:
raise boto.exception.StorageResponseError(404, 'NoSuchBucket',
'Not Found')
return self.buckets[bucket_name]
def get_all_buckets(self, headers=NOT_IMPL):
return six.itervalues(self.buckets)
# We only mock a single provider/connection.
mock_connection = MockConnection()
class MockBucketStorageUri(object):
delim = '/'
def __init__(self, scheme, bucket_name=None, object_name=None,
debug=NOT_IMPL, suppress_consec_slashes=NOT_IMPL,
version_id=None, generation=None, is_latest=False):
self.scheme = scheme
self.bucket_name = bucket_name
self.object_name = object_name
self.suppress_consec_slashes = suppress_consec_slashes
if self.bucket_name and self.object_name:
self.uri = ('%s://%s/%s' % (self.scheme, self.bucket_name,
self.object_name))
elif self.bucket_name:
self.uri = ('%s://%s/' % (self.scheme, self.bucket_name))
else:
self.uri = ('%s://' % self.scheme)
self.version_id = version_id
self.generation = generation and int(generation)
self.is_version_specific = (bool(self.generation)
or bool(self.version_id))
self.is_latest = is_latest
if bucket_name and object_name:
self.versionless_uri = '%s://%s/%s' % (scheme, bucket_name,
object_name)
def __repr__(self):
"""Returns string representation of URI."""
return self.uri
def acl_class(self):
return MockAcl
def canned_acls(self):
return boto.provider.Provider('aws').canned_acls
def clone_replace_name(self, new_name):
return self.__class__(self.scheme, self.bucket_name, new_name)
def clone_replace_key(self, key):
return self.__class__(
key.provider.get_provider_name(),
bucket_name=key.bucket.name,
object_name=key.name,
suppress_consec_slashes=self.suppress_consec_slashes,
version_id=getattr(key, 'version_id', None),
generation=getattr(key, 'generation', None),
is_latest=getattr(key, 'is_latest', None))
def connect(self, access_key_id=NOT_IMPL, secret_access_key=NOT_IMPL):
return mock_connection
def create_bucket(self, headers=NOT_IMPL, location=NOT_IMPL,
policy=NOT_IMPL, storage_class=NOT_IMPL):
return self.connect().create_bucket(self.bucket_name)
def delete_bucket(self, headers=NOT_IMPL):
return self.connect().delete_bucket(self.bucket_name)
def get_versioning_config(self, headers=NOT_IMPL):
self.get_bucket().get_versioning_status(headers)
def has_version(self):
return (issubclass(type(self), MockBucketStorageUri)
and ((self.version_id is not None)
or (self.generation is not None)))
def delete_key(self, validate=NOT_IMPL, headers=NOT_IMPL,
version_id=NOT_IMPL, mfa_token=NOT_IMPL):
self.get_bucket().delete_key(self.object_name)
def disable_logging(self, validate=NOT_IMPL, headers=NOT_IMPL,
version_id=NOT_IMPL):
self.get_bucket().disable_logging()
def enable_logging(self, target_bucket, target_prefix, validate=NOT_IMPL,
headers=NOT_IMPL, version_id=NOT_IMPL):
self.get_bucket().enable_logging(target_bucket)
def get_logging_config(self, validate=NOT_IMPL, headers=NOT_IMPL,
version_id=NOT_IMPL):
return self.get_bucket().get_logging_config()
def equals(self, uri):
return self.uri == uri.uri
def get_acl(self, validate=NOT_IMPL, headers=NOT_IMPL, version_id=NOT_IMPL):
return self.get_bucket().get_acl(self.object_name)
def get_def_acl(self, validate=NOT_IMPL, headers=NOT_IMPL,
version_id=NOT_IMPL):
return self.get_bucket().get_def_acl(self.object_name)
def get_subresource(self, subresource, validate=NOT_IMPL, headers=NOT_IMPL,
version_id=NOT_IMPL):
return self.get_bucket().get_subresource(subresource, self.object_name)
def get_all_buckets(self, headers=NOT_IMPL):
return self.connect().get_all_buckets()
def get_all_keys(self, validate=NOT_IMPL, headers=NOT_IMPL):
return self.get_bucket().get_all_keys(self)
def list_bucket(self, prefix='', delimiter='', headers=NOT_IMPL,
all_versions=NOT_IMPL):
return self.get_bucket().list(prefix=prefix, delimiter=delimiter)
def get_bucket(self, validate=NOT_IMPL, headers=NOT_IMPL):
return self.connect().get_bucket(self.bucket_name)
def get_key(self, validate=NOT_IMPL, headers=NOT_IMPL,
version_id=NOT_IMPL):
return self.get_bucket().get_key(self.object_name)
def is_file_uri(self):
return False
def is_cloud_uri(self):
return True
def names_container(self):
return bool(not self.object_name)
def names_singleton(self):
return bool(self.object_name)
def names_directory(self):
return False
def names_provider(self):
return bool(not self.bucket_name)
def names_bucket(self):
return self.names_container()
def names_file(self):
return False
def names_object(self):
return not self.names_container()
def is_stream(self):
return False
def new_key(self, validate=NOT_IMPL, headers=NOT_IMPL):
bucket = self.get_bucket()
return bucket.new_key(self.object_name)
def set_acl(self, acl_or_str, key_name='', validate=NOT_IMPL,
headers=NOT_IMPL, version_id=NOT_IMPL):
self.get_bucket().set_acl(acl_or_str, key_name)
def set_def_acl(self, acl_or_str, key_name=NOT_IMPL, validate=NOT_IMPL,
headers=NOT_IMPL, version_id=NOT_IMPL):
self.get_bucket().set_def_acl(acl_or_str)
def set_subresource(self, subresource, value, validate=NOT_IMPL,
headers=NOT_IMPL, version_id=NOT_IMPL):
self.get_bucket().set_subresource(subresource, value, self.object_name)
def copy_key(self, src_bucket_name, src_key_name, metadata=NOT_IMPL,
src_version_id=NOT_IMPL, storage_class=NOT_IMPL,
preserve_acl=NOT_IMPL, encrypt_key=NOT_IMPL, headers=NOT_IMPL,
query_args=NOT_IMPL, src_generation=NOT_IMPL):
dst_bucket = self.get_bucket()
return dst_bucket.copy_key(new_key_name=self.object_name,
src_bucket_name=src_bucket_name,
src_key_name=src_key_name)
def set_contents_from_string(self, s, headers=NOT_IMPL, replace=NOT_IMPL,
cb=NOT_IMPL, num_cb=NOT_IMPL, policy=NOT_IMPL,
md5=NOT_IMPL, reduced_redundancy=NOT_IMPL):
key = self.new_key()
key.set_contents_from_string(s)
def set_contents_from_file(self, fp, headers=None, replace=NOT_IMPL,
cb=NOT_IMPL, num_cb=NOT_IMPL, policy=NOT_IMPL,
md5=NOT_IMPL, size=NOT_IMPL, rewind=NOT_IMPL,
res_upload_handler=NOT_IMPL):
key = self.new_key()
return key.set_contents_from_file(fp, headers=headers)
def set_contents_from_stream(self, fp, headers=NOT_IMPL, replace=NOT_IMPL,
cb=NOT_IMPL, num_cb=NOT_IMPL, policy=NOT_IMPL,
reduced_redundancy=NOT_IMPL,
query_args=NOT_IMPL, size=NOT_IMPL):
dst_key.set_contents_from_stream(fp)
def get_contents_to_file(self, fp, headers=NOT_IMPL, cb=NOT_IMPL,
num_cb=NOT_IMPL, torrent=NOT_IMPL,
version_id=NOT_IMPL, res_download_handler=NOT_IMPL,
response_headers=NOT_IMPL):
key = self.get_key()
key.get_contents_to_file(fp)
def get_contents_to_stream(self, fp, headers=NOT_IMPL, cb=NOT_IMPL,
num_cb=NOT_IMPL, version_id=NOT_IMPL):
key = self.get_key()
return key.get_contents_to_file(fp)