Skip to content

Introduce gSASRec Model with Custom Loss and Training Integration #38

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
May 26, 2025

Conversation

haru-256
Copy link
Owner

@haru-256 haru-256 commented May 26, 2025

User description

This pull request introduces the gSASRec model, an extension of the SASRec sequential recommendation model, with significant updates to support its implementation. The changes include the addition of new files, model configurations, and logic for training and evaluation. Below is a summary of the most important changes:

Addition of the gSASRec Model

  • New Model Implementation: Introduced gSASRecModule in src/models/gsasrec.py, which extends SASRec with a custom loss function (gSASRecLoss) to reduce overconfidence in sequential recommendation models trained with negative sampling. This includes methods for training, validation, and optimizer configuration.
  • Model Configuration: Added a YAML configuration file (src/config/model/gsasrec.yaml) specifying parameters such as output dimensions, number of heads, dropout rates, and the calibration parameter t for the gSASRec loss.

Integration into the Codebase

  • Model Registration: Updated src/models/__init__.py to include gSASRecModule in the available models.
  • Training Pipeline: Modified src/fit.py to handle gSASRec in the main training function, including initialization of the module with the appropriate configuration. [1] [2]

Documentation Updates

  • README Update: Expanded the README.md to include a description of gSASRec, its file path, and a reference to the associated research paper.

Minor Adjustments for Compatibility

  • Refactoring in SASRec: Adjusted the constructor of SASRec to align with the new parameters introduced for gSASRec, such as eval_top_k and optimizer_params.

PR Type

Enhancement, Documentation


Description

  • Implements the new gSASRec model with a custom loss function.

  • Integrates gSASRec into the main training and validation pipeline.

  • Adds a dedicated configuration file for the gSASRec model.

  • Updates documentation to include details about the gSASRec model.


Changes walkthrough 📝

Relevant files
Enhancement
fit.py
Integrate gSASRec model into training pipeline                     

projects/recsys-candidate-generation/src/fit.py

  • Imports the newly added gSASRecModule.
  • Extends the model initialization logic to include gSASRecModule based
    on the configuration.
  • Passes specific gSASRec parameters like t and neg_sample_size during
    initialization.
  • +19/-1   
    gsasrec.py
    Implement gSASRec model and custom loss function                 

    projects/recsys-candidate-generation/src/models/gsasrec.py

  • Implements the gSASRecLoss class for the custom gSASRec loss function.
  • Defines the gSASRecModule PyTorch Lightning module, extending
    BaseModule.
  • Incorporates SASRec as a sub-module and customizes training/validation
    steps.
  • Overrides configure_optimizers and lr_scheduler_step for specific
    handling.
  • +344/-0 
    Configuration changes
    __init__.py
    Register gSASRecModule for package export                               

    projects/recsys-candidate-generation/src/models/init.py

  • Adds gSASRecModule to the __all__ list.
  • Makes gSASRecModule importable from the models package.
  • +2/-1     
    gsasrec.yaml
    Add gSASRec model configuration file                                         

    projects/recsys-candidate-generation/src/config/model/gsasrec.yaml

  • Creates a new YAML configuration file for the gSASRec model.
  • Defines model parameters such as out_dim, num_heads, num_blocks, and
    dropout rates.
  • Specifies the t parameter for the gSASRec loss function.
  • +12/-0   
    Documentation
    sasrec.py
    Standardize SASRecModule constructor docstring                     

    projects/recsys-candidate-generation/src/models/sasrec.py

  • Updates the docstring for the SASRecModule constructor.
  • Aligns parameter descriptions with the new gSASRecModule's
    constructor.
  • +4/-5     
    README.md
    Document gSASRec model and research paper                               

    projects/recsys-candidate-generation/README.md

  • Adds a new section dedicated to the gSASRec model.
  • Includes the file path and a reference to the associated research
    paper.
  • +7/-0     

    Need help?
  • Type /help how to ... in the comments thread for any questions about PR-Agent usage.
  • Check out the documentation for more information.
  • @haru-256 haru-256 requested a review from Copilot May 26, 2025 14:32
    Copy link
    Contributor

    @Copilot Copilot AI left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Pull Request Overview

    This PR introduces the new gSASRec model, extending the SASRec sequential recommendation model with a custom loss function and new training logic. Key changes include modifications in the SASRec constructor parameters and docstrings, new files for the gSASRec model and its loss function, and updates in model registration, training integration, configuration, and documentation.

    Reviewed Changes

    Copilot reviewed 7 out of 7 changed files in this pull request and generated 1 comment.

    Show a summary per file
    File Description
    projects/recsys-candidate-generation/src/models/sasrec.py Updated constructor parameters and docstrings to support additional arguments.
    projects/recsys-candidate-generation/src/models/gsasrec.py Added new gSASRec model and custom loss implementation.
    projects/recsys-candidate-generation/src/models/init.py Registered the new gSASRecModule.
    projects/recsys-candidate-generation/src/fit.py Integrated gSASRec module in the training pipeline.
    projects/recsys-candidate-generation/src/config/model/gsasrec.yaml Added YAML configuration file for the gSASRec model.
    projects/recsys-candidate-generation/README.md Updated documentation to include details for the gSASRec model.
    Comments suppressed due to low confidence (2)

    projects/recsys-candidate-generation/src/models/sasrec.py:127

    • Update the constructor docstring to reflect the changed parameters, including the removal of 'embedding_dim' and 'item_id_dim' and the addition of 'num_heads', 'ffn_dropout', 'eval_top_k', and 'optimizer_params'.
    Args:
    

    projects/recsys-candidate-generation/src/config/model/gsasrec.yaml:1

    • Consider updating the 'name' field to 'gSASRec' to avoid confusion, since this configuration file is intended for the gSASRec model.
    name: "SASRec"
    

    Copy link

    Failed to generate code suggestions for PR

    @github-actions github-actions bot changed the title feat: implement gSASRec model with configuration and loss function Implement gSASRec Model with Custom Loss and Training Integration May 26, 2025
    Copy link

    Failed to generate code suggestions for PR

    @haru-256 haru-256 requested a review from Copilot May 26, 2025 14:55
    Copy link
    Contributor

    @Copilot Copilot AI left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Pull Request Overview

    This PR adds the gSASRec model—a calibrated extension of SASRec with a custom loss—and integrates it into the existing training pipeline, configs, and docs.

    • Implement gSASRecModule & gSASRecLoss (custom calibration loss)
    • Integrate gSASRec into fit.py and expose it in __init__.py
    • Add YAML config and update README for gSASRec

    Reviewed Changes

    Copilot reviewed 7 out of 7 changed files in this pull request and generated no comments.

    Show a summary per file
    File Description
    src/models/sasrec.py Updated constructor docstring to align with new params
    src/models/gsasrec.py Added gSASRecModule, gSASRecLoss, training/validation
    src/models/init.py Exposed gSASRecModule
    src/fit.py Integrated gSASRecModule into main training logic
    src/config/model/gsasrec.yaml Added configuration for gSASRec
    README.md Documented gSASRec with file path and paper link
    Comments suppressed due to low confidence (4)

    src/models/gsasrec.py:194

    • gSASRecLoss.forward expects separate positive and negative logits, not combined logits and labels. Replace this call with self.loss_fn(pos_logits, neg_logits).
    loss: torch.Tensor = self.loss_fn(logits, labels)
    

    src/models/gsasrec.py:227

    • Same as above in validation_step: use self.loss_fn(_pos_logits, _neg_logits) instead of passing combined logits and labels.
    loss: torch.Tensor = self.loss_fn(logits, labels)
    

    src/config/model/gsasrec.yaml:1

    • The name field should reflect the new model (gSASRec) instead of SASRec to correctly register and select the model.
    name: "SASRec"
    

    src/models/gsasrec.py:321

    • The docstring for summary mentions pos_sample_size, which isn't a parameter; remove or correct this entry to match the method signature.
                pos_sample_size: positive sample size
    

    Copy link

    Failed to generate code suggestions for PR

    @github-actions github-actions bot changed the title Implement gSASRec Model with Custom Loss and Training Integration Introduce gSASRec Model with Custom Loss and Training Integration May 26, 2025
    Copy link

    PR Reviewer Guide 🔍

    Here are some key observations to aid the review process:

    ⏱️ Estimated effort to review: 3 🔵🔵🔵⚪⚪
    🧪 No relevant tests
    🔒 No security concerns identified
    🔀 No multiple PR themes
    ⚡ Recommended focus areas for review

    Numerical Stability

    The gSASRecLoss implementation uses torch.float64 and torch.clamp for numerical stability. Review the clamping bounds and the overall logic to ensure that these operations do not inadvertently clip valid ranges or introduce unexpected behavior, especially in the positive_probs_adjusted calculation.

    positive_logits = positive_logits.to(torch.float64)
    negative_logits = negative_logits.to(positive_logits.dtype)
    
    positive_probs = torch.clamp(torch.sigmoid(positive_logits), self.eps, 1 - self.eps)
    positive_probs_adjusted = torch.clamp(
        positive_probs.pow(-self.beta), 1 + self.eps, torch.finfo(torch.float64).max
    )
    to_log = torch.clamp(
        torch.div(1.0, (positive_probs_adjusted - 1)), self.eps, torch.finfo(torch.float64).max
    )
    positive_logits_transformed = to_log.log()
    
    logits = torch.cat([positive_logits_transformed, negative_logits], dim=1)
    labels = torch.cat(
        [torch.ones_like(positive_logits), torch.zeros_like(negative_logits)], dim=1
    )
    loss = self.bce(logits, labels)
    Logits Calculation

    The static method _calc_logits is critical for computing the positive and negative logits. Verify that the tensor manipulations (unsqueezing, transposing, bmm, squeezing) correctly align dimensions and produce the expected logits for the loss function, especially considering the seq_len dimension of out.

    @staticmethod
    def _calc_logits(
        out: torch.Tensor, pos_item_emb: torch.Tensor, neg_item_emb: torch.Tensor
    ) -> tuple[torch.Tensor, torch.Tensor]:
        """Calculate logits
    
        Args:
            out: output tensor of SASRec, shape (batch_size, seq_len, hidden_size)
            pos_item_emb: positive item embedding, shape (batch_size, hidden_size)
            neg_item_emb: negative item embedding, shape (batch_size, neg_sample_size, hidden_size)
    
        Returns:
            pos_logits: positive logits, shape (batch_size, 1)
            neg_logits: negative logits, shape (batch_size, neg_sample_size)
    
        """
        # extract the last hidden state, shape (batch_size, 1, hidden_size)
        out = out[:, -1, :].unsqueeze(1)
    
        pos_item_emb = pos_item_emb.unsqueeze(1)  # shape (batch_size, 1, hidden_size)
        # shape (batch_size, 1)
        pos_logits = torch.bmm(out, pos_item_emb.transpose(1, 2)).squeeze(1)
        # shape (batch_size, neg_sample_size)
        neg_logits = torch.bmm(out, neg_item_emb.transpose(1, 2)).squeeze(1)
    
        return pos_logits, neg_logits
    LR Scheduler Override

    The lr_scheduler_step method is overridden due to CosineLRScheduler not inheriting from torch.optim.lr_scheduler.LRScheduler. Ensure that the custom logic for advancing the scheduler (scheduler.step(epoch=steps)) correctly handles both 'epoch' and 'step' step_unit configurations and that the metric parameter is passed appropriately when required.

    @override
    def lr_scheduler_step(self, scheduler: CosineLRScheduler, metric: Any | None) -> None:  # type: ignore
        """CosineLRSchedulerのstepを進める
        CosineLRSchedulerがtorch.optim.lr_scheduler.LRSchedulerを継承していないためoverride
        """
        match self.optimizer_params.lr_scheduler.step_unit:
            case "epoch":
                steps = self.current_epoch
            case "step":
                steps = self.global_step
            case _:
                raise ValueError(
                    f"Invalid step unit: {self.optimizer_params.lr_scheduler.step_unit}"
                )
        if metric is None:
            scheduler.step(epoch=steps)  # NOTE: epochとあるが、epochでもstepでもどちらでもOK
        else:
            scheduler.step(epoch=steps, metric=metric)

    Comment on lines 193 to 194
    logits, labels = create_classification_inputs(pos_logits, neg_logits)
    loss: torch.Tensor = self.loss_fn(logits, labels)

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Suggestion: The gSASRecLoss expects positive_logits and negative_logits as separate arguments, not the combined logits and labels from create_classification_inputs. Passing logits and labels will result in incorrect loss computation or a runtime error. [possible issue, importance: 10]

    Suggested change
    logits, labels = create_classification_inputs(pos_logits, neg_logits)
    loss: torch.Tensor = self.loss_fn(logits, labels)
    loss: torch.Tensor = self.loss_fn(pos_logits, neg_logits)

    Comment on lines 224 to 227
    # for imbalanced, extract the first item logits, shape (batch_size, 1)
    _pos_logits, _neg_logits = pos_logits[:, 0:1], neg_logits[:, 0:1]
    logits, labels = create_classification_inputs(_pos_logits, _neg_logits)
    loss: torch.Tensor = self.loss_fn(logits, labels)

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Suggestion: The gSASRecLoss should receive all neg_sample_size negative logits for its calculation, consistent with the training objective. Currently, only one negative sample is used (neg_logits[:, 0:1]), and the loss function is incorrectly called with combined logits and labels. This leads to an inconsistent and potentially misleading validation loss. [possible issue, importance: 10]

    Suggested change
    # for imbalanced, extract the first item logits, shape (batch_size, 1)
    _pos_logits, _neg_logits = pos_logits[:, 0:1], neg_logits[:, 0:1]
    logits, labels = create_classification_inputs(_pos_logits, _neg_logits)
    loss: torch.Tensor = self.loss_fn(logits, labels)
    # The gSASRecLoss expects all negative samples for its calculation.
    # The create_classification_inputs is for metrics, not the gSASRecLoss itself.
    loss: torch.Tensor = self.loss_fn(pos_logits, neg_logits)

    @haru-256 haru-256 requested a review from Copilot May 26, 2025 15:19
    Copy link
    Contributor

    @Copilot Copilot AI left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Pull Request Overview

    This PR introduces the gSASRec model—a calibrated extension of SASRec—by adding its implementation, custom loss, and training integration, along with corresponding configuration and documentation updates.

    • Added gSASRecModule and gSASRecLoss classes to implement the new model and loss.
    • Integrated gSASRecModule into fit.py and registered it in models/__init__.py.
    • Created gsasrec.yaml config, updated docs in README.md, and adjusted SASRec docstrings for consistency.

    Reviewed Changes

    Copilot reviewed 8 out of 8 changed files in this pull request and generated 3 comments.

    Show a summary per file
    File Description
    src/models/sasrec.py Corrected constructor docstring for SASRec (removed outdated params)
    src/models/gsasrec.py Added gSASRecModule, gSASRecLoss, and training/evaluation logic
    src/models/init.py Registered gSASRecModule for package export
    src/fit.py Integrated gSASRecModule into the main training pipeline
    src/config/model/gsasrec.yaml Added YAML configuration for the gSASRec model
    README.md Documented gSASRec model and paper reference
    .github/workflows/pr_agent.yml Removed unused Vertex AI workflow variables

    @haru-256 haru-256 merged commit 7fb8ea7 into main May 26, 2025
    2 checks passed
    @haru-256 haru-256 deleted the feat/gsasrec-module branch May 26, 2025 15:27
    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
    Projects
    None yet
    Development

    Successfully merging this pull request may close these issues.

    1 participant