# (c) 2005 Ian Bicking and contributors; written for Paste (http://pythonpaste.org)
# Licensed under the MIT license: http://www.opensource.org/licenses/mit-license.php
"""
Routines for testing WSGI applications.
Most interesting is the `TestApp `_
for testing WSGI applications, and the `TestFileEnvironment
`_ class for testing the
effects of command-line scripts.
"""
import sys
import random
import urllib
import urlparse
import mimetypes
import time
import cgi
import os
import shutil
import smtplib
import shlex
from Cookie import BaseCookie
try:
from cStringIO import StringIO
except ImportError:
from StringIO import StringIO
import re
try:
import subprocess
except ImportError:
from paste.util import subprocess24 as subprocess
from paste import wsgilib
from paste import lint
from paste.response import HeaderDict
def tempnam_no_warning(*args):
"""
An os.tempnam with the warning turned off, because sometimes
you just need to use this and don't care about the stupid
security warning.
"""
return os.tempnam(*args)
class NoDefault(object):
pass
def sorted(l):
l = list(l)
l.sort()
return l
class Dummy_smtplib(object):
existing = None
def __init__(self, server):
import warnings
warnings.warn(
'Dummy_smtplib is not maintained and is deprecated',
DeprecationWarning, 2)
assert not self.existing, (
"smtplib.SMTP() called again before Dummy_smtplib.existing.reset() "
"called.")
self.server = server
self.open = True
self.__class__.existing = self
def quit(self):
assert self.open, (
"Called %s.quit() twice" % self)
self.open = False
def sendmail(self, from_address, to_addresses, msg):
self.from_address = from_address
self.to_addresses = to_addresses
self.message = msg
def install(cls):
smtplib.SMTP = cls
install = classmethod(install)
def reset(self):
assert not self.open, (
"SMTP connection not quit")
self.__class__.existing = None
class AppError(Exception):
pass
class TestApp(object):
# for py.test
disabled = True
def __init__(self, app, namespace=None, relative_to=None,
extra_environ=None, pre_request_hook=None,
post_request_hook=None):
"""
Wraps a WSGI application in a more convenient interface for
testing.
``app`` may be an application, or a Paste Deploy app
URI, like ``'config:filename.ini#test'``.
``namespace`` is a dictionary that will be written to (if
provided). This can be used with doctest or some other
system, and the variable ``res`` will be assigned everytime
you make a request (instead of returning the request).
``relative_to`` is a directory, and filenames used for file
uploads are calculated relative to this. Also ``config:``
URIs that aren't absolute.
``extra_environ`` is a dictionary of values that should go
into the environment for each request. These can provide a
communication channel with the application.
``pre_request_hook`` is a function to be called prior to
making requests (such as ``post`` or ``get``). This function
must take one argument (the instance of the TestApp).
``post_request_hook`` is a function, similar to
``pre_request_hook``, to be called after requests are made.
"""
if isinstance(app, (str, unicode)):
from paste.deploy import loadapp
# @@: Should pick up relative_to from calling module's
# __file__
app = loadapp(app, relative_to=relative_to)
self.app = app
self.namespace = namespace
self.relative_to = relative_to
if extra_environ is None:
extra_environ = {}
self.extra_environ = extra_environ
self.pre_request_hook = pre_request_hook
self.post_request_hook = post_request_hook
self.reset()
def reset(self):
"""
Resets the state of the application; currently just clears
saved cookies.
"""
self.cookies = {}
def _make_environ(self):
environ = self.extra_environ.copy()
environ['paste.throw_errors'] = True
return environ
def get(self, url, params=None, headers=None, extra_environ=None,
status=None, expect_errors=False):
"""
Get the given url (well, actually a path like
``'/page.html'``).
``params``:
A query string, or a dictionary that will be encoded
into a query string. You may also include a query
string on the ``url``.
``headers``:
A dictionary of extra headers to send.
``extra_environ``:
A dictionary of environmental variables that should
be added to the request.
``status``:
The integer status code you expect (if not 200 or 3xx).
If you expect a 404 response, for instance, you must give
``status=404`` or it will be an error. You can also give
a wildcard, like ``'3*'`` or ``'*'``.
``expect_errors``:
If this is not true, then if anything is written to
``wsgi.errors`` it will be an error. If it is true, then
non-200/3xx responses are also okay.
Returns a `response object
`_
"""
if extra_environ is None:
extra_environ = {}
# Hide from py.test:
__tracebackhide__ = True
if params:
if not isinstance(params, (str, unicode)):
params = urllib.urlencode(params, doseq=True)
if '?' in url:
url += '&'
else:
url += '?'
url += params
environ = self._make_environ()
url = str(url)
if '?' in url:
url, environ['QUERY_STRING'] = url.split('?', 1)
else:
environ['QUERY_STRING'] = ''
self._set_headers(headers, environ)
environ.update(extra_environ)
req = TestRequest(url, environ, expect_errors)
return self.do_request(req, status=status)
def _gen_request(self, method, url, params='', headers=None, extra_environ=None,
status=None, upload_files=None, expect_errors=False):
"""
Do a generic request.
"""
if headers is None:
headers = {}
if extra_environ is None:
extra_environ = {}
environ = self._make_environ()
# @@: Should this be all non-strings?
if isinstance(params, (list, tuple, dict)):
params = urllib.urlencode(params)
if hasattr(params, 'items'):
# Some other multi-dict like format
params = urllib.urlencode(params.items())
if upload_files:
params = cgi.parse_qsl(params, keep_blank_values=True)
content_type, params = self.encode_multipart(
params, upload_files)
environ['CONTENT_TYPE'] = content_type
elif params:
environ.setdefault('CONTENT_TYPE', 'application/x-www-form-urlencoded')
if '?' in url:
url, environ['QUERY_STRING'] = url.split('?', 1)
else:
environ['QUERY_STRING'] = ''
environ['CONTENT_LENGTH'] = str(len(params))
environ['REQUEST_METHOD'] = method
environ['wsgi.input'] = StringIO(params)
self._set_headers(headers, environ)
environ.update(extra_environ)
req = TestRequest(url, environ, expect_errors)
return self.do_request(req, status=status)
def post(self, url, params='', headers=None, extra_environ=None,
status=None, upload_files=None, expect_errors=False):
"""
Do a POST request. Very like the ``.get()`` method.
``params`` are put in the body of the request.
``upload_files`` is for file uploads. It should be a list of
``[(fieldname, filename, file_content)]``. You can also use
just ``[(fieldname, filename)]`` and the file content will be
read from disk.
Returns a `response object
`_
"""
return self._gen_request('POST', url, params=params, headers=headers,
extra_environ=extra_environ,status=status,
upload_files=upload_files,
expect_errors=expect_errors)
def put(self, url, params='', headers=None, extra_environ=None,
status=None, upload_files=None, expect_errors=False):
"""
Do a PUT request. Very like the ``.get()`` method.
``params`` are put in the body of the request.
``upload_files`` is for file uploads. It should be a list of
``[(fieldname, filename, file_content)]``. You can also use
just ``[(fieldname, filename)]`` and the file content will be
read from disk.
Returns a `response object
`_
"""
return self._gen_request('PUT', url, params=params, headers=headers,
extra_environ=extra_environ,status=status,
upload_files=upload_files,
expect_errors=expect_errors)
def delete(self, url, params='', headers=None, extra_environ=None,
status=None, expect_errors=False):
"""
Do a DELETE request. Very like the ``.get()`` method.
``params`` are put in the body of the request.
Returns a `response object
`_
"""
return self._gen_request('DELETE', url, params=params, headers=headers,
extra_environ=extra_environ,status=status,
upload_files=None, expect_errors=expect_errors)
def _set_headers(self, headers, environ):
"""
Turn any headers into environ variables
"""
if not headers:
return
for header, value in headers.items():
if header.lower() == 'content-type':
var = 'CONTENT_TYPE'
elif header.lower() == 'content-length':
var = 'CONTENT_LENGTH'
else:
var = 'HTTP_%s' % header.replace('-', '_').upper()
environ[var] = value
def encode_multipart(self, params, files):
"""
Encodes a set of parameters (typically a name/value list) and
a set of files (a list of (name, filename, file_body)) into a
typical POST body, returning the (content_type, body).
"""
boundary = '----------a_BoUnDaRy%s$' % random.random()
lines = []
for key, value in params:
lines.append('--'+boundary)
lines.append('Content-Disposition: form-data; name="%s"' % key)
lines.append('')
lines.append(value)
for file_info in files:
key, filename, value = self._get_file_info(file_info)
lines.append('--'+boundary)
lines.append('Content-Disposition: form-data; name="%s"; filename="%s"'
% (key, filename))
fcontent = mimetypes.guess_type(filename)[0]
lines.append('Content-Type: %s' %
fcontent or 'application/octet-stream')
lines.append('')
lines.append(value)
lines.append('--' + boundary + '--')
lines.append('')
body = '\r\n'.join(lines)
content_type = 'multipart/form-data; boundary=%s' % boundary
return content_type, body
def _get_file_info(self, file_info):
if len(file_info) == 2:
# It only has a filename
filename = file_info[1]
if self.relative_to:
filename = os.path.join(self.relative_to, filename)
f = open(filename, 'rb')
content = f.read()
f.close()
return (file_info[0], filename, content)
elif len(file_info) == 3:
return file_info
else:
raise ValueError(
"upload_files need to be a list of tuples of (fieldname, "
"filename, filecontent) or (fieldname, filename); "
"you gave: %r"
% repr(file_info)[:100])
def do_request(self, req, status):
"""
Executes the given request (``req``), with the expected
``status``. Generally ``.get()`` and ``.post()`` are used
instead.
"""
if self.pre_request_hook:
self.pre_request_hook(self)
__tracebackhide__ = True
if self.cookies:
c = BaseCookie()
for name, value in self.cookies.items():
c[name] = value
hc = '; '.join(['='.join([m.key, m.value]) for m in c.values()])
req.environ['HTTP_COOKIE'] = hc
req.environ['paste.testing'] = True
req.environ['paste.testing_variables'] = {}
app = lint.middleware(self.app)
old_stdout = sys.stdout
out = CaptureStdout(old_stdout)
try:
sys.stdout = out
start_time = time.time()
raise_on_wsgi_error = not req.expect_errors
raw_res = wsgilib.raw_interactive(
app, req.url,
raise_on_wsgi_error=raise_on_wsgi_error,
**req.environ)
end_time = time.time()
finally:
sys.stdout = old_stdout
sys.stderr.write(out.getvalue())
res = self._make_response(raw_res, end_time - start_time)
res.request = req
for name, value in req.environ['paste.testing_variables'].items():
if hasattr(res, name):
raise ValueError(
"paste.testing_variables contains the variable %r, but "
"the response object already has an attribute by that "
"name" % name)
setattr(res, name, value)
if self.namespace is not None:
self.namespace['res'] = res
if not req.expect_errors:
self._check_status(status, res)
self._check_errors(res)
res.cookies_set = {}
for header in res.all_headers('set-cookie'):
c = BaseCookie(header)
for key, morsel in c.items():
self.cookies[key] = morsel.value
res.cookies_set[key] = morsel.value
if self.post_request_hook:
self.post_request_hook(self)
if self.namespace is None:
# It's annoying to return the response in doctests, as it'll
# be printed, so we only return it is we couldn't assign
# it anywhere
return res
def _check_status(self, status, res):
__tracebackhide__ = True
if status == '*':
return
if isinstance(status, (list, tuple)):
if res.status not in status:
raise AppError(
"Bad response: %s (not one of %s for %s)\n%s"
% (res.full_status, ', '.join(map(str, status)),
res.request.url, res.body))
return
if status is None:
if res.status >= 200 and res.status < 400:
return
raise AppError(
"Bad response: %s (not 200 OK or 3xx redirect for %s)\n%s"
% (res.full_status, res.request.url,
res.body))
if status != res.status:
raise AppError(
"Bad response: %s (not %s)" % (res.full_status, status))
def _check_errors(self, res):
if res.errors:
raise AppError(
"Application had errors logged:\n%s" % res.errors)
def _make_response(self, (status, headers, body, errors), total_time):
return TestResponse(self, status, headers, body, errors,
total_time)
class CaptureStdout(object):
def __init__(self, actual):
self.captured = StringIO()
self.actual = actual
def write(self, s):
self.captured.write(s)
self.actual.write(s)
def flush(self):
self.actual.flush()
def writelines(self, lines):
for item in lines:
self.write(item)
def getvalue(self):
return self.captured.getvalue()
class TestResponse(object):
# for py.test
disabled = True
"""
Instances of this class are return by `TestApp
`_
"""
def __init__(self, test_app, status, headers, body, errors,
total_time):
self.test_app = test_app
self.status = int(status.split()[0])
self.full_status = status
self.headers = headers
self.header_dict = HeaderDict.fromlist(self.headers)
self.body = body
self.errors = errors
self._normal_body = None
self.time = total_time
self._forms_indexed = None
def forms__get(self):
"""
Returns a dictionary of ``Form`` objects. Indexes are both in
order (from zero) and by form id (if the form is given an id).
"""
if self._forms_indexed is None:
self._parse_forms()
return self._forms_indexed
forms = property(forms__get,
doc="""
A list of unexpected at %s" % match.start())
form_texts.append(self.body[started:match.end()])
started = None
else:
assert not started, (
"Nested form tags at %s" % match.start())
started = match.start()
assert not started, (
"Danging form: %r" % self.body[started:])
for i, text in enumerate(form_texts):
form = Form(self, text)
forms[i] = form
if form.id:
forms[form.id] = form
def header(self, name, default=NoDefault):
"""
Returns the named header; an error if there is not exactly one
matching header (unless you give a default -- always an error
if there is more than one header)
"""
found = None
for cur_name, value in self.headers:
if cur_name.lower() == name.lower():
assert not found, (
"Ambiguous header: %s matches %r and %r"
% (name, found, value))
found = value
if found is None:
if default is NoDefault:
raise KeyError(
"No header found: %r (from %s)"
% (name, ', '.join([n for n, v in self.headers])))
else:
return default
return found
def all_headers(self, name):
"""
Gets all headers by the ``name``, returns as a list
"""
found = []
for cur_name, value in self.headers:
if cur_name.lower() == name.lower():
found.append(value)
return found
def follow(self, **kw):
"""
If this request is a redirect, follow that redirect. It
is an error if this is not a redirect response. Returns
another response object.
"""
assert self.status >= 300 and self.status < 400, (
"You can only follow redirect responses (not %s)"
% self.full_status)
location = self.header('location')
type, rest = urllib.splittype(location)
host, path = urllib.splithost(rest)
# @@: We should test that it's not a remote redirect
return self.test_app.get(location, **kw)
def click(self, description=None, linkid=None, href=None,
anchor=None, index=None, verbose=False):
"""
Click the link as described. Each of ``description``,
``linkid``, and ``url`` are *patterns*, meaning that they are
either strings (regular expressions), compiled regular
expressions (objects with a ``search`` method), or callables
returning true or false.
All the given patterns are ANDed together:
* ``description`` is a pattern that matches the contents of the
anchor (HTML and all -- everything between ```` and
````)
* ``linkid`` is a pattern that matches the ``id`` attribute of
the anchor. It will receive the empty string if no id is
given.
* ``href`` is a pattern that matches the ``href`` of the anchor;
the literal content of that attribute, not the fully qualified
attribute.
* ``anchor`` is a pattern that matches the entire anchor, with
its contents.
If more than one link matches, then the ``index`` link is
followed. If ``index`` is not given and more than one link
matches, or if no link matches, then ``IndexError`` will be
raised.
If you give ``verbose`` then messages will be printed about
each link, and why it does or doesn't match. If you use
``app.click(verbose=True)`` you'll see a list of all the
links.
You can use multiple criteria to essentially assert multiple
aspects about the link, e.g., where the link's destination is.
"""
__tracebackhide__ = True
found_html, found_desc, found_attrs = self._find_element(
tag='a', href_attr='href',
href_extract=None,
content=description,
id=linkid,
href_pattern=href,
html_pattern=anchor,
index=index, verbose=verbose)
return self.goto(found_attrs['uri'])
def clickbutton(self, description=None, buttonid=None, href=None,
button=None, index=None, verbose=False):
"""
Like ``.click()``, except looks for link-like buttons.
This kind of button should look like
``