-
Notifications
You must be signed in to change notification settings - Fork 0
Introduce gSASRec Model with Custom Loss and Training Integration #38
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR introduces the new gSASRec model, extending the SASRec sequential recommendation model with a custom loss function and new training logic. Key changes include modifications in the SASRec constructor parameters and docstrings, new files for the gSASRec model and its loss function, and updates in model registration, training integration, configuration, and documentation.
Reviewed Changes
Copilot reviewed 7 out of 7 changed files in this pull request and generated 1 comment.
Show a summary per file
File | Description |
---|---|
projects/recsys-candidate-generation/src/models/sasrec.py | Updated constructor parameters and docstrings to support additional arguments. |
projects/recsys-candidate-generation/src/models/gsasrec.py | Added new gSASRec model and custom loss implementation. |
projects/recsys-candidate-generation/src/models/init.py | Registered the new gSASRecModule. |
projects/recsys-candidate-generation/src/fit.py | Integrated gSASRec module in the training pipeline. |
projects/recsys-candidate-generation/src/config/model/gsasrec.yaml | Added YAML configuration file for the gSASRec model. |
projects/recsys-candidate-generation/README.md | Updated documentation to include details for the gSASRec model. |
Comments suppressed due to low confidence (2)
projects/recsys-candidate-generation/src/models/sasrec.py:127
- Update the constructor docstring to reflect the changed parameters, including the removal of 'embedding_dim' and 'item_id_dim' and the addition of 'num_heads', 'ffn_dropout', 'eval_top_k', and 'optimizer_params'.
Args:
projects/recsys-candidate-generation/src/config/model/gsasrec.yaml:1
- Consider updating the 'name' field to 'gSASRec' to avoid confusion, since this configuration file is intended for the gSASRec model.
name: "SASRec"
Failed to generate code suggestions for PR |
Failed to generate code suggestions for PR |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR adds the gSASRec model—a calibrated extension of SASRec with a custom loss—and integrates it into the existing training pipeline, configs, and docs.
- Implement gSASRecModule & gSASRecLoss (custom calibration loss)
- Integrate gSASRec into
fit.py
and expose it in__init__.py
- Add YAML config and update README for gSASRec
Reviewed Changes
Copilot reviewed 7 out of 7 changed files in this pull request and generated no comments.
Show a summary per file
File | Description |
---|---|
src/models/sasrec.py | Updated constructor docstring to align with new params |
src/models/gsasrec.py | Added gSASRecModule, gSASRecLoss, training/validation |
src/models/init.py | Exposed gSASRecModule |
src/fit.py | Integrated gSASRecModule into main training logic |
src/config/model/gsasrec.yaml | Added configuration for gSASRec |
README.md | Documented gSASRec with file path and paper link |
Comments suppressed due to low confidence (4)
src/models/gsasrec.py:194
- gSASRecLoss.forward expects separate positive and negative logits, not combined logits and labels. Replace this call with
self.loss_fn(pos_logits, neg_logits)
.
loss: torch.Tensor = self.loss_fn(logits, labels)
src/models/gsasrec.py:227
- Same as above in validation_step: use
self.loss_fn(_pos_logits, _neg_logits)
instead of passing combined logits and labels.
loss: torch.Tensor = self.loss_fn(logits, labels)
src/config/model/gsasrec.yaml:1
- The
name
field should reflect the new model (gSASRec
) instead ofSASRec
to correctly register and select the model.
name: "SASRec"
src/models/gsasrec.py:321
- The docstring for
summary
mentionspos_sample_size
, which isn't a parameter; remove or correct this entry to match the method signature.
pos_sample_size: positive sample size
Failed to generate code suggestions for PR |
PR Reviewer Guide 🔍Here are some key observations to aid the review process:
|
logits, labels = create_classification_inputs(pos_logits, neg_logits) | ||
loss: torch.Tensor = self.loss_fn(logits, labels) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Suggestion: The gSASRecLoss
expects positive_logits
and negative_logits
as separate arguments, not the combined logits
and labels
from create_classification_inputs
. Passing logits
and labels
will result in incorrect loss computation or a runtime error. [possible issue, importance: 10]
logits, labels = create_classification_inputs(pos_logits, neg_logits) | |
loss: torch.Tensor = self.loss_fn(logits, labels) | |
loss: torch.Tensor = self.loss_fn(pos_logits, neg_logits) |
# for imbalanced, extract the first item logits, shape (batch_size, 1) | ||
_pos_logits, _neg_logits = pos_logits[:, 0:1], neg_logits[:, 0:1] | ||
logits, labels = create_classification_inputs(_pos_logits, _neg_logits) | ||
loss: torch.Tensor = self.loss_fn(logits, labels) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Suggestion: The gSASRecLoss
should receive all neg_sample_size
negative logits for its calculation, consistent with the training objective. Currently, only one negative sample is used (neg_logits[:, 0:1]
), and the loss function is incorrectly called with combined logits
and labels
. This leads to an inconsistent and potentially misleading validation loss. [possible issue, importance: 10]
# for imbalanced, extract the first item logits, shape (batch_size, 1) | |
_pos_logits, _neg_logits = pos_logits[:, 0:1], neg_logits[:, 0:1] | |
logits, labels = create_classification_inputs(_pos_logits, _neg_logits) | |
loss: torch.Tensor = self.loss_fn(logits, labels) | |
# The gSASRecLoss expects all negative samples for its calculation. | |
# The create_classification_inputs is for metrics, not the gSASRecLoss itself. | |
loss: torch.Tensor = self.loss_fn(pos_logits, neg_logits) |
… calculation in gSASRecModule
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR introduces the gSASRec
model—a calibrated extension of SASRec—by adding its implementation, custom loss, and training integration, along with corresponding configuration and documentation updates.
- Added
gSASRecModule
andgSASRecLoss
classes to implement the new model and loss. - Integrated
gSASRecModule
intofit.py
and registered it inmodels/__init__.py
. - Created
gsasrec.yaml
config, updated docs inREADME.md
, and adjusted SASRec docstrings for consistency.
Reviewed Changes
Copilot reviewed 8 out of 8 changed files in this pull request and generated 3 comments.
Show a summary per file
File | Description |
---|---|
src/models/sasrec.py | Corrected constructor docstring for SASRec (removed outdated params) |
src/models/gsasrec.py | Added gSASRecModule , gSASRecLoss , and training/evaluation logic |
src/models/init.py | Registered gSASRecModule for package export |
src/fit.py | Integrated gSASRecModule into the main training pipeline |
src/config/model/gsasrec.yaml | Added YAML configuration for the gSASRec model |
README.md | Documented gSASRec model and paper reference |
.github/workflows/pr_agent.yml | Removed unused Vertex AI workflow variables |
projects/recsys-candidate-generation/src/config/model/gsasrec.yaml
Outdated
Show resolved
Hide resolved
Co-authored-by: Copilot <[email protected]>
User description
This pull request introduces the
gSASRec
model, an extension of theSASRec
sequential recommendation model, with significant updates to support its implementation. The changes include the addition of new files, model configurations, and logic for training and evaluation. Below is a summary of the most important changes:Addition of the
gSASRec
ModelgSASRecModule
insrc/models/gsasrec.py
, which extendsSASRec
with a custom loss function (gSASRecLoss
) to reduce overconfidence in sequential recommendation models trained with negative sampling. This includes methods for training, validation, and optimizer configuration.src/config/model/gsasrec.yaml
) specifying parameters such as output dimensions, number of heads, dropout rates, and the calibration parametert
for thegSASRec
loss.Integration into the Codebase
src/models/__init__.py
to includegSASRecModule
in the available models.src/fit.py
to handlegSASRec
in the main training function, including initialization of the module with the appropriate configuration. [1] [2]Documentation Updates
README.md
to include a description ofgSASRec
, its file path, and a reference to the associated research paper.Minor Adjustments for Compatibility
SASRec
: Adjusted the constructor ofSASRec
to align with the new parameters introduced forgSASRec
, such aseval_top_k
andoptimizer_params
.PR Type
Enhancement, Documentation
Description
Implements the new gSASRec model with a custom loss function.
Integrates gSASRec into the main training and validation pipeline.
Adds a dedicated configuration file for the gSASRec model.
Updates documentation to include details about the gSASRec model.
Changes walkthrough 📝
fit.py
Integrate gSASRec model into training pipeline
projects/recsys-candidate-generation/src/fit.py
gSASRecModule
.gSASRecModule
basedon the configuration.
gSASRec
parameters liket
andneg_sample_size
duringinitialization.
gsasrec.py
Implement gSASRec model and custom loss function
projects/recsys-candidate-generation/src/models/gsasrec.py
gSASRecLoss
class for the custom gSASRec loss function.gSASRecModule
PyTorch Lightning module, extendingBaseModule
.SASRec
as a sub-module and customizes training/validationsteps.
configure_optimizers
andlr_scheduler_step
for specifichandling.
__init__.py
Register gSASRecModule for package export
projects/recsys-candidate-generation/src/models/init.py
gSASRecModule
to the__all__
list.gSASRecModule
importable from themodels
package.gsasrec.yaml
Add gSASRec model configuration file
projects/recsys-candidate-generation/src/config/model/gsasrec.yaml
gSASRec
model.out_dim
,num_heads
,num_blocks
, anddropout rates.
t
parameter for the gSASRec loss function.sasrec.py
Standardize SASRecModule constructor docstring
projects/recsys-candidate-generation/src/models/sasrec.py
SASRecModule
constructor.gSASRecModule
'sconstructor.
README.md
Document gSASRec model and research paper
projects/recsys-candidate-generation/README.md
gSASRec
model.paper.