Skip to content

Delta tracker DMP integration #3064

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

aliafzal
Copy link
Contributor

@aliafzal aliafzal commented Jun 9, 2025

Summary:

This Diff

Adds ModelDeltaTracker integration with DMP (DistributedModelParallel) and sharded modules. This integration enables tracking of embedding IDs, embeddings, and optimizer states during model execution, which is particularly useful for online training scenarios.

Key Components:

ModelTrackerConfig Integration:

  • Added ModelTrackerConfig parameter to DMP constructor
  • When provided, automatically initializes ModelDeltaTracker
  • Configurable options include tracking_mode, delete_on_read, auto_compact, and fqns_to_skip

Custom Callables for Tracking:

  • Added custom post_lookup_hook in ShardedModule to capture IDs and embeddings after lookup operations. This provides tracking ids/states natively into torchrec without registering any nn.Module specific hooks.
  • Added post_odist_hook for auto-compaction of tracked data. This custom hook provides native support for overlapping compaction with odist.
  • Implemented pre_forward callables in DMP for operations like batch index incrementation

Model Parallel API Enhancements:

  • Added get_model_tracker() method to DistributedModelParallel for direct access to the ModelDeltaTracker instance. This API give the flexibility to integrate model tracker into required components directly without needing to access the dmp_module.
  • Added get_delta() method as a convenience API to retrieve delta rows from dmp_module.

Embedding Module Changes:

  • Enhanced ShardedEmbeddingBag and ShardedEmbedding to support tracking hooks / Callable
  • Added hook registration methods in embedding modules
  • Implemented tracking support for different optimizer states (momentum, Adam states)

ModelDeltaTracker Context

ModelDeltaTracker is a utility for tracking and retrieving unique IDs and their corresponding embeddings or states from embedding modules in model using Torchrec. It's particularly useful for:

  1. Identifying which embedding rows were accessed during model execution
  2. Retrieving the latest delta or unique rows for a model
  3. Computing top-k changed embeddings
  4. Supporting streaming updated embeddings between systems during online training

For more details see diff:D75853147 or PR #3057

Differential Revision: D76202371

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jun 9, 2025
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D76202371

@aliafzal aliafzal force-pushed the export-D76202371 branch from 818dd82 to 2f65757 Compare June 9, 2025 10:25
aliafzal added a commit to aliafzal/torchrec that referenced this pull request Jun 9, 2025
Summary:

## This Diff
Adds ModelDeltaTracker integration with DMP (DistributedModelParallel) and sharded modules. This integration enables tracking of embedding IDs, embeddings, and optimizer states during model execution, which is particularly useful for online training scenarios.

### Key Components:
**ModelTrackerConfig Integration**:
* Added ModelTrackerConfig parameter to DMP constructor
* When provided, automatically initializes ModelDeltaTracker
*  Configurable options include tracking_mode, delete_on_read, auto_compact, and fqns_to_skip

**Custom Callables for Tracking**:
* Added custom post_lookup_hook in ShardedModule to capture IDs and embeddings after lookup operations. This provides tracking ids/states natively into torchrec without registering any nn.Module specific hooks.
* Added post_odist_hook for auto-compaction of tracked data. This custom hook provides native support for overlapping compaction with odist.
* Implemented pre_forward callables in DMP for operations like batch index incrementation

**Model Parallel API Enhancements**:
* Added `get_model_tracker()` method to DistributedModelParallel for direct access to the ModelDeltaTracker instance. This API give the flexibility to integrate model tracker into required components directly without needing to access the dmp_module.
* Added `get_delta()` method as a convenience API to retrieve delta rows from dmp_module. 

**Embedding Module Changes**:
* Enhanced ShardedEmbeddingBag and ShardedEmbedding to support tracking hooks / Callable
* Added hook registration methods in embedding modules
* Implemented tracking support for different optimizer states (momentum, Adam states)
 

## ModelDeltaTracker Context
ModelDeltaTracker is a utility for tracking and retrieving unique IDs and their corresponding embeddings or states from embedding modules in model using Torchrec. It's particularly useful for:

1. Identifying which embedding rows were accessed during model execution
2. Retrieving the latest delta or unique rows for a model
3. Computing top-k changed embeddings
4. Supporting streaming updated embeddings between systems during online training


For more details see diff:D75853147 or PR pytorch#3057

Differential Revision: D76202371
aliafzal added a commit to aliafzal/torchrec that referenced this pull request Jun 9, 2025
Summary:

## This Diff
Adds ModelDeltaTracker integration with DMP (DistributedModelParallel) and sharded modules. This integration enables tracking of embedding IDs, embeddings, and optimizer states during model execution, which is particularly useful for online training scenarios.

### Key Components:
**ModelTrackerConfig Integration**:
* Added ModelTrackerConfig parameter to DMP constructor
* When provided, automatically initializes ModelDeltaTracker
*  Configurable options include tracking_mode, delete_on_read, auto_compact, and fqns_to_skip

**Custom Callables for Tracking**:
* Added custom post_lookup_hook in ShardedModule to capture IDs and embeddings after lookup operations. This provides tracking ids/states natively into torchrec without registering any nn.Module specific hooks.
* Added post_odist_hook for auto-compaction of tracked data. This custom hook provides native support for overlapping compaction with odist.
* Implemented pre_forward callables in DMP for operations like batch index incrementation

**Model Parallel API Enhancements**:
* Added `get_model_tracker()` method to DistributedModelParallel for direct access to the ModelDeltaTracker instance. This API give the flexibility to integrate model tracker into required components directly without needing to access the dmp_module.
* Added `get_delta()` method as a convenience API to retrieve delta rows from dmp_module. 

**Embedding Module Changes**:
* Enhanced ShardedEmbeddingBag and ShardedEmbedding to support tracking hooks / Callable
* Added hook registration methods in embedding modules
* Implemented tracking support for different optimizer states (momentum, Adam states)
 

## ModelDeltaTracker Context
ModelDeltaTracker is a utility for tracking and retrieving unique IDs and their corresponding embeddings or states from embedding modules in model using Torchrec. It's particularly useful for:

1. Identifying which embedding rows were accessed during model execution
2. Retrieving the latest delta or unique rows for a model
3. Computing top-k changed embeddings
4. Supporting streaming updated embeddings between systems during online training


For more details see diff:D75853147 or PR pytorch#3057

Differential Revision: D76202371
@aliafzal aliafzal force-pushed the export-D76202371 branch from 2f65757 to f726e9e Compare June 9, 2025 10:26
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D76202371

1 similar comment
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D76202371

@aliafzal aliafzal force-pushed the export-D76202371 branch from f726e9e to 1785a5e Compare June 9, 2025 10:28
aliafzal added a commit to aliafzal/torchrec that referenced this pull request Jun 9, 2025
Summary:
Pull Request resolved: pytorch#3064

## This Diff
Adds ModelDeltaTracker integration with DMP (DistributedModelParallel) and sharded modules. This integration enables tracking of embedding IDs, embeddings, and optimizer states during model execution, which is particularly useful for online training scenarios.

### Key Components:
**ModelTrackerConfig Integration**:
* Added ModelTrackerConfig parameter to DMP constructor
* When provided, automatically initializes ModelDeltaTracker
*  Configurable options include tracking_mode, delete_on_read, auto_compact, and fqns_to_skip

**Custom Callables for Tracking**:
* Added custom post_lookup_hook in ShardedModule to capture IDs and embeddings after lookup operations. This provides tracking ids/states natively into torchrec without registering any nn.Module specific hooks.
* Added post_odist_hook for auto-compaction of tracked data. This custom hook provides native support for overlapping compaction with odist.
* Implemented pre_forward callables in DMP for operations like batch index incrementation

**Model Parallel API Enhancements**:
* Added `get_model_tracker()` method to DistributedModelParallel for direct access to the ModelDeltaTracker instance. This API give the flexibility to integrate model tracker into required components directly without needing to access the dmp_module.
* Added `get_delta()` method as a convenience API to retrieve delta rows from dmp_module.

**Embedding Module Changes**:
* Enhanced ShardedEmbeddingBag and ShardedEmbedding to support tracking hooks / Callable
* Added hook registration methods in embedding modules
* Implemented tracking support for different optimizer states (momentum, Adam states)

## ModelDeltaTracker Context
ModelDeltaTracker is a utility for tracking and retrieving unique IDs and their corresponding embeddings or states from embedding modules in model using Torchrec. It's particularly useful for:

1. Identifying which embedding rows were accessed during model execution
2. Retrieving the latest delta or unique rows for a model
3. Computing top-k changed embeddings
4. Supporting streaming updated embeddings between systems during online training

For more details see diff:D75853147 or PR pytorch#3057

Differential Revision: D76202371
maliafzal added 2 commits June 9, 2025 16:19
Summary:
Pull Request resolved: pytorch#3059

# This Diff
Added implementation for fqn_to_feature_names method along with initial testing framework and UTs for fqn_to_feature_names
# ModelDeltaTracker Context

ModelDeltaTracker is a utility for tracking and retrieving unique IDs and their corresponding embeddings or states from embedding modules in model using Torchrec. It's particularly useful for:

1. Identifying which embedding rows were accessed during model execution
2. Retrieving the latest delta or unique rows for a model
3. Computing top-k changed embeddings
4. Supporting streaming updated embeddings between systems during online training

Differential Revision: D75908963
Summary:
Pull Request resolved: pytorch#3060

### Diff Summary

This diff introduces implementation of tracking logic for ID and Embedding mode

1. **Record Functions**
```record_lookup():``` Handles recording of IDs and embeddings based on the tracking mode.
```record_ids():``` Records IDs from a KeyedJaggedTensor.
```record_embeddings():``` Records IDs along with embeddings, ensuring size compatibility between IDs and embeddings.

2. **Delta Retrieval**
```get_delta():``` Retrieves per FQN local IDs for each sparse feature.

3. **Tracked Modules Access**
```get_tracked_modules():``` Returns a dictionary of tracked modules.

## ModelDeltaTracker Context
ModelDeltaTracker is a utility for tracking and retrieving unique IDs and their corresponding embeddings or states from embedding modules in model using Torchrec. It's particularly useful for:

1. Identifying which embedding rows were accessed during model execution
2. Retrieving the latest delta or unique rows for a model
3. Computing top-k changed embeddings
4. Supporting streaming updated embeddings between systems during online training

Differential Revision: D76094097
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D76202371

@aliafzal aliafzal force-pushed the export-D76202371 branch from 1785a5e to 9505130 Compare June 9, 2025 23:25
aliafzal added a commit to aliafzal/torchrec that referenced this pull request Jun 9, 2025
Summary:
Pull Request resolved: pytorch#3064

## This Diff
Adds ModelDeltaTracker integration with DMP (DistributedModelParallel) and sharded modules. This integration enables tracking of embedding IDs, embeddings, and optimizer states during model execution, which is particularly useful for online training scenarios.

### Key Components:
**ModelTrackerConfig Integration**:
* Added ModelTrackerConfig parameter to DMP constructor
* When provided, automatically initializes ModelDeltaTracker
*  Configurable options include tracking_mode, delete_on_read, auto_compact, and fqns_to_skip

**Custom Callables for Tracking**:
* Added custom post_lookup_hook in ShardedModule to capture IDs and embeddings after lookup operations. This provides tracking ids/states natively into torchrec without registering any nn.Module specific hooks.
* Added post_odist_hook for auto-compaction of tracked data. This custom hook provides native support for overlapping compaction with odist.
* Implemented pre_forward callables in DMP for operations like batch index incrementation

**Model Parallel API Enhancements**:
* Added `get_model_tracker()` method to DistributedModelParallel for direct access to the ModelDeltaTracker instance. This API give the flexibility to integrate model tracker into required components directly without needing to access the dmp_module.
* Added `get_delta()` method as a convenience API to retrieve delta rows from dmp_module.

**Embedding Module Changes**:
* Enhanced ShardedEmbeddingBag and ShardedEmbedding to support tracking hooks / Callable
* Added hook registration methods in embedding modules
* Implemented tracking support for different optimizer states (momentum, Adam states)

## ModelDeltaTracker Context
ModelDeltaTracker is a utility for tracking and retrieving unique IDs and their corresponding embeddings or states from embedding modules in model using Torchrec. It's particularly useful for:

1. Identifying which embedding rows were accessed during model execution
2. Retrieving the latest delta or unique rows for a model
3. Computing top-k changed embeddings
4. Supporting streaming updated embeddings between systems during online training

For more details see diff:D75853147 or PR pytorch#3057

Differential Revision: D76202371
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D76202371

aliafzal added a commit to aliafzal/torchrec that referenced this pull request Jun 9, 2025
Summary:
Pull Request resolved: pytorch#3064

## This Diff
Adds ModelDeltaTracker integration with DMP (DistributedModelParallel) and sharded modules. This integration enables tracking of embedding IDs, embeddings, and optimizer states during model execution, which is particularly useful for online training scenarios.

### Key Components:
**ModelTrackerConfig Integration**:
* Added ModelTrackerConfig parameter to DMP constructor
* When provided, automatically initializes ModelDeltaTracker
*  Configurable options include tracking_mode, delete_on_read, auto_compact, and fqns_to_skip

**Custom Callables for Tracking**:
* Added custom post_lookup_hook in ShardedModule to capture IDs and embeddings after lookup operations. This provides tracking ids/states natively into torchrec without registering any nn.Module specific hooks.
* Added post_odist_hook for auto-compaction of tracked data. This custom hook provides native support for overlapping compaction with odist.
* Implemented pre_forward callables in DMP for operations like batch index incrementation

**Model Parallel API Enhancements**:
* Added `get_model_tracker()` method to DistributedModelParallel for direct access to the ModelDeltaTracker instance. This API give the flexibility to integrate model tracker into required components directly without needing to access the dmp_module.
* Added `get_delta()` method as a convenience API to retrieve delta rows from dmp_module.

**Embedding Module Changes**:
* Enhanced ShardedEmbeddingBag and ShardedEmbedding to support tracking hooks / Callable
* Added hook registration methods in embedding modules
* Implemented tracking support for different optimizer states (momentum, Adam states)

## ModelDeltaTracker Context
ModelDeltaTracker is a utility for tracking and retrieving unique IDs and their corresponding embeddings or states from embedding modules in model using Torchrec. It's particularly useful for:

1. Identifying which embedding rows were accessed during model execution
2. Retrieving the latest delta or unique rows for a model
3. Computing top-k changed embeddings
4. Supporting streaming updated embeddings between systems during online training

For more details see diff:D75853147 or PR pytorch#3057

Differential Revision: D76202371
@aliafzal aliafzal force-pushed the export-D76202371 branch from 9505130 to c04f2da Compare June 9, 2025 23:29
Summary:
Pull Request resolved: pytorch#3064

## This Diff
Adds ModelDeltaTracker integration with DMP (DistributedModelParallel) and sharded modules. This integration enables tracking of embedding IDs, embeddings, and optimizer states during model execution, which is particularly useful for online training scenarios.

### Key Components:
**ModelTrackerConfig Integration**:
* Added ModelTrackerConfig parameter to DMP constructor
* When provided, automatically initializes ModelDeltaTracker
*  Configurable options include tracking_mode, delete_on_read, auto_compact, and fqns_to_skip

**Custom Callables for Tracking**:
* Added custom post_lookup_hook in ShardedModule to capture IDs and embeddings after lookup operations. This provides tracking ids/states natively into torchrec without registering any nn.Module specific hooks.
* Added post_odist_hook for auto-compaction of tracked data. This custom hook provides native support for overlapping compaction with odist.
* Implemented pre_forward callables in DMP for operations like batch index incrementation

**Model Parallel API Enhancements**:
* Added `get_model_tracker()` method to DistributedModelParallel for direct access to the ModelDeltaTracker instance. This API give the flexibility to integrate model tracker into required components directly without needing to access the dmp_module.
* Added `get_delta()` method as a convenience API to retrieve delta rows from dmp_module.

**Embedding Module Changes**:
* Enhanced ShardedEmbeddingBag and ShardedEmbedding to support tracking hooks / Callable
* Added hook registration methods in embedding modules
* Implemented tracking support for different optimizer states (momentum, Adam states)

## ModelDeltaTracker Context
ModelDeltaTracker is a utility for tracking and retrieving unique IDs and their corresponding embeddings or states from embedding modules in model using Torchrec. It's particularly useful for:

1. Identifying which embedding rows were accessed during model execution
2. Retrieving the latest delta or unique rows for a model
3. Computing top-k changed embeddings
4. Supporting streaming updated embeddings between systems during online training

For more details see diff:D75853147 or PR pytorch#3057

Differential Revision: D76202371
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D76202371

@aliafzal aliafzal force-pushed the export-D76202371 branch from c04f2da to 61b208e Compare June 9, 2025 23:34
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. fb-exported
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants