-
Notifications
You must be signed in to change notification settings - Fork 61
Merge vssr_pourbaix
#34
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
- change paths to models in `chgnet` dir - remove `chgnet` models in `NeuralForceField`
…utional layer unfreezing in training script
…ion layer unfreezing in MACE models
…in MACE model training - able to specify whether to fix pooling in args
…ch size in evaluation script
…tter visualization
…r in ChgnetLayerFreezer
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This pull request “Merge vssr_pourbaix” introduces additional fine‐tuning options and improvements for evaluation/plotting in the NFF framework. Key changes include extended command‐line arguments for training and evaluation, revised freezing/unfreezing logic in the transfer learning utilities, and updated dependency/configuration settings.
Reviewed Changes
Copilot reviewed 18 out of 18 changed files in this pull request and generated 2 comments.
Show a summary per file
| File | Description |
|---|---|
| scripts/train_nff.py | Added new CLI arguments for fine-tuning, including custom layers and unfreezing options; updated model loading logic using to_tensor. |
| scripts/evaluate_nff.py | Introduced new plotting options (plot_type, batch_size, per_atom_energy) and adjusted test loader settings. |
| pyproject.toml | Updated the ASE dependency version from 3.22.1 to 3.23.0. |
| nff/utils/cuda.py | Wrapped device selection in a try/except to gracefully handle potential NVIDIA SMI errors. |
| nff/train/transfer.py | Integrated debug print statements in the transfer learning functions and modified unfreezing functions in MaceLayerFreezer and ChgnetLayerFreezer. |
| nff/nn/models/chgnet.py | Updated file paths for pretrained model checkpoints and tweaked module imports. |
| nff/io/chgnet.py | Added helper functions for converting CHGNet structure targets and expanded support for structure data. |
| nff/io/ase_calcs.py | Added a TODO comment for updating atoms only when necessary. |
| nff/io/ase.py | Enhanced AtomsBatch by deep copying arrays and constraints upon initialization and copying. |
| nff/data/stats.py | Reformatted standard deviation and reference mean calculations for outlier removal. |
| nff/data/dataset.py | Added condition for splitting when there is no validation set. |
| nff/analysis/parity_plot.py | Adjusted figure size and changed saving format from PNG to PDF, as well as rasterized plotting elements. |
| nff/analysis/mpl_settings.py | Updated several Matplotlib settings including DPI, font sizes, and line widths. |
| nff/analysis/loss_plot.py | Reduced figure size to (5, 2.5) in the loss plotting routine. |
| models/foundation_models/chgnet/0.3.0/README.md | Removed legacy README to streamline documentation for the 0.3.0 model. |
| models/foundation_models/chgnet/0.2.0/README.md | Removed legacy README to streamline documentation for the 0.2.0 model. |
nff/train/transfer.py
Outdated
| Function to transfer learn a model. Defined in the subclasses. | ||
| """ | ||
| pass |
Copilot
AI
Apr 23, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The 'model_tl' function is now implemented as 'pass'; please confirm that this omission is intentional and that proper transfer learning logic is provided elsewhere or will be implemented later.
| Function to transfer learn a model. Defined in the subclasses. | |
| """ | |
| pass | |
| Function to transfer learn a model. This method must be implemented | |
| by subclasses to define specific transfer learning logic. | |
| Args: | |
| model (torch.nn.Module): model to be transfer learned | |
| freeze_gap_embedding (bool): whether to freeze gap embedding layers | |
| freeze_pooling (bool): whether to freeze pooling layers | |
| freeze_skip (bool): whether to freeze skip connections | |
| custom_layers (List[str]): list of layers to unfreeze specified by the user | |
| **kwargs: additional arguments for transfer learning | |
| """ | |
| raise NotImplementedError( | |
| "The 'model_tl' method must be implemented by subclasses of LayerFreezer." | |
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think most of the code looks good, there are a few places where it looks like there are some duplicate lines though.
nff/analysis/loss_plot.py
Outdated
| ax_fig[0].set_xlabel("Epoch") | ||
| ax_fig[0].set_ylabel("Loss") | ||
| ax_fig[0].set_xlabel("Epoch") | ||
| ax_fig[0].set_ylabel("Loss") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are these lines repeated for a reason?
nff/analysis/loss_plot.py
Outdated
| ax_fig[1].set_xlabel("Epoch") | ||
| ax_fig[1].set_ylabel("Loss") | ||
| ax_fig[1].set_xlabel("Epoch") | ||
| ax_fig[1].set_ylabel("Loss") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These lines also seem redundant.
nff/analysis/mpl_settings.py
Outdated
| """Converts hex to rgb colors. | ||
| Args: | ||
| value (str): string of 6 characters representing a hex colour. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Your US education is clashing with the British English that (I presume?) is taught in Singapore 😂
(Just to be clear, no fix is needed here)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Whoops, I think what happened was I copied and pasted the arg description from somewhere else while the top line was ChatGPT generated. XD
nff/analysis/parity_plot.py
Outdated
| kernel = gaussian_kde( | ||
| np.vstack([x.sample(n=len(x), random_state=2), y.sample(n=len(y), random_state=2)]) | ||
| ) | ||
| kernel = gaussian_kde( | ||
| np.vstack([x.sample(n=len(x), random_state=2), y.sample(n=len(y), random_state=2)]) | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These lines appear to be repeated
nff/data/stats.py
Outdated
| mean = reference_mean if reference_mean else np.mean(stats_array) | ||
| std = reference_std if reference_std else np.std(stats_array) | ||
| if reference_mean is None: | ||
| mean = np.mean(stats_array) | ||
| else: | ||
| mean = reference_mean | ||
| if reference_std is None: | ||
| std = np.std(stats_array) | ||
| else: | ||
| std = reference_std |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this modification do anything different? I feel like the code that this update replaces should function the same as this new code but is more succinct. None values should be falsy.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I must have made an error when merging. I wanted to take the incoming (master) rather than current (vssr_pourbaix).
nff/io/ase_calcs.py
Outdated
| # TODO: update atoms only when necessary | ||
| atoms.update_nbr_list(update_atoms=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you provide some additional clarity about the TODO here? Is this issue more persistent/affecting performance significantly enough that it merits its own issue on GitHub? If so, we might want to open one.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is used to update the number of atoms between MCMC steps. It might not be necessary for the general user running MD. Let me remove it for the main branch.
| for i, block in enumerate(model.readouts): | ||
| if unfreeze_skip or i == num_readouts - 1: | ||
| if unfreeze_skip: | ||
| self.unfreeze_parameters(block) | ||
| elif i == num_readouts - 1: | ||
| self.unfreeze_parameters(block) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a reason this was split into two if/elif statements that do the same thing?
nff/utils/cuda.py
Outdated
| return f"cuda:{cuda_devices_sorted_by_free_mem()[-1]}" | ||
| except nvidia_smi.NVMLError: | ||
| return "cuda:0" | ||
| return f"cuda:{cuda_devices_sorted_by_free_mem()[-1]}" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
With the above try/except statement, should this return line be removed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you're right!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch
Co-authored-by: Copilot <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great to me just had one suggestion.
| def convert_data_batch( | ||
| data_batch: Dict, | ||
| cutoff: float = 5.0, | ||
| shuffle: bool = True, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we set shuffle to be False. I assume shuffle has been done when you make dataloader, and this is a function for wrapper.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's a dummy variable but made the change as you suggested to make it less confusing (I guess)!
Additional and improved methods for fine-tuning and evaluating/plotting NFF