-
Notifications
You must be signed in to change notification settings - Fork 173
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Significant time spent on dense_to_jagged operations contradicts compute-intensive claims in paper #192
Comments
Hi - we actually have two copies of code in this repo that we haven't fully integrated yet. Please check https://github.com/facebookresearch/generative-recommenders?tab=readme-ov-file#efficiency-experiments for the notes. Integration will likely be done in the next couple of weeks. |
Thank you for your previous reply. I have some further questions and confusions regarding the code and the model in your repository, and I'm hoping you can help clarify them. 1. Tensor Shape Transformation IssueI've noticed that a significant amount of time is spent on tensor shape transformations, specifically between dense and jagged tensors. I'm curious about the rationale behind this design. Will there be any future plans to optimize these tensor shape transformations? As it currently consumes a considerable amount of resources and might affect the overall system efficiency. 2. User Sequence Inference IssueCurrently, in the existing tests, the model's input during each inference is based on the full user sequence. I'm wondering if the upcoming merged version will support inference using user incremental sequences. If so, could you provide some insights into how this will be implemented and what the potential benefits are? I truly appreciate your time and effort in answering these questions. I'm looking forward to your response. Best regards |
The general idea for what we described in the paper is to keep all intermediate tensors in jagged/ragged formats. We provide triton/cuda kernels in this repo that should enable you to do that. For the code on public datasets, for fairness with Transformer/SASRec-based baselines (where we didn't reimplement everything in jagged), we omitted this step. Not quite sure about what you meant by "user incremental sequences" - is that microbatching/KV caching in M-FALCON? |
Currently, the input for encoding is the user's full sequence, which is used to generate the user's current embedding representation. If the user has a new behavior sequence, should the next input to the model be the user's current embedding representation along with the new behavior sequence, or the user's full sequence after the new addition? Moreover, will the dlrm-v3 to be integrated keep all intermediate tensors in jagged/ragged formats? |
I do not fully understand this question. GR is not a typical embedding model (due to target-aware formulation discussed in the paper) so you need either the full sequence or utilize KV caching to do partial re-computation. The jagged optimizations should be done in https://github.com/facebookresearch/generative-recommenders/blob/main/generative_recommenders/modules/stu.py (but that code is slightly WIP). |
Thank you for your explanation. I took another look at the code in the encode stage of the research section. Regarding your mention of "utilize KV caching to do partial re - computation", are you referring to the key point of using |
Issue Report: Significant time spent on dense_to_jagged operations contradicts compute-intensive claims in paper`
Description
Hi maintainers,
When profiling the
eval_metrics_v2_from_tensors
module under research code, I observed that the majority of the execution time is spent ondense_to_jagged
memory copy operations rather than compute-intensive operations as mentioned in the paper.This behavior seems contradictory to the paper's emphasis on compute-bound workloads. Specifically:
Steps to Reproduce
eval_metrics_v2_from_tensors
with theml-20m
dataset.Expected Behavior
Expect to see dominant time spent on:
As suggested by the paper's computational complexity analysis.
Environment
ml-20m
.Additional Context
Attached profiling trace screenshot:
Questions
Would appreciate your insights on whether:
Thank you for your work on this important research!
The text was updated successfully, but these errors were encountered: