Skip to content

Commit

Permalink
Patch loops to copy context on task creation.
Browse files Browse the repository at this point in the history
  • Loading branch information
fantix committed Aug 29, 2018
1 parent 278ad10 commit b1dca79
Show file tree
Hide file tree
Showing 2 changed files with 153 additions and 0 deletions.
59 changes: 59 additions & 0 deletions contextvars/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import asyncio
import collections.abc
import threading
import types

import immutables

Expand Down Expand Up @@ -209,3 +210,61 @@ def _get_state():


_state = threading.local()


def create_task(loop, coro):
task = loop._orig_create_task(coro)
if task._source_traceback:
del task._source_traceback[-1]
task.context = copy_context()
return task


def _patch_loop(loop):
if not hasattr(loop, '_orig_create_task'):
loop._orig_create_task = loop.create_task
loop.create_task = types.MethodType(create_task, loop)
return loop


def get_event_loop(policy):
return _patch_loop(policy._orig_methods[0]())


def set_event_loop(policy, loop):
return policy._orig_methods[1](_patch_loop(loop))


def new_event_loop(policy):
return _patch_loop(policy._orig_methods[2]())


def _patch_policy(policy):
if not hasattr(policy, '_orig_methods'):
policy._orig_methods = (
policy.get_event_loop,
policy.set_event_loop,
policy.new_event_loop,
)
policy.get_event_loop = types.MethodType(get_event_loop, policy)
policy.set_event_loop = types.MethodType(set_event_loop, policy)
policy.new_event_loop = types.MethodType(new_event_loop, policy)
return policy


_orig_getter = asyncio.events.get_event_loop_policy
_orig_setter = asyncio.events.set_event_loop_policy


def get_event_loop_policy():
return _patch_policy(_orig_getter())


def set_event_loop_policy(policy):
return _orig_setter(_patch_policy(policy))


asyncio.events.get_event_loop_policy = get_event_loop_policy
asyncio.events.set_event_loop_policy = set_event_loop_policy
asyncio.get_event_loop_policy = get_event_loop_policy
asyncio.set_event_loop_policy = set_event_loop_policy
94 changes: 94 additions & 0 deletions tests/test_tasks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
# Copied from https://git.io/fAGgA with small updates

import asyncio
import contextvars
import random
import unittest


class TaskTests(unittest.TestCase):
def test_context_1(self):
cvar = contextvars.ContextVar('cvar')

async def sub():
await asyncio.sleep(0.01, loop=loop)
self.assertEqual(cvar.get(), 'nope')
cvar.set('something else')

async def main():
cvar.set('nope')
self.assertEqual(cvar.get(), 'nope')
subtask = loop.create_task(sub())
cvar.set('yes')
self.assertEqual(cvar.get(), 'yes')
await subtask
self.assertEqual(cvar.get(), 'yes')

loop = asyncio.new_event_loop()
try:
loop.run_until_complete(main())
finally:
loop.close()

def test_context_2(self):
cvar = contextvars.ContextVar('cvar', default='nope')

async def main():
def fut_on_done(fut):
# This change must not pollute the context
# of the "main()" task.
cvar.set('something else')

self.assertEqual(cvar.get(), 'nope')

for j in range(2):
fut = loop.create_future()
ctx = contextvars.copy_context()
fut.add_done_callback(lambda f: ctx.run(fut_on_done, f))
cvar.set('yes{}'.format(j))
loop.call_soon(fut.set_result, None)
await fut
self.assertEqual(cvar.get(), 'yes{}'.format(j))

for i in range(3):
# Test that task passed its context to add_done_callback:
cvar.set('yes{}-{}'.format(i, j))
await asyncio.sleep(0.001, loop=loop)
self.assertEqual(cvar.get(), 'yes{}-{}'.format(i, j))

loop = asyncio.new_event_loop()
try:
task = loop.create_task(main())
loop.run_until_complete(task)
finally:
loop.close()

self.assertEqual(cvar.get(), 'nope')

def test_context_3(self):
# Run 100 Tasks in parallel, each modifying cvar.

cvar = contextvars.ContextVar('cvar', default=-1)

async def sub(num):
for i in range(10):
cvar.set(num + i)
await asyncio.sleep(
random.uniform(0.001, 0.05), loop=loop)
self.assertEqual(cvar.get(), num + i)

async def main():
tasks = []
for i in range(100):
task = loop.create_task(sub(random.randint(0, 10)))
tasks.append(task)

await asyncio.gather(*tasks, loop=loop)

loop = asyncio.new_event_loop()
try:
loop.run_until_complete(main())
finally:
loop.close()

self.assertEqual(cvar.get(), -1)

0 comments on commit b1dca79

Please sign in to comment.