Skip to content

Commit ab794fe

Browse files
authored
Add "reinstall()" method to make it easier in spawn multiprocessing (#1069)
1 parent f64b9e3 commit ab794fe

File tree

2 files changed

+106
-0
lines changed

2 files changed

+106
-0
lines changed

loguru/_logger.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1812,6 +1812,40 @@ def configure(self, *, handlers=None, levels=None, extra=None, patcher=None, act
18121812

18131813
return [self.add(**params) for params in handlers]
18141814

1815+
def reinstall(self):
1816+
"""Reinstall the core of logger.
1817+
1818+
When using multiprocessing, you can pass logger as a parameter to the target of
1819+
``multiprocessing.Process``, and run this method once, thus you don't need to pass
1820+
logger to every function you called in the same process with spawn multiprocessing.
1821+
1822+
Examples
1823+
--------
1824+
>>> def subworker(logger):
1825+
... logger.reinstall()
1826+
... logger.info("Child")
1827+
... deeper_subworker()
1828+
1829+
>>> def deeper_subworker():
1830+
... logger.info("Grandchild")
1831+
1832+
>>> def test_process_spawn():
1833+
... spawn_context = multiprocessing.get_context("spawn")
1834+
... logger.add("file.log", context=spawn_context, enqueue=True, catch=False)
1835+
...
1836+
... process = spawn_context.Process(target=subworker, args=(logger,))
1837+
... process.start()
1838+
... process.join()
1839+
1840+
... assert process.exitcode == 0
1841+
1842+
... logger.info("Main")
1843+
... logger.remove()
1844+
"""
1845+
from loguru import logger
1846+
1847+
logger._core = self._core
1848+
18151849
def _change_activation(self, name, status):
18161850
if not (name is None or isinstance(name, str)):
18171851
raise TypeError(

tests/test_reinstall.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
import multiprocessing
2+
import os
3+
4+
import pytest
5+
6+
from loguru import logger
7+
8+
9+
@pytest.fixture
10+
def fork_context():
11+
yield multiprocessing.get_context("fork")
12+
13+
14+
@pytest.fixture
15+
def spawn_context():
16+
yield multiprocessing.get_context("spawn")
17+
18+
19+
class Writer:
20+
def __init__(self):
21+
self._output = ""
22+
23+
def write(self, message):
24+
self._output += message
25+
26+
def read(self):
27+
return self._output
28+
29+
30+
def subworker(logger):
31+
logger.reinstall()
32+
logger.info("Child")
33+
deeper_subworker()
34+
35+
36+
def deeper_subworker():
37+
logger.info("Grandchild")
38+
39+
40+
@pytest.mark.skipif(os.name == "nt", reason="Windows does not support forking")
41+
def test_process_fork(fork_context):
42+
writer = Writer()
43+
44+
logger.add(writer, context=fork_context, format="{message}", enqueue=True, catch=False)
45+
46+
process = fork_context.Process(target=subworker, args=(logger,))
47+
process.start()
48+
process.join()
49+
50+
assert process.exitcode == 0
51+
52+
logger.info("Main")
53+
logger.remove()
54+
55+
assert writer.read() == "Child\nGrandchild\nMain\n"
56+
57+
58+
def test_process_spawn(spawn_context):
59+
writer = Writer()
60+
61+
logger.add(writer, context=spawn_context, format="{message}", enqueue=True, catch=False)
62+
63+
process = spawn_context.Process(target=subworker, args=(logger,))
64+
process.start()
65+
process.join()
66+
67+
assert process.exitcode == 0
68+
69+
logger.info("Main")
70+
logger.remove()
71+
72+
assert writer.read() == "Child\nGrandchild\nMain\n"

0 commit comments

Comments
 (0)