From b1dca79872382a53ad3afe048eaf4646b5057a85 Mon Sep 17 00:00:00 2001 From: Fantix King Date: Wed, 29 Aug 2018 18:41:55 +0800 Subject: [PATCH] Patch loops to copy context on task creation. --- contextvars/__init__.py | 59 ++++++++++++++++++++++++++ tests/test_tasks.py | 94 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 153 insertions(+) create mode 100644 tests/test_tasks.py diff --git a/contextvars/__init__.py b/contextvars/__init__.py index 1107627..b6d84fe 100644 --- a/contextvars/__init__.py +++ b/contextvars/__init__.py @@ -1,6 +1,7 @@ import asyncio import collections.abc import threading +import types import immutables @@ -209,3 +210,61 @@ def _get_state(): _state = threading.local() + + +def create_task(loop, coro): + task = loop._orig_create_task(coro) + if task._source_traceback: + del task._source_traceback[-1] + task.context = copy_context() + return task + + +def _patch_loop(loop): + if not hasattr(loop, '_orig_create_task'): + loop._orig_create_task = loop.create_task + loop.create_task = types.MethodType(create_task, loop) + return loop + + +def get_event_loop(policy): + return _patch_loop(policy._orig_methods[0]()) + + +def set_event_loop(policy, loop): + return policy._orig_methods[1](_patch_loop(loop)) + + +def new_event_loop(policy): + return _patch_loop(policy._orig_methods[2]()) + + +def _patch_policy(policy): + if not hasattr(policy, '_orig_methods'): + policy._orig_methods = ( + policy.get_event_loop, + policy.set_event_loop, + policy.new_event_loop, + ) + policy.get_event_loop = types.MethodType(get_event_loop, policy) + policy.set_event_loop = types.MethodType(set_event_loop, policy) + policy.new_event_loop = types.MethodType(new_event_loop, policy) + return policy + + +_orig_getter = asyncio.events.get_event_loop_policy +_orig_setter = asyncio.events.set_event_loop_policy + + +def get_event_loop_policy(): + return _patch_policy(_orig_getter()) + + +def set_event_loop_policy(policy): + return _orig_setter(_patch_policy(policy)) + + +asyncio.events.get_event_loop_policy = get_event_loop_policy +asyncio.events.set_event_loop_policy = set_event_loop_policy +asyncio.get_event_loop_policy = get_event_loop_policy +asyncio.set_event_loop_policy = set_event_loop_policy diff --git a/tests/test_tasks.py b/tests/test_tasks.py new file mode 100644 index 0000000..6eab3e5 --- /dev/null +++ b/tests/test_tasks.py @@ -0,0 +1,94 @@ +# Copied from https://git.io/fAGgA with small updates + +import asyncio +import contextvars +import random +import unittest + + +class TaskTests(unittest.TestCase): + def test_context_1(self): + cvar = contextvars.ContextVar('cvar') + + async def sub(): + await asyncio.sleep(0.01, loop=loop) + self.assertEqual(cvar.get(), 'nope') + cvar.set('something else') + + async def main(): + cvar.set('nope') + self.assertEqual(cvar.get(), 'nope') + subtask = loop.create_task(sub()) + cvar.set('yes') + self.assertEqual(cvar.get(), 'yes') + await subtask + self.assertEqual(cvar.get(), 'yes') + + loop = asyncio.new_event_loop() + try: + loop.run_until_complete(main()) + finally: + loop.close() + + def test_context_2(self): + cvar = contextvars.ContextVar('cvar', default='nope') + + async def main(): + def fut_on_done(fut): + # This change must not pollute the context + # of the "main()" task. + cvar.set('something else') + + self.assertEqual(cvar.get(), 'nope') + + for j in range(2): + fut = loop.create_future() + ctx = contextvars.copy_context() + fut.add_done_callback(lambda f: ctx.run(fut_on_done, f)) + cvar.set('yes{}'.format(j)) + loop.call_soon(fut.set_result, None) + await fut + self.assertEqual(cvar.get(), 'yes{}'.format(j)) + + for i in range(3): + # Test that task passed its context to add_done_callback: + cvar.set('yes{}-{}'.format(i, j)) + await asyncio.sleep(0.001, loop=loop) + self.assertEqual(cvar.get(), 'yes{}-{}'.format(i, j)) + + loop = asyncio.new_event_loop() + try: + task = loop.create_task(main()) + loop.run_until_complete(task) + finally: + loop.close() + + self.assertEqual(cvar.get(), 'nope') + + def test_context_3(self): + # Run 100 Tasks in parallel, each modifying cvar. + + cvar = contextvars.ContextVar('cvar', default=-1) + + async def sub(num): + for i in range(10): + cvar.set(num + i) + await asyncio.sleep( + random.uniform(0.001, 0.05), loop=loop) + self.assertEqual(cvar.get(), num + i) + + async def main(): + tasks = [] + for i in range(100): + task = loop.create_task(sub(random.randint(0, 10))) + tasks.append(task) + + await asyncio.gather(*tasks, loop=loop) + + loop = asyncio.new_event_loop() + try: + loop.run_until_complete(main()) + finally: + loop.close() + + self.assertEqual(cvar.get(), -1)