@@ -323,20 +323,20 @@ def run(self, *, dry=False, verbose=True,
323323
324324 # a. check timestep correctness + that
325325 # equilibration & production are divisible by n_steps
326- prototol_settings : RelativeHybridTopologyProtocolSettings = self ._inputs ['settings' ]
326+ protocol_settings : RelativeHybridTopologyProtocolSettings = self ._inputs ['settings' ]
327327 stateA = self ._inputs ['stateA' ]
328328 stateB = self ._inputs ['stateB' ]
329329 mapping = self ._inputs ['ligandmapping' ]
330330
331- forcefield_settings : settings .OpenMMSystemGeneratorFFSettings = prototol_settings .forcefield_settings
332- thermo_settings : settings .ThermoSettings = prototol_settings .thermo_settings
333- alchem_settings : AlchemicalSettings = prototol_settings .alchemical_settings
334- system_settings : SystemSettings = prototol_settings .system_settings
335- solvation_settings : SolvationSettings = prototol_settings .solvation_settings
336- sampler_settings : AlchemicalSamplerSettings = prototol_settings .alchemical_sampler_settings
337- sim_settings : SimulationSettings = prototol_settings .simulation_settings
338- timestep = prototol_settings .integrator_settings .timestep
339- mc_steps = prototol_settings .integrator_settings .n_steps .m
331+ forcefield_settings : settings .OpenMMSystemGeneratorFFSettings = protocol_settings .forcefield_settings
332+ thermo_settings : settings .ThermoSettings = protocol_settings .thermo_settings
333+ alchem_settings : AlchemicalSettings = protocol_settings .alchemical_settings
334+ system_settings : SystemSettings = protocol_settings .system_settings
335+ solvation_settings : SolvationSettings = protocol_settings .solvation_settings
336+ sampler_settings : AlchemicalSamplerSettings = protocol_settings .alchemical_sampler_settings
337+ sim_settings : SimulationSettings = protocol_settings .simulation_settings
338+ timestep = protocol_settings .integrator_settings .timestep
339+ mc_steps = protocol_settings .integrator_settings .n_steps .m
340340
341341 # is the timestep good for the mass?
342342 if forcefield_settings .hydrogen_mass < 3.0 :
@@ -536,9 +536,9 @@ def run(self, *, dry=False, verbose=True,
536536 if 'solvent' in stateA .components :
537537 hybrid_factory .hybrid_system .addForce (
538538 openmm .MonteCarloBarostat (
539- prototol_settings .thermo_settings .pressure .to (unit .bar ).m ,
540- prototol_settings .thermo_settings .temperature .m ,
541- prototol_settings .integrator_settings .barostat_frequency .m ,
539+ protocol_settings .thermo_settings .pressure .to (unit .bar ).m ,
540+ protocol_settings .thermo_settings .temperature .m ,
541+ protocol_settings .integrator_settings .barostat_frequency .m ,
542542 )
543543 )
544544
@@ -572,24 +572,14 @@ def run(self, *, dry=False, verbose=True,
572572 checkpoint_storage = shared_basepath / sim_settings .checkpoint_storage ,
573573 )
574574
575- # 10. Get platform and context caches
575+ # 10. Get platform
576576 platform = _rfe_utils .compute .get_openmm_platform (
577- prototol_settings .engine_settings .compute_platform
578- )
579-
580- # a. Create context caches (energy + sampler)
581- # Note: these needs to exist on the compute node
582- energy_context_cache = openmmtools .cache .ContextCache (
583- capacity = None , time_to_live = None , platform = platform ,
584- )
585-
586- sampler_context_cache = openmmtools .cache .ContextCache (
587- capacity = None , time_to_live = None , platform = platform ,
577+ protocol_settings .engine_settings .compute_platform
588578 )
589579
590580 # 11. Set the integrator
591581 # a. get integrator settings
592- integrator_settings = prototol_settings .integrator_settings
582+ integrator_settings = protocol_settings .integrator_settings
593583
594584 # b. create langevin integrator
595585 integrator = openmmtools .mcmc .LangevinSplittingDynamicsMove (
@@ -635,59 +625,91 @@ def run(self, *, dry=False, verbose=True,
635625 sampler .setup (
636626 n_replicas = sampler_settings .n_replicas ,
637627 reporter = reporter ,
638- platform = platform ,
639628 lambda_protocol = lambdas ,
640- temperature = to_openmm (prototol_settings .thermo_settings .temperature ),
629+ temperature = to_openmm (protocol_settings .thermo_settings .temperature ),
641630 endstates = alchem_settings .unsampled_endstates ,
631+ minimization_platform = platform .getName (),
642632 )
643633
644- sampler .energy_context_cache = energy_context_cache
645- sampler .sampler_context_cache = sampler_context_cache
634+ try :
635+ # Create context caches (energy + sampler)
636+ energy_context_cache = openmmtools .cache .ContextCache (
637+ capacity = None , time_to_live = None , platform = platform ,
638+ )
646639
647- if not dry : # pragma: no-cover
648- # minimize
649- if verbose :
650- logger .info ("minimizing systems" )
640+ sampler_context_cache = openmmtools .cache .ContextCache (
641+ capacity = None , time_to_live = None , platform = platform ,
642+ )
651643
652- sampler .minimize (max_iterations = sim_settings .minimization_steps )
644+ sampler .energy_context_cache = energy_context_cache
645+ sampler .sampler_context_cache = sampler_context_cache
653646
654- # equilibrate
655- if verbose :
656- logger .info ("equilibrating systems" )
647+ if not dry : # pragma: no-cover
648+ # minimize
649+ if verbose :
650+ logger .info ("minimizing systems" )
657651
658- sampler .equilibrate ( int ( equil_steps . m / mc_steps )) # type: ignore
652+ sampler .minimize ( max_iterations = sim_settings . minimization_steps )
659653
660- # production
661- if verbose :
662- logger .info ("running production phase " )
654+ # equilibrate
655+ if verbose :
656+ logger .info ("equilibrating systems " )
663657
664- sampler .extend (int (prod_steps .m / mc_steps )) # type: ignore
658+ sampler .equilibrate (int (equil_steps .m / mc_steps )) # type: ignore
665659
666- # calculate estimate of results from this individual unit
667- ana = multistate .MultiStateSamplerAnalyzer (reporter )
668- est , _ = ana .get_free_energy ()
669- est = (est [0 , - 1 ] * ana .kT ).in_units_of (omm_unit .kilocalories_per_mole )
670- est = ensure_quantity (est , 'openff' )
660+ # production
661+ if verbose :
662+ logger .info ("running production phase" )
671663
672- # close reporter when you're done
664+ sampler .extend (int (prod_steps .m / mc_steps )) # type: ignore
665+
666+ # calculate estimate of results from this individual unit
667+ ana = multistate .MultiStateSamplerAnalyzer (reporter )
668+ est , _ = ana .get_free_energy ()
669+ est = (est [0 , - 1 ] * ana .kT ).in_units_of (omm_unit .kilocalories_per_mole )
670+ est = ensure_quantity (est , 'openff' )
671+
672+ nc = shared_basepath / sim_settings .output_filename
673+ chk = shared_basepath / sim_settings .checkpoint_storage
674+ else :
675+ # clean up the reporter file
676+ fns = [shared_basepath / sim_settings .output_filename ,
677+ shared_basepath / sim_settings .checkpoint_storage ]
678+ for fn in fns :
679+ os .remove (fn )
680+ finally :
681+ # close reporter when you're done, prevent file handle clashes
673682 reporter .close ()
683+ del reporter
684+
685+ # clean up the analyzer
686+ if not dry :
687+ ana .clear ()
688+ del ana
689+
690+ # clear GPU contexts
691+ # TODO: use cache.empty() calls when openmmtools #690 is resolved
692+ # replace with above
693+ for context in list (energy_context_cache ._lru ._data .keys ()):
694+ del energy_context_cache ._lru ._data [context ]
695+ for context in list (sampler_context_cache ._lru ._data .keys ()):
696+ del sampler_context_cache ._lru ._data [context ]
697+ # cautiously clear out the global context cache too
698+ for context in list (
699+ openmmtools .cache .global_context_cache ._lru ._data .keys ()):
700+ del openmmtools .cache .global_context_cache ._lru ._data [context ]
701+
702+ del sampler_context_cache , energy_context_cache
703+ if not dry :
704+ del integrator , sampler
674705
675- nc = shared_basepath / sim_settings .output_filename
676- chk = shared_basepath / sim_settings .checkpoint_storage
706+ if not dry : # pragma: no-cover
677707 return {
678708 'nc' : nc ,
679709 'last_checkpoint' : chk ,
680710 'unit_estimate' : est ,
681711 }
682712 else :
683- # close reporter when you're done, prevent file handle clashes
684- reporter .close ()
685-
686- # clean up the reporter file
687- fns = [shared_basepath / sim_settings .output_filename ,
688- shared_basepath / sim_settings .checkpoint_storage ]
689- for fn in fns :
690- os .remove (fn )
691713 return {'debug' : {'sampler' : sampler }}
692714
693715 def _execute (
0 commit comments