generated from pyiron/pyiron_module_template
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add TemporaryEnvironment context manager
This class sets and restores environment variables within the context. I'd like to use this in pyiron_base, but thought it'd be nice here as well.
- Loading branch information
Showing
2 changed files
with
109 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
import os | ||
import contextlib | ||
|
||
@contextlib.contextmanager | ||
def TemporaryEnvironment(**kwargs): | ||
""" | ||
Context manager for temporarily setting environment variables. | ||
Takes any number of keyword arguments where the key is the environment | ||
variable to set and the value is the value to set it to. For the duration | ||
of the context, the environment variables are set as per the provided arguments. | ||
The original environment setting is restored once the context is exited, | ||
even if an exception is raised within the context. | ||
Non-string values are coerced with `str()`. | ||
Can also be used as a function decorator. | ||
Examples: | ||
>>> with TemporaryEnvironment(PATH='/tmp', HOME='/', USER='foobar'): | ||
... print(os.getenv('PATH')) # Outputs: /tmp | ||
... print(os.getenv('HOME')) # Outputs: / | ||
... print(os.getenv('USER')) # Outputs: foobar | ||
""" | ||
old_vars = {} | ||
for k, v in kwargs.items(): | ||
try: | ||
old_vars[k] = os.environ[k] | ||
except KeyError: | ||
pass | ||
os.environ[k] = str(v) | ||
try: | ||
yield | ||
finally: | ||
for k, v in kwargs.items(): | ||
if k in old_vars: | ||
os.environ[k] = v | ||
else: | ||
del os.environ[k] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
import os | ||
import unittest | ||
from pyiron_base.utils.tempenv import TemporaryEnvironment | ||
|
||
class TestTemporaryEnvironment(unittest.TestCase): | ||
""" | ||
Class to test TemporaryEnvironment context manager. | ||
""" | ||
|
||
def setUp(self): | ||
"""Ensures the original environment is kept intact before each test.""" | ||
self.old_environ = dict(os.environ) | ||
|
||
def tearDown(self): | ||
"""Ensures the original environment is restored after each test.""" | ||
os.environ.clear() | ||
os.environ.update(self.old_environ) | ||
|
||
def test_value_int(self): | ||
"""Should correctly convert and store integer values as strings.""" | ||
with TemporaryEnvironment(FOO=12): | ||
self.assertEqual(os.environ.get('FOO'), '12', "Failed to convert integer to string") | ||
|
||
def test_value_bool(self): | ||
"""Should correctly convert and store boolean values as strings.""" | ||
with TemporaryEnvironment(FOO=True): | ||
self.assertEqual(os.environ.get('FOO'), 'True', "Failed to convert boolean to string") | ||
|
||
def test_environment_set(self): | ||
"""Should correctly set environment variables inside its scope.""" | ||
with TemporaryEnvironment(FOO='1', BAR='2'): | ||
self.assertEqual(os.environ.get('FOO'), '1', "Failed to set FOO") | ||
self.assertEqual(os.environ.get('BAR'), '2', "Failed to set BAR") | ||
|
||
def test_environment_restored(self): | ||
"""Should restore original environment variables state after leaving its scope.""" | ||
os.environ['FOO'] = '0' | ||
with TemporaryEnvironment(FOO='1'): | ||
self.assertEqual(os.environ.get('FOO'), '1') | ||
self.assertEqual(os.environ.get('FOO'), '0', "Failed to restore original FOO value") | ||
|
||
def test_environment_deleted(self): | ||
"""Should correctly delete environment variables that didn't exist originally after leaving its scope.""" | ||
with TemporaryEnvironment(FOO='1'): | ||
self.assertEqual(os.environ.get('FOO'), '1') | ||
self.assertIsNone(os.environ.get('FOO'), "Failed to delete FOO") | ||
|
||
def test_raise_exception(self): | ||
"""Should restore original environment after handling an exception within its scope.""" | ||
os.environ['FOO'] = '0' | ||
try: | ||
with TemporaryEnvironment(FOO='1'): | ||
self.assertEqual(os.environ.get('FOO'), '1') | ||
raise Exception('Some Error') | ||
except: | ||
self.assertEqual(os.environ.get('FOO'), '0', "Failed to restore environment after exception") | ||
|
||
def test_as_decorator(self): | ||
"""Should correctly set and restore environment variables when used as a decorator.""" | ||
@TemporaryEnvironment(FOO='1') | ||
def test_func(): | ||
return os.environ.get('FOO') | ||
|
||
self.assertEqual(test_func(), '1', "Failed to set environment as decorator") | ||
self.assertEqual(os.environ.get('FOO', None), None, "Failed to restore environment as decorator") | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |