Skip to content

Commit 536c935

Browse files
committed
update docs
1 parent bcaeb5f commit 536c935

File tree

11 files changed

+433
-118
lines changed

11 files changed

+433
-118
lines changed

docs/howto/model.rst

+3-3
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ Say we wanted to build a custom bi-encoder model that adds an additional linear
2424
super().__init__(**kwargs)
2525
self.additional_linear_layer = additional_linear_layer
2626
27-
Next, we need to subclass the :py:class:`lightning_ir.bi_encoder.model.BiEncoderModel` and override the :py:class:`lightning_ir.bi_encoder.model.BiEncoderModel._encode` method to include the additional linear layer. We also need to ensure that our new config class is registered with our new model as the :py:meth:`~lightning_ir.bi_encoder.model.BiEncoderModel.config_class` attribute. In the :py:class:`lightning_ir.bi_encoder.model.BiEncoderModel._encode` method, the :py:meth:`~lightning_ir.bi_encoder.model.BiEncoderModel._backbone_forward` method runs the backbone model and returns the contextualized embeddings of the input sequence. We then apply our additional linear layer to the pooled embeddings. Afterwards, the various steps of the processing pipeline for bi-encoders are applied (see :ref:`concepts-model` for more details). For example:
27+
Next, we need to subclass the :py:class:`lightning_ir.bi_encoder.model.BiEncoderModel` and override the :py:class:`lightning_ir.bi_encoder.model.BiEncoderModel.encode` method to include the additional linear layer. We also need to ensure that our new config class is registered with our new model as the :py:meth:`~lightning_ir.bi_encoder.model.BiEncoderModel.config_class` attribute. In the :py:class:`lightning_ir.bi_encoder.model.BiEncoderModel.encode` method, the :py:meth:`~lightning_ir.bi_encoder.model.BiEncoderModel._backbone_forward` method runs the backbone model and returns the contextualized embeddings of the input sequence. We then apply our additional linear layer to the pooled embeddings. Afterwards, the various steps of the processing pipeline for bi-encoders are applied (see :ref:`concepts-model` for more details). For example:
2828

2929
.. code-block:: python
3030
@@ -46,7 +46,7 @@ Next, we need to subclass the :py:class:`lightning_ir.bi_encoder.model.BiEncoder
4646
config.hidden_size, config.hidden_size
4747
)
4848
49-
def _encode(
49+
def encode(
5050
self,
5151
encoding: BatchEncoding,
5252
expansion: bool = False,
@@ -62,7 +62,7 @@ Next, we need to subclass the :py:class:`lightning_ir.bi_encoder.model.BiEncoder
6262
embeddings = self._pooling(embeddings, encoding["attention_mask"], pooling_strategy)
6363
if self.config.normalize:
6464
embeddings = torch.nn.functional.normalize(embeddings, dim=-1)
65-
scoring_mask = self._scoring_mask(
65+
scoring_mask = self.scoring_mask(
6666
encoding["input_ids"],
6767
encoding["attention_mask"],
6868
expansion,

examples/custom_bi_encoder.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def __init__(self, config, *args, **kwargs):
3636
if config.additional_linear_layer:
3737
self.additional_linear_layer = torch.nn.Linear(config.hidden_size, config.hidden_size)
3838

39-
def _encode(
39+
def encode(
4040
self,
4141
encoding: BatchEncoding,
4242
expansion: bool = False,
@@ -52,7 +52,7 @@ def _encode(
5252
embeddings = self._pooling(embeddings, encoding["attention_mask"], pooling_strategy)
5353
if self.config.normalize:
5454
embeddings = torch.nn.functional.normalize(embeddings, dim=-1)
55-
scoring_mask = self._scoring_mask(
55+
scoring_mask = self.scoring_mask(
5656
encoding["input_ids"],
5757
encoding["attention_mask"],
5858
expansion,

lightning_ir/base/model.py

+16-5
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from dataclasses import dataclass
99
from functools import partial, wraps
1010
from pathlib import Path
11-
from typing import Any, Callable, Literal, Mapping, Sequence, Type, TypeVar
11+
from typing import Any, Callable, Literal, Mapping, Protocol, Sequence, Type, TypeVar
1212

1313
import torch
1414
from transformers import MODEL_MAPPING, BatchEncoding, BertModel
@@ -232,10 +232,21 @@ def _cat_outputs(
232232
return OutputClass(**{key: _cat_outputs(value, types[key]) for key, value in agg.items()})
233233

234234

235-
def _batch_encoding(
236-
func: Callable[[LightningIRModel, BatchEncoding, ...], Any]
237-
) -> Callable[[LightningIRModel, BatchEncoding, ...], Any]:
238-
"""Decorator to enable sub-batching for models that support it."""
235+
class BatchEncodingWrapper(Protocol):
236+
def __call__(self, encoding: BatchEncoding, *args, **kwargs) -> Any: ...
237+
238+
239+
def batch_encoding_wrapper(func: BatchEncodingWrapper) -> BatchEncodingWrapper:
240+
"""Decorator to enable sub-batching for models that support it. Lowers the batch size of the input batch encoding
241+
if the model runs out of memory.
242+
243+
:param func: Function to wrap that takes a batch encoding
244+
:type func: BatchEncodingWrapper
245+
:raises e: If CUDA runs out of memory even after lowering the batch size to 1
246+
:raises ValueError: If no output was generated
247+
:return: Wrapped function
248+
:rtype: BatchEncodingWrapper
249+
"""
239250

240251
@wraps(func)
241252
def wrapper(self, encoding: BatchEncoding, *args, **kwargs) -> Any:

lightning_ir/base/module.py

+4-11
Original file line numberDiff line numberDiff line change
@@ -165,15 +165,8 @@ def prepare_input(
165165
encodings[key] = encodings[key].to(self.device)
166166
return encodings
167167

168-
def compute_losses(self, batch: TrainBatch, output: LightningIROutput) -> List[torch.Tensor]:
169-
"""Computes the losses for the batch.
170-
171-
:param batch: Batch of training data
172-
:type batch: TrainBatch
173-
:raises NotImplementedError: Must be implemented by derived class
174-
:return: List of losses, one for each loss function
175-
:rtype: List[torch.Tensor]
176-
"""
168+
def _compute_losses(self, batch: TrainBatch, output: LightningIROutput) -> List[torch.Tensor]:
169+
"""Computes the losses for a training batch."""
177170
raise NotImplementedError
178171

179172
def training_step(self, batch: TrainBatch, batch_idx: int) -> torch.Tensor:
@@ -190,7 +183,7 @@ def training_step(self, batch: TrainBatch, batch_idx: int) -> torch.Tensor:
190183
if self.loss_functions is None:
191184
raise ValueError("Loss functions are not set")
192185
output = self.forward(batch)
193-
losses = self.compute_losses(batch, output)
186+
losses = self._compute_losses(batch, output)
194187
total_loss = torch.tensor(0)
195188
assert len(losses) == len(self.loss_functions)
196189
for (loss_function, loss_weight), loss in zip(self.loss_functions, losses):
@@ -205,7 +198,7 @@ def validation_step(
205198
"""Handles the validation step for the model.
206199
207200
:param batch: Batch of validation or testing data
208-
:type batch: TrainBatch | RankBatch
201+
:type batch: TrainBatch | RankBatch | SearchBatch
209202
:param batch_idx: Index of the batch
210203
:type batch_idx: int
211204
:param dataloader_idx: Index of the dataloader, defaults to 0

lightning_ir/bi_encoder/__init__.py

+5
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,8 @@
1+
"""Base module for bi-encoder models.
2+
3+
This module provides the main classes and functions for bi-encoder models, including configurations, models,
4+
modules, and tokenizers."""
5+
16
from .config import BiEncoderConfig
27
from .model import BiEncoderEmbedding, BiEncoderModel, BiEncoderOutput, ScoringFunction
38
from .module import BiEncoderModule

lightning_ir/bi_encoder/config.py

+38-2
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,9 @@
1+
"""
2+
Configuration module for bi-encoder models.
3+
4+
This module defines the configuration class used to instantiate bi-encoder models.
5+
"""
6+
17
import json
28
import os
39
from os import PathLike
@@ -109,22 +115,52 @@ def __init__(
109115
self.projection = projection
110116

111117
def to_dict(self) -> Dict[str, Any]:
118+
"""Overrides the transformers.PretrainedConfig.to_dict_ method to include the added arguments, the backbone
119+
model type, and remove the mask scoring tokens.
120+
121+
.. _transformers.PretrainedConfig.to_dict: \
122+
https://huggingface.co/docs/transformers/en/main_classes/configuration#transformers.PretrainedConfig.to_dict
123+
124+
:return: Configuration dictionary
125+
:rtype: Dict[str, Any]
126+
"""
127+
112128
output = super().to_dict()
113129
if "query_mask_scoring_tokens" in output:
114130
output.pop("query_mask_scoring_tokens")
115131
if "doc_mask_scoring_tokens" in output:
116132
output.pop("doc_mask_scoring_tokens")
117133
return output
118134

119-
def save_pretrained(self, save_directory: str | PathLike, push_to_hub: bool = False, **kwargs):
135+
def save_pretrained(self, save_directory: str | PathLike, **kwargs) -> None:
136+
"""Overrides the transformers.PretrainedConfig.save_pretrained_ method to addtionally save the tokens which
137+
should be maksed during scoring.
138+
139+
.. _transformers.PretrainedConfig.save_pretrained: \
140+
https://huggingface.co/docs/transformers/en/main_classes/configuration#transformers.PretrainedConfig.save_pretrained
141+
142+
:param save_directory: Directory to save the configuration
143+
:type save_directory: str | PathLike
144+
"""
120145
with open(os.path.join(save_directory, "mask_scoring_tokens.json"), "w") as f:
121146
json.dump({"query": self.query_mask_scoring_tokens, "doc": self.doc_mask_scoring_tokens}, f)
122-
return super().save_pretrained(save_directory, push_to_hub, **kwargs)
147+
return super().save_pretrained(save_directory, **kwargs)
123148

124149
@classmethod
125150
def get_config_dict(
126151
cls, pretrained_model_name_or_path: str | PathLike, **kwargs
127152
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
153+
"""Overrides the transformers.PretrainedConfig.get_config_dict_ method to load the tokens that should be masked
154+
during scoring.
155+
156+
.. _transformers.PretrainedConfig.get_config_dict: \
157+
https://huggingface.co/docs/transformers/en/main_classes/configuration#transformers.PretrainedConfig.get_config_dict
158+
159+
:param pretrained_model_name_or_path: Name or path of the pretrained model
160+
:type pretrained_model_name_or_path: str | PathLike
161+
:return: Configuration dictionary and additional keyword arguments
162+
:rtype: Tuple[Dict[str, Any], Dict[str, Any]]
163+
"""
128164
config_dict, kwargs = super().get_config_dict(pretrained_model_name_or_path, **kwargs)
129165
mask_scoring_tokens = None
130166
mask_scoring_tokens_path = os.path.join(pretrained_model_name_or_path, "mask_scoring_tokens.json")

0 commit comments

Comments
 (0)