-
Notifications
You must be signed in to change notification settings - Fork 25
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP] Free energy fitting #54
base: main
Are you sure you want to change the base?
Conversation
… to 79 character lines oops, my IDE had been on 120 character lines this whole time -- switching to 79 character lines so future black passes don't turn things into quadruply nested messes
…with a graph-net in the loop it's fishy that the initial hydration free energy prediction is so poor i suspect i may have made a unit mistake in my numpy/jax --> pytorch port
…another function that converts to espaloma unit system
…ssertions Thanks to @yuanqing-wang for carefully stepping through this with me Co-Authored-By: Yuanqing Wang <[email protected]>
…et using parsley 1.2
the line to compute the reduced work was written as if "solv_energies" was "valence_energies + gbsa_energies" but of course it was just "gbsa_energies"...
in column `quick_xyz` -- will shortly replace this with a column `xyz` with more thorough parsley 1.2 vacuum sampling
… negative delta g predictions
f = torch.sqrt(r ** 2 + torch.ger(B, B) * torch.exp( | ||
-r ** 2 / (4 * torch.ger(B, B)))) | ||
charge_products = torch.ger(charges, charges) | ||
assert (f.shape == (N, N)) | ||
assert (charge_products.shape == (N, N)) | ||
|
||
ixns = - ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Isn't this missing a -138.935485 conversion from nm/(proton_charge**2) to kJ/mol? The docstring says "everything is in OpenMM native units".
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks like you might pre-multiply the charges by sqrt(138.935485)
? If so, you should probably document that in the docstring.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Gahh -- you're right -- I had dropped this in the current conversion! Thank you for catching this. Charges are not assumed to be premultiplied by sqrt(138.935485)
, will clarify docstring...
(This conversion was present but poorly labeled in the numpy/jax implementation in bayes-implicit-solvent.)
thanks to @jchodera for spotting this! #54 (comment) Co-Authored-By: John Chodera <[email protected]>
* confirmed that it can overfit to a small subset of FreeSolv! 🎉 * RMSE on whole of freesolv hasn't yet matched the quality of OBC2 🙏
…ze: validation set RMSE 1.8kcal.mol * increased stepsize 1e-3 rather than 1e-4 * decreased layer and node dimensions from 128 to 32
Incorporating the missing unit conversion John identified, this now appears to be passing integration checks in the demo notebook. A graph-net is used to emit Can this procedure overfit a GBSA-parameter-emitting graph-net to a small random subsample of FreeSolv (N=10)? Yes: Can this procedure fit a graph-net to a random half of FreeSolv (N=321) and generalize to the other half (N=321)? Tentatively yes: A few more important refinements and unit tests are needed before this is ready for final review and merge, but this is now significantly less fishy than it was yesterday. |
+ a few formatting and documentation enhancements
* define random seed once at top of file rather than before each step * remove verbose flag * use same learning rate, n_iterations, n_mols_per_batch, n_snapshots_per_mol for both trajectories
element # of molecules containing it C 639 H 629 O 344 N 169 Cl 114 S 40 F 35 Br 25 P 14 I 12
* add training / validation curves for early-stopping * add bootstrapped rmses to final scatterplots
To address concern about elements that appear only a handful of times in FreeSolv, see this notebook counting the number of molecules in FreeSolv containing each element.
A related question is: if we filter the molecules to retain only certain subsets of elements, how many molecules do we retain? Enumerating one sequence of element subsets (including elements in descending order of "popularity"):
A not-so-challenging subset of FreeSolv -- that should be free of the infrequently-occurring-element concern -- is the collection of molecules containing only {C, H, O}. This demo notebook fits a GB-parameter-emitting graph-net on this set in about 40 CPU minutes. Training and validation RMSE are reported every epoch for this "mini-Freesolv" subset: The same plot, zoomed in on the y range 0.5-2.5 kcal/mol In this run, the lowest validation-set RMSE happened to be encountered at the very last epoch, but that wouldn't be expected in general due to noise in gradient estimates (and especially if run longer). Plotting predicted vs. reference scatter plots for training and validation sets at that last epoch (labeled with RMSE +/- 95% bootstrapped CI): Similar plots could easily be generated for every other "mini-Freesolv" enumerated above. If there's an apparent difference between the more restricted vs. the more complete "mini-Freesolv"s, that might be suggestive of difficulty arising from sparsely sampled elements / chemical environments. |
run 10ns of md per molecule (rather than the measly 0.01ns per molecule in 5866029 )
Looks great! How about we run with this for the next bioRxiv update (and thesis) and revisit compound splitting on a larger set (maybe including N and Cl) in the next iteration (after thesis submission)? |
validation-set rmse ~1.2-1.5 kcal/mol
Thanks!
Sounds good -- compound splitting is subtle and not the primary focus of this demonstration. Because it was convenient (change one line, wait 30 minutes), I re-ran the notebook on the {C, H, O, N, Cl} FreeSolv subset (n=529) to get a preview Noting one observation for when we return to this: In this run, the validation loss increased for ~10 epochs before decreasing again. Early-stopping requires the user to pre-specify a "patience" parameter (how many iterations without improvement to tolerate before stopping), and this example suggests it might be better to choose a "patience" >= 10 epochs. Will sync with @yuanqing-wang about how this patience parameter is currently selected. |
using more thorough vacuum md, specified here 8e50eec
anecdotally, this appears to increase the training-set vs. validation-set error gap, suggesting that insufficient equilibrium sampling might have made the validation-set performance reported in #54 (comment) look more favorable than it should!
To hone in on the version of these results that will be reported in the biorxiv update (and thesis), I re-ran the notebook from #54 (comment) , on the updated equilibrium snapshots cached from more thorough vacuum MD. These updated results should supersede the earlier results. Anecdotally (based on one run with snapshots from short MD vs. one run with snapshots from thorough MD), this update appears to have increased the training-set vs. validation-set error gap, suggesting that insufficient equilibrium sampling might have made the validation-set performance reported in #54 (comment) look more favorable than it should. |
Interesting finding! But I agree that behavior is much closer to what I would have expected from training vs validation error. |
Is this train/validate/test with early stopping, or is it just train/validate with 10% of the dataset split out and no early stopping (with cross-validation over the 10% held-out sets intended to be representative of the test set error)? Are we concerned at all with the experimental strategies being vastly different between the different experiments in the paper for no particular reason? |
No early stopping.
Correct. I'm not aiming to do any hyperparameter selection informed by this experiment, just aiming to report on the repeatability / variability of the training procedure if the dataset were slightly different, and to report an estimate of the generalization error of this specific procedure on the chemistry represented by this specific dataset. In the previous plot, I showed just a single 50/50 split. Would that plot look different if the random seed were different? The way to measure that is to repeat multiple times with different random splits and report all results. The ideal would be to approach leave-one-out (run the procedure 300 times on each of the n=299-size subsets of the data). K-fold is a common compromise.
John, I think the different approaches in progress partly reflect differing goals -- here I'm picking a single hyperparameter choice, and aiming to report on the variability / repeatability of the training procedure. The valence-fitting experiments I think are still highly sensitive to various hyperparameter choices, and the goal of those ongoing experiments is still I think to select good hyperparameters. Experiments constructed to simultaneously select hyperparameters and estimate the generalization error once hyperparameters are selected must take care to do nested cross-validation or use a held-out test set that is only ever consulted once. |
Thanks for the clear explanations! Let's make sure the experimental section describes the motivation and conception of this design, both in the presentation of results and Detailed Methods! Those subtleties will be lost on the reader unless we make them explicit. |
…alculations vs. experiment on the {C, H, O} subset
Noting here a few more to-do's (of undecided priority level), that I think would help shore up and contextualize these results:
|
Observations: * n=10 overfitting seems to achieve a lower error than previously * 50/50 train/validate seems to initialize and optimize at a higher error than in first version
Translating numerical demonstrations from https://github.com/openforcefield/bayes-implicit-solvent#differentiable-atom-typing-experiments , upgrading to use message-passing rather than fingerprints + feedforward model.
Hiccup: porting the autodiff-friendly implementation of GBSA OBC energy from Jax to PyTorch wasn't as simple as replacing
np.
withtorch.
-- I need to track down a likely unit bug I introduced during the conversion, and pass an OpenMM consistency assertion, before merging.