Skip to content

Conversation

@xiaochendu
Copy link
Contributor

Additional and improved methods for fine-tuning and evaluating/plotting NFF

- change paths to models in `chgnet` dir
- remove `chgnet` models in `NeuralForceField`
…in MACE model training

- able to specify whether to fix pooling in args
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This pull request “Merge vssr_pourbaix” introduces additional fine‐tuning options and improvements for evaluation/plotting in the NFF framework. Key changes include extended command‐line arguments for training and evaluation, revised freezing/unfreezing logic in the transfer learning utilities, and updated dependency/configuration settings.

Reviewed Changes

Copilot reviewed 18 out of 18 changed files in this pull request and generated 2 comments.

Show a summary per file
File Description
scripts/train_nff.py Added new CLI arguments for fine-tuning, including custom layers and unfreezing options; updated model loading logic using to_tensor.
scripts/evaluate_nff.py Introduced new plotting options (plot_type, batch_size, per_atom_energy) and adjusted test loader settings.
pyproject.toml Updated the ASE dependency version from 3.22.1 to 3.23.0.
nff/utils/cuda.py Wrapped device selection in a try/except to gracefully handle potential NVIDIA SMI errors.
nff/train/transfer.py Integrated debug print statements in the transfer learning functions and modified unfreezing functions in MaceLayerFreezer and ChgnetLayerFreezer.
nff/nn/models/chgnet.py Updated file paths for pretrained model checkpoints and tweaked module imports.
nff/io/chgnet.py Added helper functions for converting CHGNet structure targets and expanded support for structure data.
nff/io/ase_calcs.py Added a TODO comment for updating atoms only when necessary.
nff/io/ase.py Enhanced AtomsBatch by deep copying arrays and constraints upon initialization and copying.
nff/data/stats.py Reformatted standard deviation and reference mean calculations for outlier removal.
nff/data/dataset.py Added condition for splitting when there is no validation set.
nff/analysis/parity_plot.py Adjusted figure size and changed saving format from PNG to PDF, as well as rasterized plotting elements.
nff/analysis/mpl_settings.py Updated several Matplotlib settings including DPI, font sizes, and line widths.
nff/analysis/loss_plot.py Reduced figure size to (5, 2.5) in the loss plotting routine.
models/foundation_models/chgnet/0.3.0/README.md Removed legacy README to streamline documentation for the 0.3.0 model.
models/foundation_models/chgnet/0.2.0/README.md Removed legacy README to streamline documentation for the 0.2.0 model.

Comment on lines 72 to 74
Function to transfer learn a model. Defined in the subclasses.
"""
pass
Copy link

Copilot AI Apr 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The 'model_tl' function is now implemented as 'pass'; please confirm that this omission is intentional and that proper transfer learning logic is provided elsewhere or will be implemented later.

Suggested change
Function to transfer learn a model. Defined in the subclasses.
"""
pass
Function to transfer learn a model. This method must be implemented
by subclasses to define specific transfer learning logic.
Args:
model (torch.nn.Module): model to be transfer learned
freeze_gap_embedding (bool): whether to freeze gap embedding layers
freeze_pooling (bool): whether to freeze pooling layers
freeze_skip (bool): whether to freeze skip connections
custom_layers (List[str]): list of layers to unfreeze specified by the user
**kwargs: additional arguments for transfer learning
"""
raise NotImplementedError(
"The 'model_tl' method must be implemented by subclasses of LayerFreezer."
)

Copilot uses AI. Check for mistakes.
pafervi
pafervi previously approved these changes Apr 23, 2025
Copy link
Contributor

@ajhoffman1229 ajhoffman1229 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think most of the code looks good, there are a few places where it looks like there are some duplicate lines though.

Comment on lines 25 to 28
ax_fig[0].set_xlabel("Epoch")
ax_fig[0].set_ylabel("Loss")
ax_fig[0].set_xlabel("Epoch")
ax_fig[0].set_ylabel("Loss")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are these lines repeated for a reason?

Comment on lines 35 to 38
ax_fig[1].set_xlabel("Epoch")
ax_fig[1].set_ylabel("Loss")
ax_fig[1].set_xlabel("Epoch")
ax_fig[1].set_ylabel("Loss")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These lines also seem redundant.

Comment on lines 77 to 80
"""Converts hex to rgb colors.
Args:
value (str): string of 6 characters representing a hex colour.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Your US education is clashing with the British English that (I presume?) is taught in Singapore 😂
(Just to be clear, no fix is needed here)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Whoops, I think what happened was I copied and pasted the arg description from somewhere else while the top line was ChatGPT generated. XD

Comment on lines 169 to 174
kernel = gaussian_kde(
np.vstack([x.sample(n=len(x), random_state=2), y.sample(n=len(y), random_state=2)])
)
kernel = gaussian_kde(
np.vstack([x.sample(n=len(x), random_state=2), y.sample(n=len(y), random_state=2)])
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These lines appear to be repeated

Comment on lines 56 to 65
mean = reference_mean if reference_mean else np.mean(stats_array)
std = reference_std if reference_std else np.std(stats_array)
if reference_mean is None:
mean = np.mean(stats_array)
else:
mean = reference_mean
if reference_std is None:
std = np.std(stats_array)
else:
std = reference_std
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this modification do anything different? I feel like the code that this update replaces should function the same as this new code but is more succinct. None values should be falsy.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I must have made an error when merging. I wanted to take the incoming (master) rather than current (vssr_pourbaix).

Comment on lines 135 to 136
# TODO: update atoms only when necessary
atoms.update_nbr_list(update_atoms=True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you provide some additional clarity about the TODO here? Is this issue more persistent/affecting performance significantly enough that it merits its own issue on GitHub? If so, we might want to open one.

Copy link
Contributor Author

@xiaochendu xiaochendu Apr 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is used to update the number of atoms between MCMC steps. It might not be necessary for the general user running MD. Let me remove it for the main branch.

Comment on lines 219 to 263
for i, block in enumerate(model.readouts):
if unfreeze_skip or i == num_readouts - 1:
if unfreeze_skip:
self.unfreeze_parameters(block)
elif i == num_readouts - 1:
self.unfreeze_parameters(block)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason this was split into two if/elif statements that do the same thing?

return f"cuda:{cuda_devices_sorted_by_free_mem()[-1]}"
except nvidia_smi.NVMLError:
return "cuda:0"
return f"cuda:{cuda_devices_sorted_by_free_mem()[-1]}"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With the above try/except statement, should this return line be removed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you're right!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch

ajhoffman1229
ajhoffman1229 previously approved these changes Apr 28, 2025
Copy link
Contributor

@ajhoffman1229 ajhoffman1229 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great!

HojeChun
HojeChun previously approved these changes Apr 28, 2025
Copy link
Contributor

@HojeChun HojeChun left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great to me just had one suggestion.

def convert_data_batch(
data_batch: Dict,
cutoff: float = 5.0,
shuffle: bool = True,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we set shuffle to be False. I assume shuffle has been done when you make dataloader, and this is a function for wrapper.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a dummy variable but made the change as you suggested to make it less confusing (I guess)!

@xiaochendu xiaochendu dismissed stale reviews from HojeChun and ajhoffman1229 via 95937c9 May 1, 2025 01:41
@HojeChun HojeChun merged commit 08d142f into master May 1, 2025
1 check passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants