blob: df1c75d68c6dfe9fb32293637876a3e55dda26d2 [file] [log] [blame]
# (c) 2005 Ian Bicking and contributors; written for Paste (http://pythonpaste.org)
# Licensed under the MIT license: http://www.opensource.org/licenses/mit-license.php
"""
Routines for testing WSGI applications.
Most interesting is the `TestApp <class-paste.fixture.TestApp.html>`_
for testing WSGI applications, and the `TestFileEnvironment
<class-paste.fixture.TestFileEnvironment.html>`_ class for testing the
effects of command-line scripts.
"""
from __future__ import print_function
import sys
import random
import mimetypes
import time
import os
import shutil
import smtplib
import shlex
import re
import six
import subprocess
from six.moves import cStringIO as StringIO
from six.moves.urllib.parse import urlencode
from six.moves.urllib import parse as urlparse
try:
# Python 3
from http.cookies import BaseCookie
from urllib.parse import splittype, splithost
except ImportError:
# Python 2
from Cookie import BaseCookie
from urllib import splittype, splithost
from paste import wsgilib
from paste import lint
from paste.response import HeaderDict
def tempnam_no_warning(*args):
"""
An os.tempnam with the warning turned off, because sometimes
you just need to use this and don't care about the stupid
security warning.
"""
return os.tempnam(*args)
class NoDefault(object):
pass
def sorted(l):
l = list(l)
l.sort()
return l
class Dummy_smtplib(object):
existing = None
def __init__(self, server):
import warnings
warnings.warn(
'Dummy_smtplib is not maintained and is deprecated',
DeprecationWarning, 2)
assert not self.existing, (
"smtplib.SMTP() called again before Dummy_smtplib.existing.reset() "
"called.")
self.server = server
self.open = True
self.__class__.existing = self
def quit(self):
assert self.open, (
"Called %s.quit() twice" % self)
self.open = False
def sendmail(self, from_address, to_addresses, msg):
self.from_address = from_address
self.to_addresses = to_addresses
self.message = msg
def install(cls):
smtplib.SMTP = cls
install = classmethod(install)
def reset(self):
assert not self.open, (
"SMTP connection not quit")
self.__class__.existing = None
class AppError(Exception):
pass
class TestApp(object):
# for py.test
disabled = True
def __init__(self, app, namespace=None, relative_to=None,
extra_environ=None, pre_request_hook=None,
post_request_hook=None):
"""
Wraps a WSGI application in a more convenient interface for
testing.
``app`` may be an application, or a Paste Deploy app
URI, like ``'config:filename.ini#test'``.
``namespace`` is a dictionary that will be written to (if
provided). This can be used with doctest or some other
system, and the variable ``res`` will be assigned everytime
you make a request (instead of returning the request).
``relative_to`` is a directory, and filenames used for file
uploads are calculated relative to this. Also ``config:``
URIs that aren't absolute.
``extra_environ`` is a dictionary of values that should go
into the environment for each request. These can provide a
communication channel with the application.
``pre_request_hook`` is a function to be called prior to
making requests (such as ``post`` or ``get``). This function
must take one argument (the instance of the TestApp).
``post_request_hook`` is a function, similar to
``pre_request_hook``, to be called after requests are made.
"""
if isinstance(app, (six.binary_type, six.text_type)):
from paste.deploy import loadapp
# @@: Should pick up relative_to from calling module's
# __file__
app = loadapp(app, relative_to=relative_to)
self.app = app
self.namespace = namespace
self.relative_to = relative_to
if extra_environ is None:
extra_environ = {}
self.extra_environ = extra_environ
self.pre_request_hook = pre_request_hook
self.post_request_hook = post_request_hook
self.reset()
def reset(self):
"""
Resets the state of the application; currently just clears
saved cookies.
"""
self.cookies = {}
def _make_environ(self):
environ = self.extra_environ.copy()
environ['paste.throw_errors'] = True
return environ
def get(self, url, params=None, headers=None, extra_environ=None,
status=None, expect_errors=False):
"""
Get the given url (well, actually a path like
``'/page.html'``).
``params``:
A query string, or a dictionary that will be encoded
into a query string. You may also include a query
string on the ``url``.
``headers``:
A dictionary of extra headers to send.
``extra_environ``:
A dictionary of environmental variables that should
be added to the request.
``status``:
The integer status code you expect (if not 200 or 3xx).
If you expect a 404 response, for instance, you must give
``status=404`` or it will be an error. You can also give
a wildcard, like ``'3*'`` or ``'*'``.
``expect_errors``:
If this is not true, then if anything is written to
``wsgi.errors`` it will be an error. If it is true, then
non-200/3xx responses are also okay.
Returns a `response object
<class-paste.fixture.TestResponse.html>`_
"""
if extra_environ is None:
extra_environ = {}
# Hide from py.test:
__tracebackhide__ = True
if params:
if not isinstance(params, (six.binary_type, six.text_type)):
params = urlencode(params, doseq=True)
if '?' in url:
url += '&'
else:
url += '?'
url += params
environ = self._make_environ()
url = str(url)
if '?' in url:
url, environ['QUERY_STRING'] = url.split('?', 1)
else:
environ['QUERY_STRING'] = ''
self._set_headers(headers, environ)
environ.update(extra_environ)
req = TestRequest(url, environ, expect_errors)
return self.do_request(req, status=status)
def _gen_request(self, method, url, params=b'', headers=None, extra_environ=None,
status=None, upload_files=None, expect_errors=False):
"""
Do a generic request.
"""
if headers is None:
headers = {}
if extra_environ is None:
extra_environ = {}
environ = self._make_environ()
# @@: Should this be all non-strings?
if isinstance(params, (list, tuple, dict)):
params = urlencode(params)
if hasattr(params, 'items'):
# Some other multi-dict like format
params = urlencode(params.items())
if six.PY3:
params = params.encode('utf8')
if upload_files:
params = urlparse.parse_qsl(params, keep_blank_values=True)
content_type, params = self.encode_multipart(
params, upload_files)
environ['CONTENT_TYPE'] = content_type
elif params:
environ.setdefault('CONTENT_TYPE', 'application/x-www-form-urlencoded')
if '?' in url:
url, environ['QUERY_STRING'] = url.split('?', 1)
else:
environ['QUERY_STRING'] = ''
environ['CONTENT_LENGTH'] = str(len(params))
environ['REQUEST_METHOD'] = method
environ['wsgi.input'] = six.BytesIO(params)
self._set_headers(headers, environ)
environ.update(extra_environ)
req = TestRequest(url, environ, expect_errors)
return self.do_request(req, status=status)
def post(self, url, params=b'', headers=None, extra_environ=None,
status=None, upload_files=None, expect_errors=False):
"""
Do a POST request. Very like the ``.get()`` method.
``params`` are put in the body of the request.
``upload_files`` is for file uploads. It should be a list of
``[(fieldname, filename, file_content)]``. You can also use
just ``[(fieldname, filename)]`` and the file content will be
read from disk.
Returns a `response object
<class-paste.fixture.TestResponse.html>`_
"""
return self._gen_request('POST', url, params=params, headers=headers,
extra_environ=extra_environ,status=status,
upload_files=upload_files,
expect_errors=expect_errors)
def put(self, url, params=b'', headers=None, extra_environ=None,
status=None, upload_files=None, expect_errors=False):
"""
Do a PUT request. Very like the ``.get()`` method.
``params`` are put in the body of the request.
``upload_files`` is for file uploads. It should be a list of
``[(fieldname, filename, file_content)]``. You can also use
just ``[(fieldname, filename)]`` and the file content will be
read from disk.
Returns a `response object
<class-paste.fixture.TestResponse.html>`_
"""
return self._gen_request('PUT', url, params=params, headers=headers,
extra_environ=extra_environ,status=status,
upload_files=upload_files,
expect_errors=expect_errors)
def delete(self, url, params=b'', headers=None, extra_environ=None,
status=None, expect_errors=False):
"""
Do a DELETE request. Very like the ``.get()`` method.
``params`` are put in the body of the request.
Returns a `response object
<class-paste.fixture.TestResponse.html>`_
"""
return self._gen_request('DELETE', url, params=params, headers=headers,
extra_environ=extra_environ,status=status,
upload_files=None, expect_errors=expect_errors)
def _set_headers(self, headers, environ):
"""
Turn any headers into environ variables
"""
if not headers:
return
for header, value in headers.items():
if header.lower() == 'content-type':
var = 'CONTENT_TYPE'
elif header.lower() == 'content-length':
var = 'CONTENT_LENGTH'
else:
var = 'HTTP_%s' % header.replace('-', '_').upper()
environ[var] = value
def encode_multipart(self, params, files):
"""
Encodes a set of parameters (typically a name/value list) and
a set of files (a list of (name, filename, file_body)) into a
typical POST body, returning the (content_type, body).
"""
boundary = '----------a_BoUnDaRy%s$' % random.random()
content_type = 'multipart/form-data; boundary=%s' % boundary
if six.PY3:
boundary = boundary.encode('ascii')
lines = []
for key, value in params:
lines.append(b'--'+boundary)
line = 'Content-Disposition: form-data; name="%s"' % key
if six.PY3:
line = line.encode('utf8')
lines.append(line)
lines.append(b'')
line = value
if six.PY3 and isinstance(line, six.text_type):
line = line.encode('utf8')
lines.append(line)
for file_info in files:
key, filename, value = self._get_file_info(file_info)
lines.append(b'--'+boundary)
line = ('Content-Disposition: form-data; name="%s"; filename="%s"'
% (key, filename))
if six.PY3:
line = line.encode('utf8')
lines.append(line)
fcontent = mimetypes.guess_type(filename)[0]
line = ('Content-Type: %s'
% (fcontent or 'application/octet-stream'))
if six.PY3:
line = line.encode('utf8')
lines.append(line)
lines.append(b'')
lines.append(value)
lines.append(b'--' + boundary + b'--')
lines.append(b'')
body = b'\r\n'.join(lines)
return content_type, body
def _get_file_info(self, file_info):
if len(file_info) == 2:
# It only has a filename
filename = file_info[1]
if self.relative_to:
filename = os.path.join(self.relative_to, filename)
f = open(filename, 'rb')
content = f.read()
f.close()
return (file_info[0], filename, content)
elif len(file_info) == 3:
return file_info
else:
raise ValueError(
"upload_files need to be a list of tuples of (fieldname, "
"filename, filecontent) or (fieldname, filename); "
"you gave: %r"
% repr(file_info)[:100])
def do_request(self, req, status):
"""
Executes the given request (``req``), with the expected
``status``. Generally ``.get()`` and ``.post()`` are used
instead.
"""
if self.pre_request_hook:
self.pre_request_hook(self)
__tracebackhide__ = True
if self.cookies:
c = BaseCookie()
for name, value in self.cookies.items():
c[name] = value
hc = '; '.join(['='.join([m.key, m.value]) for m in c.values()])
req.environ['HTTP_COOKIE'] = hc
req.environ['paste.testing'] = True
req.environ['paste.testing_variables'] = {}
app = lint.middleware(self.app)
old_stdout = sys.stdout
out = CaptureStdout(old_stdout)
try:
sys.stdout = out
start_time = time.time()
raise_on_wsgi_error = not req.expect_errors
raw_res = wsgilib.raw_interactive(
app, req.url,
raise_on_wsgi_error=raise_on_wsgi_error,
**req.environ)
end_time = time.time()
finally:
sys.stdout = old_stdout
sys.stderr.write(out.getvalue())
res = self._make_response(raw_res, end_time - start_time)
res.request = req
for name, value in req.environ['paste.testing_variables'].items():
if hasattr(res, name):
raise ValueError(
"paste.testing_variables contains the variable %r, but "
"the response object already has an attribute by that "
"name" % name)
setattr(res, name, value)
if self.namespace is not None:
self.namespace['res'] = res
if not req.expect_errors:
self._check_status(status, res)
self._check_errors(res)
res.cookies_set = {}
for header in res.all_headers('set-cookie'):
c = BaseCookie(header)
for key, morsel in c.items():
self.cookies[key] = morsel.value
res.cookies_set[key] = morsel.value
if self.post_request_hook:
self.post_request_hook(self)
if self.namespace is None:
# It's annoying to return the response in doctests, as it'll
# be printed, so we only return it is we couldn't assign
# it anywhere
return res
def _check_status(self, status, res):
__tracebackhide__ = True
if status == '*':
return
if isinstance(status, (list, tuple)):
if res.status not in status:
raise AppError(
"Bad response: %s (not one of %s for %s)\n%s"
% (res.full_status, ', '.join(map(str, status)),
res.request.url, res.body))
return
if status is None:
if res.status >= 200 and res.status < 400:
return
body = res.body
if six.PY3:
body = body.decode('utf8', 'xmlcharrefreplace')
raise AppError(
"Bad response: %s (not 200 OK or 3xx redirect for %s)\n%s"
% (res.full_status, res.request.url,
body))
if status != res.status:
raise AppError(
"Bad response: %s (not %s)" % (res.full_status, status))
def _check_errors(self, res):
if res.errors:
raise AppError(
"Application had errors logged:\n%s" % res.errors)
def _make_response(self, resp, total_time):
status, headers, body, errors = resp
return TestResponse(self, status, headers, body, errors,
total_time)
class CaptureStdout(object):
def __init__(self, actual):
self.captured = StringIO()
self.actual = actual
def write(self, s):
self.captured.write(s)
self.actual.write(s)
def flush(self):
self.actual.flush()
def writelines(self, lines):
for item in lines:
self.write(item)
def getvalue(self):
return self.captured.getvalue()
class TestResponse(object):
# for py.test
disabled = True
"""
Instances of this class are return by `TestApp
<class-paste.fixture.TestApp.html>`_
"""
def __init__(self, test_app, status, headers, body, errors,
total_time):
self.test_app = test_app
self.status = int(status.split()[0])
self.full_status = status
self.headers = headers
self.header_dict = HeaderDict.fromlist(self.headers)
self.body = body
self.errors = errors
self._normal_body = None
self.time = total_time
self._forms_indexed = None
def forms__get(self):
"""
Returns a dictionary of ``Form`` objects. Indexes are both in
order (from zero) and by form id (if the form is given an id).
"""
if self._forms_indexed is None:
self._parse_forms()
return self._forms_indexed
forms = property(forms__get,
doc="""
A list of <form>s found on the page (instances of
`Form <class-paste.fixture.Form.html>`_)
""")
def form__get(self):
forms = self.forms
if not forms:
raise TypeError(
"You used response.form, but no forms exist")
if 1 in forms:
# There is more than one form
raise TypeError(
"You used response.form, but more than one form exists")
return forms[0]
form = property(form__get,
doc="""
Returns a single `Form
<class-paste.fixture.Form.html>`_ instance; it
is an error if there are multiple forms on the
page.
""")
_tag_re = re.compile(r'<(/?)([:a-z0-9_\-]*)(.*?)>', re.S|re.I)
def _parse_forms(self):
forms = self._forms_indexed = {}
form_texts = []
started = None
for match in self._tag_re.finditer(self.body):
end = match.group(1) == '/'
tag = match.group(2).lower()
if tag != 'form':
continue
if end:
assert started, (
"</form> unexpected at %s" % match.start())
form_texts.append(self.body[started:match.end()])
started = None
else:
assert not started, (
"Nested form tags at %s" % match.start())
started = match.start()
assert not started, (
"Danging form: %r" % self.body[started:])
for i, text in enumerate(form_texts):
form = Form(self, text)
forms[i] = form
if form.id:
forms[form.id] = form
def header(self, name, default=NoDefault):
"""
Returns the named header; an error if there is not exactly one
matching header (unless you give a default -- always an error
if there is more than one header)
"""
found = None
for cur_name, value in self.headers:
if cur_name.lower() == name.lower():
assert not found, (
"Ambiguous header: %s matches %r and %r"
% (name, found, value))
found = value
if found is None:
if default is NoDefault:
raise KeyError(
"No header found: %r (from %s)"
% (name, ', '.join([n for n, v in self.headers])))
else:
return default
return found
def all_headers(self, name):
"""
Gets all headers by the ``name``, returns as a list
"""
found = []
for cur_name, value in self.headers:
if cur_name.lower() == name.lower():
found.append(value)
return found
def follow(self, **kw):
"""
If this request is a redirect, follow that redirect. It
is an error if this is not a redirect response. Returns
another response object.
"""
assert self.status >= 300 and self.status < 400, (
"You can only follow redirect responses (not %s)"
% self.full_status)
location = self.header('location')
type, rest = splittype(location)
host, path = splithost(rest)
# @@: We should test that it's not a remote redirect
return self.test_app.get(location, **kw)
def click(self, description=None, linkid=None, href=None,
anchor=None, index=None, verbose=False):
"""
Click the link as described. Each of ``description``,
``linkid``, and ``url`` are *patterns*, meaning that they are
either strings (regular expressions), compiled regular
expressions (objects with a ``search`` method), or callables
returning true or false.
All the given patterns are ANDed together:
* ``description`` is a pattern that matches the contents of the
anchor (HTML and all -- everything between ``<a...>`` and
``</a>``)
* ``linkid`` is a pattern that matches the ``id`` attribute of
the anchor. It will receive the empty string if no id is
given.
* ``href`` is a pattern that matches the ``href`` of the anchor;
the literal content of that attribute, not the fully qualified
attribute.
* ``anchor`` is a pattern that matches the entire anchor, with
its contents.
If more than one link matches, then the ``index`` link is
followed. If ``index`` is not given and more than one link
matches, or if no link matches, then ``IndexError`` will be
raised.
If you give ``verbose`` then messages will be printed about
each link, and why it does or doesn't match. If you use
``app.click(verbose=True)`` you'll see a list of all the
links.
You can use multiple criteria to essentially assert multiple
aspects about the link, e.g., where the link's destination is.
"""
__tracebackhide__ = True
found_html, found_desc, found_attrs = self._find_element(
tag='a', href_attr='href',
href_extract=None,
content=description,
id=linkid,
href_pattern=href,
html_pattern=anchor,
index=index, verbose=verbose)
return self.goto(found_attrs['uri'])
def clickbutton(self, description=None, buttonid=None, href=None,
button=None, index=None, verbose=False):
"""
Like ``.click()``, except looks for link-like buttons.
This kind of button should look like
``<button onclick="...location.href='url'...">``.
"""
__tracebackhide__ = True
found_html, found_desc, found_attrs = self._find_element(
tag='button', href_attr='onclick',
href_extract=re.compile(r"location\.href='(.*?)'"),
content=description,
id=buttonid,
href_pattern=href,
html_pattern=button,
index=index, verbose=verbose)
return self.goto(found_attrs['uri'])
def _find_element(self, tag, href_attr, href_extract,
content, id,
href_pattern,
html_pattern,
index, verbose):
content_pat = _make_pattern(content)
id_pat = _make_pattern(id)
href_pat = _make_pattern(href_pattern)
html_pat = _make_pattern(html_pattern)
_tag_re = re.compile(r'<%s\s+(.*?)>(.*?)</%s>' % (tag, tag),
re.I+re.S)
def printlog(s):
if verbose:
print(s)
found_links = []
total_links = 0
for match in _tag_re.finditer(self.body):
el_html = match.group(0)
el_attr = match.group(1)
el_content = match.group(2)
attrs = _parse_attrs(el_attr)
if verbose:
printlog('Element: %r' % el_html)
if not attrs.get(href_attr):
printlog(' Skipped: no %s attribute' % href_attr)
continue
el_href = attrs[href_attr]
if href_extract:
m = href_extract.search(el_href)
if not m:
printlog(" Skipped: doesn't match extract pattern")
continue
el_href = m.group(1)
attrs['uri'] = el_href
if el_href.startswith('#'):
printlog(' Skipped: only internal fragment href')
continue
if el_href.startswith('javascript:'):
printlog(' Skipped: cannot follow javascript:')
continue
total_links += 1
if content_pat and not content_pat(el_content):
printlog(" Skipped: doesn't match description")
continue
if id_pat and not id_pat(attrs.get('id', '')):
printlog(" Skipped: doesn't match id")
continue
if href_pat and not href_pat(el_href):
printlog(" Skipped: doesn't match href")
continue
if html_pat and not html_pat(el_html):
printlog(" Skipped: doesn't match html")
continue
printlog(" Accepted")
found_links.append((el_html, el_content, attrs))
if not found_links:
raise IndexError(
"No matching elements found (from %s possible)"
% total_links)
if index is None:
if len(found_links) > 1:
raise IndexError(
"Multiple links match: %s"
% ', '.join([repr(anc) for anc, d, attr in found_links]))
found_link = found_links[0]
else:
try:
found_link = found_links[index]
except IndexError:
raise IndexError(
"Only %s (out of %s) links match; index %s out of range"
% (len(found_links), total_links, index))
return found_link
def goto(self, href, method='get', **args):
"""
Go to the (potentially relative) link ``href``, using the
given method (``'get'`` or ``'post'``) and any extra arguments
you want to pass to the ``app.get()`` or ``app.post()``
methods.
All hostnames and schemes will be ignored.
"""
scheme, host, path, query, fragment = urlparse.urlsplit(href)
# We
scheme = host = fragment = ''
href = urlparse.urlunsplit((scheme, host, path, query, fragment))
href = urlparse.urljoin(self.request.full_url, href)
method = method.lower()
assert method in ('get', 'post'), (
'Only "get" or "post" are allowed for method (you gave %r)'
% method)
if method == 'get':
method = self.test_app.get
else:
method = self.test_app.post
return method(href, **args)
_normal_body_regex = re.compile(br'[ \n\r\t]+')
def normal_body__get(self):
if self._normal_body is None:
self._normal_body = self._normal_body_regex.sub(
b' ', self.body)
return self._normal_body
normal_body = property(normal_body__get,
doc="""
Return the whitespace-normalized body
""")
def __contains__(self, s):
"""
A response 'contains' a string if it is present in the body
of the response. Whitespace is normalized when searching
for a string.
"""
if not isinstance(s, (six.binary_type, six.text_type)):
s = str(s)
if isinstance(s, six.text_type):
## FIXME: we don't know that this response uses utf8:
s = s.encode('utf8')
return (self.body.find(s) != -1
or self.normal_body.find(s) != -1)
def mustcontain(self, *strings, **kw):
"""
Assert that the response contains all of the strings passed
in as arguments.
Equivalent to::
assert string in res
"""
if 'no' in kw:
no = kw['no']
del kw['no']
if isinstance(no, (six.binary_type, six.text_type)):
no = [no]
else:
no = []
if kw:
raise TypeError(
"The only keyword argument allowed is 'no'")
for s in strings:
if not s in self:
print("Actual response (no %r):" % s, file=sys.stderr)
print(self, file=sys.stderr)
raise IndexError(
"Body does not contain string %r" % s)
for no_s in no:
if no_s in self:
print("Actual response (has %r)" % no_s, file=sys.stderr)
print(self, file=sys.stderr)
raise IndexError(
"Body contains string %r" % s)
def __repr__(self):
body = self.body
if six.PY3:
body = body.decode('utf8', 'xmlcharrefreplace')
body = body[:20]
return '<Response %s %r>' % (self.full_status, body)
def __str__(self):
simple_body = b'\n'.join([l for l in self.body.splitlines()
if l.strip()])
if six.PY3:
simple_body = simple_body.decode('utf8', 'xmlcharrefreplace')
return 'Response: %s\n%s\n%s' % (
self.status,
'\n'.join(['%s: %s' % (n, v) for n, v in self.headers]),
simple_body)
def showbrowser(self):
"""
Show this response in a browser window (for debugging purposes,
when it's hard to read the HTML).
"""
import webbrowser
fn = tempnam_no_warning(None, 'paste-fixture') + '.html'
f = open(fn, 'wb')
f.write(self.body)
f.close()
url = 'file:' + fn.replace(os.sep, '/')
webbrowser.open_new(url)
class TestRequest(object):
# for py.test
disabled = True
"""
Instances of this class are created by `TestApp
<class-paste.fixture.TestApp.html>`_ with the ``.get()`` and
``.post()`` methods, and are consumed there by ``.do_request()``.
Instances are also available as a ``.req`` attribute on
`TestResponse <class-paste.fixture.TestResponse.html>`_ instances.
Useful attributes:
``url``:
The url (actually usually the path) of the request, without
query string.
``environ``:
The environment dictionary used for the request.
``full_url``:
The url/path, with query string.
"""
def __init__(self, url, environ, expect_errors=False):
if url.startswith('http://localhost'):
url = url[len('http://localhost'):]
self.url = url
self.environ = environ
if environ.get('QUERY_STRING'):
self.full_url = url + '?' + environ['QUERY_STRING']
else:
self.full_url = url
self.expect_errors = expect_errors
class Form(object):
"""
This object represents a form that has been found in a page.
This has a couple useful attributes:
``text``:
the full HTML of the form.
``action``:
the relative URI of the action.
``method``:
the method (e.g., ``'GET'``).
``id``:
the id, or None if not given.
``fields``:
a dictionary of fields, each value is a list of fields by
that name. ``<input type=\"radio\">`` and ``<select>`` are
both represented as single fields with multiple options.
"""
# @@: This really should be using Mechanize/ClientForm or
# something...
_tag_re = re.compile(r'<(/?)([:a-z0-9_\-]*)([^>]*?)>', re.I)
def __init__(self, response, text):
self.response = response
self.text = text
self._parse_fields()
self._parse_action()
def _parse_fields(self):
in_select = None
in_textarea = None
fields = {}
for match in self._tag_re.finditer(self.text):
end = match.group(1) == '/'
tag = match.group(2).lower()
if tag not in ('input', 'select', 'option', 'textarea',
'button'):
continue
if tag == 'select' and end:
assert in_select, (
'%r without starting select' % match.group(0))
in_select = None
continue
if tag == 'textarea' and end:
assert in_textarea, (
"</textarea> with no <textarea> at %s" % match.start())
in_textarea[0].value = html_unquote(self.text[in_textarea[1]:match.start()])
in_textarea = None
continue
if end:
continue
attrs = _parse_attrs(match.group(3))
if 'name' in attrs:
name = attrs.pop('name')
else:
name = None
if tag == 'option':
in_select.options.append((attrs.get('value'),
'selected' in attrs))
continue
if tag == 'input' and attrs.get('type') == 'radio':
field = fields.get(name)
if not field:
field = Radio(self, tag, name, match.start(), **attrs)
fields.setdefault(name, []).append(field)
else:
field = field[0]
assert isinstance(field, Radio)
field.options.append((attrs.get('value'),
'checked' in attrs))
continue
tag_type = tag
if tag == 'input':
tag_type = attrs.get('type', 'text').lower()
FieldClass = Field.classes.get(tag_type, Field)
field = FieldClass(self, tag, name, match.start(), **attrs)
if tag == 'textarea':
assert not in_textarea, (
"Nested textareas: %r and %r"
% (in_textarea, match.group(0)))
in_textarea = field, match.end()
elif tag == 'select':
assert not in_select, (
"Nested selects: %r and %r"
% (in_select, match.group(0)))
in_select = field
fields.setdefault(name, []).append(field)
self.fields = fields
def _parse_action(self):
self.action = None
for match in self._tag_re.finditer(self.text):
end = match.group(1) == '/'
tag = match.group(2).lower()
if tag != 'form':
continue
if end:
break
attrs = _parse_attrs(match.group(3))
self.action = attrs.get('action', '')
self.method = attrs.get('method', 'GET')
self.id = attrs.get('id')
# @@: enctype?
else:
assert 0, "No </form> tag found"
assert self.action is not None, (
"No <form> tag found")
def __setitem__(self, name, value):
"""
Set the value of the named field. If there is 0 or multiple
fields by that name, it is an error.
Setting the value of a ``<select>`` selects the given option
(and confirms it is an option). Setting radio fields does the
same. Checkboxes get boolean values. You cannot set hidden
fields or buttons.
Use ``.set()`` if there is any ambiguity and you must provide
an index.
"""
fields = self.fields.get(name)
assert fields is not None, (
"No field by the name %r found (fields: %s)"
% (name, ', '.join(map(repr, self.fields.keys()))))
assert len(fields) == 1, (
"Multiple fields match %r: %s"
% (name, ', '.join(map(repr, fields))))
fields[0].value = value
def __getitem__(self, name):
"""
Get the named field object (ambiguity is an error).
"""
fields = self.fields.get(name)
assert fields is not None, (
"No field by the name %r found" % name)
assert len(fields) == 1, (
"Multiple fields match %r: %s"
% (name, ', '.join(map(repr, fields))))
return fields[0]
def set(self, name, value, index=None):
"""
Set the given name, using ``index`` to disambiguate.
"""
if index is None:
self[name] = value
else:
fields = self.fields.get(name)
assert fields is not None, (
"No fields found matching %r" % name)
field = fields[index]
field.value = value
def get(self, name, index=None, default=NoDefault):
"""
Get the named/indexed field object, or ``default`` if no field
is found.
"""
fields = self.fields.get(name)
if fields is None and default is not NoDefault:
return default
if index is None:
return self[name]
else:
fields = self.fields.get(name)
assert fields is not None, (
"No fields found matching %r" % name)
field = fields[index]
return field
def select(self, name, value, index=None):
"""
Like ``.set()``, except also confirms the target is a
``<select>``.
"""
field = self.get(name, index=index)
assert isinstance(field, Select)
field.value = value
def submit(self, name=None, index=None, **args):
"""
Submits the form. If ``name`` is given, then also select that
button (using ``index`` to disambiguate)``.
Any extra keyword arguments are passed to the ``.get()`` or
``.post()`` method.
Returns a response object.
"""
fields = self.submit_fields(name, index=index)
return self.response.goto(self.action, method=self.method,
params=fields, **args)
def submit_fields(self, name=None, index=None):
"""
Return a list of ``[(name, value), ...]`` for the current
state of the form.
"""
submit = []
if name is not None:
field = self.get(name, index=index)
submit.append((field.name, field.value_if_submitted()))
for name, fields in self.fields.items():
if name is None:
continue
for field in fields:
value = field.value
if value is None:
continue
submit.append((name, value))
return submit
_attr_re = re.compile(r'([^= \n\r\t]+)[ \n\r\t]*(?:=[ \n\r\t]*(?:"([^"]*)"|([^"][^ \n\r\t>]*)))?', re.S)
def _parse_attrs(text):
attrs = {}
for match in _attr_re.finditer(text):
attr_name = match.group(1).lower()
attr_body = match.group(2) or match.group(3)
attr_body = html_unquote(attr_body or '')
attrs[attr_name] = attr_body
return attrs
class Field(object):
"""
Field object.
"""
# Dictionary of field types (select, radio, etc) to classes
classes = {}
settable = True
def __init__(self, form, tag, name, pos,
value=None, id=None, **attrs):
self.form = form
self.tag = tag
self.name = name
self.pos = pos
self._value = value
self.id = id
self.attrs = attrs
def value__set(self, value):
if not self.settable:
raise AttributeError(
"You cannot set the value of the <%s> field %r"
% (self.tag, self.name))
self._value = value
def force_value(self, value):
"""
Like setting a value, except forces it even for, say, hidden
fields.
"""
self._value = value
def value__get(self):
return self._value
value = property(value__get, value__set)
class Select(Field):
"""
Field representing ``<select>``
"""
def __init__(self, *args, **attrs):
super(Select, self).__init__(*args, **attrs)
self.options = []
self.multiple = attrs.get('multiple')
assert not self.multiple, (
"<select multiple> not yet supported")
# Undetermined yet:
self.selectedIndex = None
def value__set(self, value):
for i, (option, checked) in enumerate(self.options):
if option == str(value):
self.selectedIndex = i
break
else:
raise ValueError(
"Option %r not found (from %s)"
% (value, ', '.join(
[repr(o) for o, c in self.options])))
def value__get(self):
if self.selectedIndex is not None:
return self.options[self.selectedIndex][0]
else:
for option, checked in self.options:
if checked:
return option
else:
if self.options:
return self.options[0][0]
else:
return None
value = property(value__get, value__set)
Field.classes['select'] = Select
class Radio(Select):
"""
Field representing ``<input type="radio">``
"""
Field.classes['radio'] = Radio
class Checkbox(Field):
"""
Field representing ``<input type="checkbox">``
"""
def __init__(self, *args, **attrs):
super(Checkbox, self).__init__(*args, **attrs)
self.checked = 'checked' in attrs
def value__set(self, value):
self.checked = not not value
def value__get(self):
if self.checked:
if self._value is None:
return 'on'
else:
return self._value
else:
return None
value = property(value__get, value__set)
Field.classes['checkbox'] = Checkbox
class Text(Field):
"""
Field representing ``<input type="text">``
"""
def __init__(self, form, tag, name, pos,
value='', id=None, **attrs):
#text fields default to empty string
Field.__init__(self, form, tag, name, pos,
value=value, id=id, **attrs)
Field.classes['text'] = Text
class Textarea(Text):
"""
Field representing ``<textarea>``
"""
Field.classes['textarea'] = Textarea
class Hidden(Text):
"""
Field representing ``<input type="hidden">``
"""
Field.classes['hidden'] = Hidden
class Submit(Field):
"""
Field representing ``<input type="submit">`` and ``<button>``
"""
settable = False
def value__get(self):
return None
value = property(value__get)
def value_if_submitted(self):
return self._value
Field.classes['submit'] = Submit
Field.classes['button'] = Submit
Field.classes['image'] = Submit
############################################################
## Command-line testing
############################################################
class TestFileEnvironment(object):
"""
This represents an environment in which files will be written, and
scripts will be run.
"""
# for py.test
disabled = True
def __init__(self, base_path, template_path=None,
script_path=None,
environ=None, cwd=None, start_clear=True,
ignore_paths=None, ignore_hidden=True):
"""
Creates an environment. ``base_path`` is used as the current
working directory, and generally where changes are looked for.
``template_path`` is the directory to look for *template*
files, which are files you'll explicitly add to the
environment. This is done with ``.writefile()``.
``script_path`` is the PATH for finding executables. Usually
grabbed from ``$PATH``.
``environ`` is the operating system environment,
``os.environ`` if not given.
``cwd`` is the working directory, ``base_path`` by default.
If ``start_clear`` is true (default) then the ``base_path``
will be cleared (all files deleted) when an instance is
created. You can also use ``.clear()`` to clear the files.
``ignore_paths`` is a set of specific filenames that should be
ignored when created in the environment. ``ignore_hidden``
means, if true (default) that filenames and directories
starting with ``'.'`` will be ignored.
"""
self.base_path = base_path
self.template_path = template_path
if environ is None:
environ = os.environ.copy()
self.environ = environ
if script_path is None:
if sys.platform == 'win32':
script_path = environ.get('PATH', '').split(';')
else:
script_path = environ.get('PATH', '').split(':')
self.script_path = script_path
if cwd is None:
cwd = base_path
self.cwd = cwd
if start_clear:
self.clear()
elif not os.path.exists(base_path):
os.makedirs(base_path)
self.ignore_paths = ignore_paths or []
self.ignore_hidden = ignore_hidden
def run(self, script, *args, **kw):
"""
Run the command, with the given arguments. The ``script``
argument can have space-separated arguments, or you can use
the positional arguments.
Keywords allowed are:
``expect_error``: (default False)
Don't raise an exception in case of errors
``expect_stderr``: (default ``expect_error``)
Don't raise an exception if anything is printed to stderr
``stdin``: (default ``""``)
Input to the script
``printresult``: (default True)
Print the result after running
``cwd``: (default ``self.cwd``)
The working directory to run in
Returns a `ProcResponse
<class-paste.fixture.ProcResponse.html>`_ object.
"""
__tracebackhide__ = True
expect_error = _popget(kw, 'expect_error', False)
expect_stderr = _popget(kw, 'expect_stderr', expect_error)
cwd = _popget(kw, 'cwd', self.cwd)
stdin = _popget(kw, 'stdin', None)
printresult = _popget(kw, 'printresult', True)
args = list(map(str, args))
assert not kw, (
"Arguments not expected: %s" % ', '.join(kw.keys()))
if ' ' in script:
assert not args, (
"You cannot give a multi-argument script (%r) "
"and arguments (%s)" % (script, args))
script, args = script.split(None, 1)
args = shlex.split(args)
script = self._find_exe(script)
all = [script] + args
files_before = self._find_files()
proc = subprocess.Popen(all, stdin=subprocess.PIPE,
stderr=subprocess.PIPE,
stdout=subprocess.PIPE,
cwd=cwd,
env=self.environ)
stdout, stderr = proc.communicate(stdin)
files_after = self._find_files()
result = ProcResult(
self, all, stdin, stdout, stderr,
returncode=proc.returncode,
files_before=files_before,
files_after=files_after)
if printresult:
print(result)
print('-'*40)
if not expect_error:
result.assert_no_error()
if not expect_stderr:
result.assert_no_stderr()
return result
def _find_exe(self, script_name):
if self.script_path is None:
script_name = os.path.join(self.cwd, script_name)
if not os.path.exists(script_name):
raise OSError(
"Script %s does not exist" % script_name)
return script_name
for path in self.script_path:
fn = os.path.join(path, script_name)
if os.path.exists(fn):
return fn
raise OSError(
"Script %s could not be found in %s"
% (script_name, ':'.join(self.script_path)))
def _find_files(self):
result = {}
for fn in os.listdir(self.base_path):
if self._ignore_file(fn):
continue
self._find_traverse(fn, result)
return result
def _ignore_file(self, fn):
if fn in self.ignore_paths:
return True
if self.ignore_hidden and os.path.basename(fn).startswith('.'):
return True
return False
def _find_traverse(self, path, result):
full = os.path.join(self.base_path, path)
if os.path.isdir(full):
result[path] = FoundDir(self.base_path, path)
for fn in os.listdir(full):
fn = os.path.join(path, fn)
if self._ignore_file(fn):
continue
self._find_traverse(fn, result)
else:
result[path] = FoundFile(self.base_path, path)
def clear(self):
"""
Delete all the files in the base directory.
"""
if os.path.exists(self.base_path):
shutil.rmtree(self.base_path)
os.mkdir(self.base_path)
def writefile(self, path, content=None,
frompath=None):
"""
Write a file to the given path. If ``content`` is given then
that text is written, otherwise the file in ``frompath`` is
used. ``frompath`` is relative to ``self.template_path``
"""
full = os.path.join(self.base_path, path)
if not os.path.exists(os.path.dirname(full)):
os.makedirs(os.path.dirname(full))
f = open(full, 'wb')
if content is not None:
f.write(content)
if frompath is not None:
if self.template_path:
frompath = os.path.join(self.template_path, frompath)
f2 = open(frompath, 'rb')
f.write(f2.read())
f2.close()
f.close()
return FoundFile(self.base_path, path)
class ProcResult(object):
"""
Represents the results of running a command in
`TestFileEnvironment
<class-paste.fixture.TestFileEnvironment.html>`_.
Attributes to pay particular attention to:
``stdout``, ``stderr``:
What is produced
``files_created``, ``files_deleted``, ``files_updated``:
Dictionaries mapping filenames (relative to the ``base_dir``)
to `FoundFile <class-paste.fixture.FoundFile.html>`_ or
`FoundDir <class-paste.fixture.FoundDir.html>`_ objects.
"""
def __init__(self, test_env, args, stdin, stdout, stderr,
returncode, files_before, files_after):
self.test_env = test_env
self.args = args
self.stdin = stdin
self.stdout = stdout
self.stderr = stderr
self.returncode = returncode
self.files_before = files_before
self.files_after = files_after
self.files_deleted = {}
self.files_updated = {}
self.files_created = files_after.copy()
for path, f in files_before.items():
if path not in files_after:
self.files_deleted[path] = f
continue
del self.files_created[path]
if f.mtime < files_after[path].mtime:
self.files_updated[path] = files_after[path]
def assert_no_error(self):
__tracebackhide__ = True
assert self.returncode == 0, (
"Script returned code: %s" % self.returncode)
def assert_no_stderr(self):
__tracebackhide__ = True
if self.stderr:
print('Error output:')
print(self.stderr)
raise AssertionError("stderr output not expected")
def __str__(self):
s = ['Script result: %s' % ' '.join(self.args)]
if self.returncode:
s.append(' return code: %s' % self.returncode)
if self.stderr:
s.append('-- stderr: --------------------')
s.append(self.stderr)
if self.stdout:
s.append('-- stdout: --------------------')
s.append(self.stdout)
for name, files, show_size in [
('created', self.files_created, True),
('deleted', self.files_deleted, True),
('updated', self.files_updated, True)]:
if files:
s.append('-- %s: -------------------' % name)
files = files.items()
files.sort()
last = ''
for path, f in files:
t = ' %s' % _space_prefix(last, path, indent=4,
include_sep=False)
last = path
if show_size and f.size != 'N/A':
t += ' (%s bytes)' % f.size
s.append(t)
return '\n'.join(s)
class FoundFile(object):
"""
Represents a single file found as the result of a command.
Has attributes:
``path``:
The path of the file, relative to the ``base_path``
``full``:
The full path
``stat``:
The results of ``os.stat``. Also ``mtime`` and ``size``
contain the ``.st_mtime`` and ``st_size`` of the stat.
``bytes``:
The contents of the file.
You may use the ``in`` operator with these objects (tested against
the contents of the file), and the ``.mustcontain()`` method.
"""
file = True
dir = False
def __init__(self, base_path, path):
self.base_path = base_path
self.path = path
self.full = os.path.join(base_path, path)
self.stat = os.stat(self.full)
self.mtime = self.stat.st_mtime
self.size = self.stat.st_size
self._bytes = None
def bytes__get(self):
if self._bytes is None:
f = open(self.full, 'rb')
self._bytes = f.read()
f.close()
return self._bytes
bytes = property(bytes__get)
def __contains__(self, s):
return s in self.bytes
def mustcontain(self, s):
__tracebackhide__ = True
bytes_ = self.bytes
if s not in bytes_:
print('Could not find %r in:' % s)
print(bytes_)
assert s in bytes_
def __repr__(self):
return '<%s %s:%s>' % (
self.__class__.__name__,
self.base_path, self.path)
class FoundDir(object):
"""
Represents a directory created by a command.
"""
file = False
dir = True
def __init__(self, base_path, path):
self.base_path = base_path
self.path = path
self.full = os.path.join(base_path, path)
self.size = 'N/A'
self.mtime = 'N/A'
def __repr__(self):
return '<%s %s:%s>' % (
self.__class__.__name__,
self.base_path, self.path)
def _popget(d, key, default=None):
"""
Pop the key if found (else return default)
"""
if key in d:
return d.pop(key)
return default
def _space_prefix(pref, full, sep=None, indent=None, include_sep=True):
"""
Anything shared by pref and full will be replaced with spaces
in full, and full returned.
"""
if sep is None:
sep = os.path.sep
pref = pref.split(sep)
full = full.split(sep)
padding = []
while pref and full and pref[0] == full[0]:
if indent is None:
padding.append(' ' * (len(full[0]) + len(sep)))
else:
padding.append(' ' * indent)
full.pop(0)
pref.pop(0)
if padding:
if include_sep:
return ''.join(padding) + sep + sep.join(full)
else:
return ''.join(padding) + sep.join(full)
else:
return sep.join(full)
def _make_pattern(pat):
if pat is None:
return None
if isinstance(pat, (six.binary_type, six.text_type)):
pat = re.compile(pat)
if hasattr(pat, 'search'):
return pat.search
if callable(pat):
return pat
assert 0, (
"Cannot make callable pattern object out of %r" % pat)
def setup_module(module=None):
"""
This is used by py.test if it is in the module, so you can
import this directly.
Use like::
from paste.fixture import setup_module
"""
# Deprecated June 2008
import warnings
warnings.warn(
'setup_module is deprecated',
DeprecationWarning, 2)
if module is None:
# The module we were called from must be the module...
module = sys._getframe().f_back.f_globals['__name__']
if isinstance(module, (six.binary_type, six.text_type)):
module = sys.modules[module]
if hasattr(module, 'reset_state'):
module.reset_state()
def html_unquote(v):
"""
Unquote (some) entities in HTML. (incomplete)
"""
for ent, repl in [('&nbsp;', ' '), ('&gt;', '>'),
('&lt;', '<'), ('&quot;', '"'),
('&amp;', '&')]:
v = v.replace(ent, repl)
return v