Skip to content

Commit 793b817

Browse files
authored
Merge pull request #354 from dotsdl/repex-context-cleanup
Cleanup contexts upon calculation completion, failure
2 parents 48a3214 + b73dfc6 commit 793b817

File tree

2 files changed

+113
-80
lines changed

2 files changed

+113
-80
lines changed

openfe/protocols/openmm_rfe/_rfe_utils/multistate.py

Lines changed: 32 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,10 @@ def __init__(self, *args, hybrid_factory=None, **kwargs):
3434
self._hybrid_factory = hybrid_factory
3535
super(HybridCompatibilityMixin, self).__init__(*args, **kwargs)
3636

37-
def setup(self, reporter, platform, lambda_protocol,
37+
def setup(self, reporter, lambda_protocol,
3838
temperature=298.15 * unit.kelvin, n_replicas=None,
39-
endstates=True, minimization_steps=100):
39+
endstates=True, minimization_steps=100,
40+
minimization_platform="CPU"):
4041
"""
4142
Setup MultistateSampler based on the input lambda protocol and number
4243
of replicas.
@@ -45,8 +46,6 @@ def setup(self, reporter, platform, lambda_protocol,
4546
----------
4647
reporter : OpenMM reporter
4748
Simulation reporter to attach to each simulation replica.
48-
platform : openmm.Platform
49-
Platform to perform simulation on.
5049
lambda_protocol : LambdaProtocol
5150
The lambda protocol to be used for simulation. Default to a default
5251
class creation of LambdaProtocol.
@@ -60,7 +59,9 @@ class creation of LambdaProtocol.
6059
Whether or not to generate unsampled endstates (i.e. dispersion
6160
correction).
6261
minimization_steps : int
63-
Number of steps to minimize states.
62+
Number of steps to pre-minimize states.
63+
minimization_platform : str
64+
Platform to do the initial pre-minimization with.
6465
6566
Attributes
6667
----------
@@ -84,8 +85,6 @@ class creation of LambdaProtocol.
8485
thermodynamic_state_list = []
8586
sampler_state_list = []
8687

87-
context_cache = cache.ContextCache(platform)
88-
8988
if n_replicas is None:
9089
msg = (f"setting number of replicas to number of states: {n_states}")
9190
warnings.warn(msg)
@@ -118,12 +117,14 @@ class creation of LambdaProtocol.
118117

119118
# now generating a sampler_state for each thermodyanmic state,
120119
# with relaxed positions
121-
context, context_integrator = context_cache.get_context(
122-
compound_thermostate_copy)
120+
# Note: remove once choderalab/openmmtools#672 is completed
123121
minimize(compound_thermostate_copy, sampler_state,
124-
max_iterations=minimization_steps)
122+
max_iterations=minimization_steps,
123+
platform_name=minimization_platform)
125124
sampler_state_list.append(copy.deepcopy(sampler_state))
126125

126+
del compound_thermostate, sampler_state
127+
127128
# making sure number of sampler states equals n_replicas
128129
if len(sampler_state_list) != n_replicas:
129130
# picking roughly evenly spaced sampler states
@@ -261,8 +262,10 @@ def create_endstates(first_thermostate, last_thermostate):
261262
return unsampled_endstates
262263

263264

264-
def minimize(thermodynamic_state: states.ThermodynamicState, sampler_state: states.SamplerState,
265-
max_iterations: int=100) -> states.SamplerState:
265+
def minimize(thermodynamic_state: states.ThermodynamicState,
266+
sampler_state: states.SamplerState,
267+
max_iterations: int=100,
268+
platform_name: str="CPU") -> states.SamplerState:
266269
"""
267270
Adapted from perses.dispersed.feptasks.minimize
268271
@@ -277,20 +280,28 @@ def minimize(thermodynamic_state: states.ThermodynamicState, sampler_state: stat
277280
The starting state at which to minimize the system.
278281
max_iterations : int, optional, default 100
279282
The maximum number of minimization steps. Default is 100.
283+
platform_name : str
284+
The OpenMM platform name to carry out the minimization with.
280285
281286
Returns
282287
-------
283288
sampler_state : openmmtools.states.SamplerState
284289
The posititions and accompanying state following minimization
285290
"""
286-
integrator = openmm.VerletIntegrator(1.0) #we won't take any steps, so use a simple integrator
287-
context, integrator = cache.global_context_cache.get_context(
291+
# we won't take any steps, so use a simple integrator
292+
integrator = openmm.VerletIntegrator(1.0)
293+
platform = openmm.Platform.getPlatformByName(platform_name)
294+
dummy_cache = cache.DummyContextCache(platform=platform)
295+
context, integrator = dummy_cache.get_context(
288296
thermodynamic_state, integrator
289297
)
290-
sampler_state.apply_to_context(
291-
context, ignore_velocities=True
292-
)
293-
openmm.LocalEnergyMinimizer.minimize(
294-
context, maxIterations=max_iterations
295-
)
296-
sampler_state.update_from_context(context)
298+
try:
299+
sampler_state.apply_to_context(
300+
context, ignore_velocities=True
301+
)
302+
openmm.LocalEnergyMinimizer.minimize(
303+
context, maxIterations=max_iterations
304+
)
305+
sampler_state.update_from_context(context)
306+
finally:
307+
del context, integrator, dummy_cache

openfe/protocols/openmm_rfe/equil_rfe_methods.py

Lines changed: 81 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -323,20 +323,20 @@ def run(self, *, dry=False, verbose=True,
323323

324324
# a. check timestep correctness + that
325325
# equilibration & production are divisible by n_steps
326-
prototol_settings: RelativeHybridTopologyProtocolSettings = self._inputs['settings']
326+
protocol_settings: RelativeHybridTopologyProtocolSettings = self._inputs['settings']
327327
stateA = self._inputs['stateA']
328328
stateB = self._inputs['stateB']
329329
mapping = self._inputs['ligandmapping']
330330

331-
forcefield_settings: settings.OpenMMSystemGeneratorFFSettings = prototol_settings.forcefield_settings
332-
thermo_settings: settings.ThermoSettings = prototol_settings.thermo_settings
333-
alchem_settings: AlchemicalSettings = prototol_settings.alchemical_settings
334-
system_settings: SystemSettings = prototol_settings.system_settings
335-
solvation_settings: SolvationSettings = prototol_settings.solvation_settings
336-
sampler_settings: AlchemicalSamplerSettings = prototol_settings.alchemical_sampler_settings
337-
sim_settings: SimulationSettings = prototol_settings.simulation_settings
338-
timestep = prototol_settings.integrator_settings.timestep
339-
mc_steps = prototol_settings.integrator_settings.n_steps.m
331+
forcefield_settings: settings.OpenMMSystemGeneratorFFSettings = protocol_settings.forcefield_settings
332+
thermo_settings: settings.ThermoSettings = protocol_settings.thermo_settings
333+
alchem_settings: AlchemicalSettings = protocol_settings.alchemical_settings
334+
system_settings: SystemSettings = protocol_settings.system_settings
335+
solvation_settings: SolvationSettings = protocol_settings.solvation_settings
336+
sampler_settings: AlchemicalSamplerSettings = protocol_settings.alchemical_sampler_settings
337+
sim_settings: SimulationSettings = protocol_settings.simulation_settings
338+
timestep = protocol_settings.integrator_settings.timestep
339+
mc_steps = protocol_settings.integrator_settings.n_steps.m
340340

341341
# is the timestep good for the mass?
342342
if forcefield_settings.hydrogen_mass < 3.0:
@@ -536,9 +536,9 @@ def run(self, *, dry=False, verbose=True,
536536
if 'solvent' in stateA.components:
537537
hybrid_factory.hybrid_system.addForce(
538538
openmm.MonteCarloBarostat(
539-
prototol_settings.thermo_settings.pressure.to(unit.bar).m,
540-
prototol_settings.thermo_settings.temperature.m,
541-
prototol_settings.integrator_settings.barostat_frequency.m,
539+
protocol_settings.thermo_settings.pressure.to(unit.bar).m,
540+
protocol_settings.thermo_settings.temperature.m,
541+
protocol_settings.integrator_settings.barostat_frequency.m,
542542
)
543543
)
544544

@@ -572,24 +572,14 @@ def run(self, *, dry=False, verbose=True,
572572
checkpoint_storage=shared_basepath / sim_settings.checkpoint_storage,
573573
)
574574

575-
# 10. Get platform and context caches
575+
# 10. Get platform
576576
platform = _rfe_utils.compute.get_openmm_platform(
577-
prototol_settings.engine_settings.compute_platform
578-
)
579-
580-
# a. Create context caches (energy + sampler)
581-
# Note: these needs to exist on the compute node
582-
energy_context_cache = openmmtools.cache.ContextCache(
583-
capacity=None, time_to_live=None, platform=platform,
584-
)
585-
586-
sampler_context_cache = openmmtools.cache.ContextCache(
587-
capacity=None, time_to_live=None, platform=platform,
577+
protocol_settings.engine_settings.compute_platform
588578
)
589579

590580
# 11. Set the integrator
591581
# a. get integrator settings
592-
integrator_settings = prototol_settings.integrator_settings
582+
integrator_settings = protocol_settings.integrator_settings
593583

594584
# b. create langevin integrator
595585
integrator = openmmtools.mcmc.LangevinSplittingDynamicsMove(
@@ -635,59 +625,91 @@ def run(self, *, dry=False, verbose=True,
635625
sampler.setup(
636626
n_replicas=sampler_settings.n_replicas,
637627
reporter=reporter,
638-
platform=platform,
639628
lambda_protocol=lambdas,
640-
temperature=to_openmm(prototol_settings.thermo_settings.temperature),
629+
temperature=to_openmm(protocol_settings.thermo_settings.temperature),
641630
endstates=alchem_settings.unsampled_endstates,
631+
minimization_platform=platform.getName(),
642632
)
643633

644-
sampler.energy_context_cache = energy_context_cache
645-
sampler.sampler_context_cache = sampler_context_cache
634+
try:
635+
# Create context caches (energy + sampler)
636+
energy_context_cache = openmmtools.cache.ContextCache(
637+
capacity=None, time_to_live=None, platform=platform,
638+
)
646639

647-
if not dry: # pragma: no-cover
648-
# minimize
649-
if verbose:
650-
logger.info("minimizing systems")
640+
sampler_context_cache = openmmtools.cache.ContextCache(
641+
capacity=None, time_to_live=None, platform=platform,
642+
)
651643

652-
sampler.minimize(max_iterations=sim_settings.minimization_steps)
644+
sampler.energy_context_cache = energy_context_cache
645+
sampler.sampler_context_cache = sampler_context_cache
653646

654-
# equilibrate
655-
if verbose:
656-
logger.info("equilibrating systems")
647+
if not dry: # pragma: no-cover
648+
# minimize
649+
if verbose:
650+
logger.info("minimizing systems")
657651

658-
sampler.equilibrate(int(equil_steps.m / mc_steps)) # type: ignore
652+
sampler.minimize(max_iterations=sim_settings.minimization_steps)
659653

660-
# production
661-
if verbose:
662-
logger.info("running production phase")
654+
# equilibrate
655+
if verbose:
656+
logger.info("equilibrating systems")
663657

664-
sampler.extend(int(prod_steps.m / mc_steps)) # type: ignore
658+
sampler.equilibrate(int(equil_steps.m / mc_steps)) # type: ignore
665659

666-
# calculate estimate of results from this individual unit
667-
ana = multistate.MultiStateSamplerAnalyzer(reporter)
668-
est, _ = ana.get_free_energy()
669-
est = (est[0, -1] * ana.kT).in_units_of(omm_unit.kilocalories_per_mole)
670-
est = ensure_quantity(est, 'openff')
660+
# production
661+
if verbose:
662+
logger.info("running production phase")
671663

672-
# close reporter when you're done
664+
sampler.extend(int(prod_steps.m / mc_steps)) # type: ignore
665+
666+
# calculate estimate of results from this individual unit
667+
ana = multistate.MultiStateSamplerAnalyzer(reporter)
668+
est, _ = ana.get_free_energy()
669+
est = (est[0, -1] * ana.kT).in_units_of(omm_unit.kilocalories_per_mole)
670+
est = ensure_quantity(est, 'openff')
671+
672+
nc = shared_basepath / sim_settings.output_filename
673+
chk = shared_basepath / sim_settings.checkpoint_storage
674+
else:
675+
# clean up the reporter file
676+
fns = [shared_basepath / sim_settings.output_filename,
677+
shared_basepath / sim_settings.checkpoint_storage]
678+
for fn in fns:
679+
os.remove(fn)
680+
finally:
681+
# close reporter when you're done, prevent file handle clashes
673682
reporter.close()
683+
del reporter
684+
685+
# clean up the analyzer
686+
if not dry:
687+
ana.clear()
688+
del ana
689+
690+
# clear GPU contexts
691+
# TODO: use cache.empty() calls when openmmtools #690 is resolved
692+
# replace with above
693+
for context in list(energy_context_cache._lru._data.keys()):
694+
del energy_context_cache._lru._data[context]
695+
for context in list(sampler_context_cache._lru._data.keys()):
696+
del sampler_context_cache._lru._data[context]
697+
# cautiously clear out the global context cache too
698+
for context in list(
699+
openmmtools.cache.global_context_cache._lru._data.keys()):
700+
del openmmtools.cache.global_context_cache._lru._data[context]
701+
702+
del sampler_context_cache, energy_context_cache
703+
if not dry:
704+
del integrator, sampler
674705

675-
nc = shared_basepath / sim_settings.output_filename
676-
chk = shared_basepath / sim_settings.checkpoint_storage
706+
if not dry: # pragma: no-cover
677707
return {
678708
'nc': nc,
679709
'last_checkpoint': chk,
680710
'unit_estimate': est,
681711
}
682712
else:
683-
# close reporter when you're done, prevent file handle clashes
684-
reporter.close()
685-
686-
# clean up the reporter file
687-
fns = [shared_basepath / sim_settings.output_filename,
688-
shared_basepath / sim_settings.checkpoint_storage]
689-
for fn in fns:
690-
os.remove(fn)
691713
return {'debug': {'sampler': sampler}}
692714

693715
def _execute(

0 commit comments

Comments
 (0)