Skip to content

Commit

Permalink
Add TemporaryEnvironment context manager
Browse files Browse the repository at this point in the history
This class sets and restores environment variables within the context.
I'd like to use this in pyiron_base, but thought it'd be nice here as
well.
  • Loading branch information
pmrv committed Jun 18, 2024
1 parent ce59a49 commit 6d9a089
Show file tree
Hide file tree
Showing 2 changed files with 109 additions and 0 deletions.
40 changes: 40 additions & 0 deletions pyiron_snippets/tempenv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import os
import contextlib

@contextlib.contextmanager
def TemporaryEnvironment(**kwargs):
"""
Context manager for temporarily setting environment variables.
Takes any number of keyword arguments where the key is the environment
variable to set and the value is the value to set it to. For the duration
of the context, the environment variables are set as per the provided arguments.
The original environment setting is restored once the context is exited,
even if an exception is raised within the context.
Non-string values are coerced with `str()`.
Can also be used as a function decorator.
Examples:
>>> with TemporaryEnvironment(PATH='/tmp', HOME='/', USER='foobar'):
... print(os.getenv('PATH')) # Outputs: /tmp
... print(os.getenv('HOME')) # Outputs: /
... print(os.getenv('USER')) # Outputs: foobar
"""
old_vars = {}
for k, v in kwargs.items():
try:
old_vars[k] = os.environ[k]
except KeyError:
pass
os.environ[k] = str(v)
try:
yield
finally:
for k, v in kwargs.items():
if k in old_vars:
os.environ[k] = v
else:
del os.environ[k]
69 changes: 69 additions & 0 deletions tests/unit/test_tempenv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import os
import unittest
from pyiron_base.utils.tempenv import TemporaryEnvironment

class TestTemporaryEnvironment(unittest.TestCase):
"""
Class to test TemporaryEnvironment context manager.
"""

def setUp(self):
"""Ensures the original environment is kept intact before each test."""
self.old_environ = dict(os.environ)

def tearDown(self):
"""Ensures the original environment is restored after each test."""
os.environ.clear()
os.environ.update(self.old_environ)

def test_value_int(self):
"""Should correctly convert and store integer values as strings."""
with TemporaryEnvironment(FOO=12):
self.assertEqual(os.environ.get('FOO'), '12', "Failed to convert integer to string")

def test_value_bool(self):
"""Should correctly convert and store boolean values as strings."""
with TemporaryEnvironment(FOO=True):
self.assertEqual(os.environ.get('FOO'), 'True', "Failed to convert boolean to string")

def test_environment_set(self):
"""Should correctly set environment variables inside its scope."""
with TemporaryEnvironment(FOO='1', BAR='2'):
self.assertEqual(os.environ.get('FOO'), '1', "Failed to set FOO")
self.assertEqual(os.environ.get('BAR'), '2', "Failed to set BAR")

def test_environment_restored(self):
"""Should restore original environment variables state after leaving its scope."""
os.environ['FOO'] = '0'
with TemporaryEnvironment(FOO='1'):
self.assertEqual(os.environ.get('FOO'), '1')
self.assertEqual(os.environ.get('FOO'), '0', "Failed to restore original FOO value")

def test_environment_deleted(self):
"""Should correctly delete environment variables that didn't exist originally after leaving its scope."""
with TemporaryEnvironment(FOO='1'):
self.assertEqual(os.environ.get('FOO'), '1')
self.assertIsNone(os.environ.get('FOO'), "Failed to delete FOO")

def test_raise_exception(self):
"""Should restore original environment after handling an exception within its scope."""
os.environ['FOO'] = '0'
try:
with TemporaryEnvironment(FOO='1'):
self.assertEqual(os.environ.get('FOO'), '1')
raise Exception('Some Error')
except:
self.assertEqual(os.environ.get('FOO'), '0', "Failed to restore environment after exception")

def test_as_decorator(self):
"""Should correctly set and restore environment variables when used as a decorator."""
@TemporaryEnvironment(FOO='1')
def test_func():
return os.environ.get('FOO')

self.assertEqual(test_func(), '1', "Failed to set environment as decorator")
self.assertEqual(os.environ.get('FOO', None), None, "Failed to restore environment as decorator")


if __name__ == "__main__":
unittest.main()

0 comments on commit 6d9a089

Please sign in to comment.