-
Notifications
You must be signed in to change notification settings - Fork 2
Conditional diffusion part1 #49
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…re network base class to do conditional forward randomly
@@ -42,6 +43,9 @@ class ScoreNetworkParameters: | |||
"""Base Hyper-parameters for score networks.""" | |||
architecture: str | |||
spatial_dimension: int = 3 # the dimension of Euclidean space where atoms live. | |||
conditional_prob: float = 0. # probability of making an conditional forward - else, do a unconditional forward |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
'an conditional' -> 'a conditional'
'a unconditional' -> 'an unconditional'
"""Model forward. | ||
|
||
Args: | ||
batch : dictionary containing the data to be processed by the model. | ||
conditional: if True, do an conditional forward, if False, do a unconditional forward. If None, choose |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
'a unconditional' -> 'an unconditional'
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems to me that the behavior of the MLP score network has been changed. Review if this is intentional.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
in this PR, I implement a conditional version of the score network based on the classifier-free guidance approach. In short, the modified score becomes:


where I used \gamma = 1 + w instead to match the convention in MatterGen (hyper-parameter in the config file - set to 2 by default as per MatterGen). The second term on the right-hand side is the usual score network. The first one is a modified version that takes a condition c - here the cartesian forces - as an input. This modification takes the form of an added value in MatterGen:
where H is the node features at a given layer in the GNN used (GemNet in the case of MatterGen). Here, I implemented the added value as a linear layer on top of each layer of the MLP approach. I didn't tackle MACE yet.
During training, a forward for a batch can be done unconditionally (as if the property was null) or conditionally. The probability is set by an hyper-parameter in the config file. This is different from MatterGen where training was first done unconditionally, then fine-tuned on the conditions.