|
1 | | -import asyncio |
2 | | -import collections |
3 | | -import gc |
4 | | -import logging |
5 | | -import pytest |
6 | | -import re |
7 | | -import socket |
8 | | -import sys |
9 | | -import warnings |
10 | | - |
11 | | -from aiohttp import web |
12 | | - |
13 | | - |
14 | | -class _AssertWarnsContext: |
15 | | - """A context manager used to implement TestCase.assertWarns* methods.""" |
16 | | - |
17 | | - def __init__(self, expected, expected_regex=None): |
18 | | - self.expected = expected |
19 | | - if expected_regex is not None: |
20 | | - expected_regex = re.compile(expected_regex) |
21 | | - self.expected_regex = expected_regex |
22 | | - self.obj_name = None |
23 | | - |
24 | | - def __enter__(self): |
25 | | - # The __warningregistry__'s need to be in a pristine state for tests |
26 | | - # to work properly. |
27 | | - for v in sys.modules.values(): |
28 | | - if getattr(v, '__warningregistry__', None): |
29 | | - v.__warningregistry__ = {} |
30 | | - self.warnings_manager = warnings.catch_warnings(record=True) |
31 | | - self.warnings = self.warnings_manager.__enter__() |
32 | | - warnings.simplefilter("always", self.expected) |
33 | | - return self |
34 | | - |
35 | | - def __exit__(self, exc_type, exc_value, tb): |
36 | | - self.warnings_manager.__exit__(exc_type, exc_value, tb) |
37 | | - if exc_type is not None: |
38 | | - # let unexpected exceptions pass through |
39 | | - return |
40 | | - try: |
41 | | - exc_name = self.expected.__name__ |
42 | | - except AttributeError: |
43 | | - exc_name = str(self.expected) |
44 | | - first_matching = None |
45 | | - for m in self.warnings: |
46 | | - w = m.message |
47 | | - if not isinstance(w, self.expected): |
48 | | - continue |
49 | | - if first_matching is None: |
50 | | - first_matching = w |
51 | | - if (self.expected_regex is not None and |
52 | | - not self.expected_regex.search(str(w))): |
53 | | - continue |
54 | | - # store warning for later retrieval |
55 | | - self.warning = w |
56 | | - self.filename = m.filename |
57 | | - self.lineno = m.lineno |
58 | | - return |
59 | | - # Now we simply try to choose a helpful failure message |
60 | | - if first_matching is not None: |
61 | | - __tracebackhide__ = True |
62 | | - assert 0, '"{}" does not match "{}"'.format( |
63 | | - self.expected_regex.pattern, str(first_matching)) |
64 | | - if self.obj_name: |
65 | | - __tracebackhide__ = True |
66 | | - assert 0, "{} not triggered by {}".format(exc_name, |
67 | | - self.obj_name) |
68 | | - else: |
69 | | - __tracebackhide__ = True |
70 | | - assert 0, "{} not triggered".format(exc_name) |
71 | | - |
72 | | - |
73 | | -_LoggingWatcher = collections.namedtuple("_LoggingWatcher", |
74 | | - ["records", "output"]) |
75 | | - |
76 | | - |
77 | | -class _CapturingHandler(logging.Handler): |
78 | | - """ |
79 | | - A logging handler capturing all (raw and formatted) logging output. |
80 | | - """ |
81 | | - |
82 | | - def __init__(self): |
83 | | - logging.Handler.__init__(self) |
84 | | - self.watcher = _LoggingWatcher([], []) |
85 | | - |
86 | | - def flush(self): |
87 | | - pass |
88 | | - |
89 | | - def emit(self, record): |
90 | | - self.watcher.records.append(record) |
91 | | - msg = self.format(record) |
92 | | - self.watcher.output.append(msg) |
93 | | - |
94 | | - |
95 | | -class _AssertLogsContext: |
96 | | - """A context manager used to implement TestCase.assertLogs().""" |
97 | | - |
98 | | - LOGGING_FORMAT = "%(levelname)s:%(name)s:%(message)s" |
99 | | - |
100 | | - def __init__(self, logger_name=None, level=None): |
101 | | - self.logger_name = logger_name |
102 | | - if level: |
103 | | - self.level = logging._nameToLevel.get(level, level) |
104 | | - else: |
105 | | - self.level = logging.INFO |
106 | | - self.msg = None |
107 | | - |
108 | | - def __enter__(self): |
109 | | - if isinstance(self.logger_name, logging.Logger): |
110 | | - logger = self.logger = self.logger_name |
111 | | - else: |
112 | | - logger = self.logger = logging.getLogger(self.logger_name) |
113 | | - formatter = logging.Formatter(self.LOGGING_FORMAT) |
114 | | - handler = _CapturingHandler() |
115 | | - handler.setFormatter(formatter) |
116 | | - self.watcher = handler.watcher |
117 | | - self.old_handlers = logger.handlers[:] |
118 | | - self.old_level = logger.level |
119 | | - self.old_propagate = logger.propagate |
120 | | - logger.handlers = [handler] |
121 | | - logger.setLevel(self.level) |
122 | | - logger.propagate = False |
123 | | - return handler.watcher |
124 | | - |
125 | | - def __exit__(self, exc_type, exc_value, tb): |
126 | | - self.logger.handlers = self.old_handlers |
127 | | - self.logger.propagate = self.old_propagate |
128 | | - self.logger.setLevel(self.old_level) |
129 | | - if exc_type is not None: |
130 | | - # let unexpected exceptions pass through |
131 | | - return False |
132 | | - if len(self.watcher.records) == 0: |
133 | | - __tracebackhide__ = True |
134 | | - assert 0, ("no logs of level {} or higher triggered on {}" |
135 | | - .format(logging.getLevelName(self.level), |
136 | | - self.logger.name)) |
137 | | - |
138 | | - |
139 | | -@pytest.yield_fixture |
140 | | -def warning(): |
141 | | - yield _AssertWarnsContext |
142 | | - |
143 | | - |
144 | | -@pytest.yield_fixture |
145 | | -def log(): |
146 | | - yield _AssertLogsContext |
147 | | - |
148 | | - |
149 | | -@pytest.fixture |
150 | | -def unused_port(): |
151 | | - def f(): |
152 | | - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: |
153 | | - s.bind(('127.0.0.1', 0)) |
154 | | - return s.getsockname()[1] |
155 | | - return f |
156 | | - |
157 | | - |
158 | | -@pytest.yield_fixture |
159 | | -def loop(request): |
160 | | - loop = asyncio.new_event_loop() |
161 | | - asyncio.set_event_loop(None) |
162 | | - |
163 | | - yield loop |
164 | | - |
165 | | - loop.stop() |
166 | | - loop.run_forever() |
167 | | - loop.close() |
168 | | - gc.collect() |
169 | | - asyncio.set_event_loop(None) |
170 | | - |
171 | | - |
172 | | -@pytest.yield_fixture |
173 | | -def create_server(loop, unused_port): |
174 | | - app = handler = srv = None |
175 | | - |
176 | | - @asyncio.coroutine |
177 | | - def create(*, debug=False, ssl_ctx=None, proto='http', **kwargs): |
178 | | - nonlocal app, handler, srv |
179 | | - app = web.Application(loop=loop, **kwargs) |
180 | | - port = unused_port() |
181 | | - handler = app.make_handler(debug=debug, keep_alive_on=False) |
182 | | - srv = yield from loop.create_server(handler, '127.0.0.1', port, |
183 | | - ssl=ssl_ctx) |
184 | | - if ssl_ctx: |
185 | | - proto += 's' |
186 | | - url = "{}://127.0.0.1:{}".format(proto, port) |
187 | | - return app, url |
188 | | - |
189 | | - yield create |
190 | | - |
191 | | - @asyncio.coroutine |
192 | | - def finish(): |
193 | | - yield from handler.finish_connections() |
194 | | - yield from app.finish() |
195 | | - srv.close() |
196 | | - yield from srv.wait_closed() |
197 | | - |
198 | | - loop.run_until_complete(finish()) |
199 | | - |
200 | | - |
201 | | -@pytest.mark.tryfirst |
202 | | -def pytest_pycollect_makeitem(collector, name, obj): |
203 | | - if collector.funcnamefilter(name): |
204 | | - if not callable(obj): |
205 | | - return |
206 | | - item = pytest.Function(name, parent=collector) |
207 | | - if 'run_loop' in item.keywords: |
208 | | - return list(collector._genfunctions(name, obj)) |
209 | | - |
210 | | - |
211 | | -@pytest.mark.tryfirst |
212 | | -def pytest_pyfunc_call(pyfuncitem): |
213 | | - """ |
214 | | - Run asyncio marked test functions in an event loop instead of a normal |
215 | | - function call. |
216 | | - """ |
217 | | - if 'run_loop' in pyfuncitem.keywords: |
218 | | - funcargs = pyfuncitem.funcargs |
219 | | - loop = funcargs['loop'] |
220 | | - testargs = {arg: funcargs[arg] |
221 | | - for arg in pyfuncitem._fixtureinfo.argnames} |
222 | | - loop.run_until_complete(pyfuncitem.obj(**testargs)) |
223 | | - return True |
224 | | - |
225 | | - |
226 | | -def pytest_runtest_setup(item): |
227 | | - if 'run_loop' in item.keywords and 'loop' not in item.fixturenames: |
228 | | - # inject an event loop fixture for all async tests |
229 | | - item.fixturenames.append('loop') |
0 commit comments