|
| 1 | +""" |
| 2 | +Configuration module for bi-encoder models. |
| 3 | +
|
| 4 | +This module defines the configuration class used to instantiate bi-encoder models. |
| 5 | +""" |
| 6 | + |
1 | 7 | import json
|
2 | 8 | import os
|
3 | 9 | from os import PathLike
|
@@ -109,22 +115,52 @@ def __init__(
|
109 | 115 | self.projection = projection
|
110 | 116 |
|
111 | 117 | def to_dict(self) -> Dict[str, Any]:
|
| 118 | + """Overrides the transformers.PretrainedConfig.to_dict_ method to include the added arguments, the backbone |
| 119 | + model type, and remove the mask scoring tokens. |
| 120 | +
|
| 121 | + .. _transformers.PretrainedConfig.to_dict: \ |
| 122 | +https://huggingface.co/docs/transformers/en/main_classes/configuration#transformers.PretrainedConfig.to_dict |
| 123 | +
|
| 124 | + :return: Configuration dictionary |
| 125 | + :rtype: Dict[str, Any] |
| 126 | + """ |
| 127 | + |
112 | 128 | output = super().to_dict()
|
113 | 129 | if "query_mask_scoring_tokens" in output:
|
114 | 130 | output.pop("query_mask_scoring_tokens")
|
115 | 131 | if "doc_mask_scoring_tokens" in output:
|
116 | 132 | output.pop("doc_mask_scoring_tokens")
|
117 | 133 | return output
|
118 | 134 |
|
119 |
| - def save_pretrained(self, save_directory: str | PathLike, push_to_hub: bool = False, **kwargs): |
| 135 | + def save_pretrained(self, save_directory: str | PathLike, **kwargs) -> None: |
| 136 | + """Overrides the transformers.PretrainedConfig.save_pretrained_ method to addtionally save the tokens which |
| 137 | + should be maksed during scoring. |
| 138 | +
|
| 139 | + .. _transformers.PretrainedConfig.save_pretrained: \ |
| 140 | +https://huggingface.co/docs/transformers/en/main_classes/configuration#transformers.PretrainedConfig.save_pretrained |
| 141 | +
|
| 142 | + :param save_directory: Directory to save the configuration |
| 143 | + :type save_directory: str | PathLike |
| 144 | + """ |
120 | 145 | with open(os.path.join(save_directory, "mask_scoring_tokens.json"), "w") as f:
|
121 | 146 | json.dump({"query": self.query_mask_scoring_tokens, "doc": self.doc_mask_scoring_tokens}, f)
|
122 |
| - return super().save_pretrained(save_directory, push_to_hub, **kwargs) |
| 147 | + return super().save_pretrained(save_directory, **kwargs) |
123 | 148 |
|
124 | 149 | @classmethod
|
125 | 150 | def get_config_dict(
|
126 | 151 | cls, pretrained_model_name_or_path: str | PathLike, **kwargs
|
127 | 152 | ) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
| 153 | + """Overrides the transformers.PretrainedConfig.get_config_dict_ method to load the tokens that should be masked |
| 154 | + during scoring. |
| 155 | +
|
| 156 | + .. _transformers.PretrainedConfig.get_config_dict: \ |
| 157 | +https://huggingface.co/docs/transformers/en/main_classes/configuration#transformers.PretrainedConfig.get_config_dict |
| 158 | +
|
| 159 | + :param pretrained_model_name_or_path: Name or path of the pretrained model |
| 160 | + :type pretrained_model_name_or_path: str | PathLike |
| 161 | + :return: Configuration dictionary and additional keyword arguments |
| 162 | + :rtype: Tuple[Dict[str, Any], Dict[str, Any]] |
| 163 | + """ |
128 | 164 | config_dict, kwargs = super().get_config_dict(pretrained_model_name_or_path, **kwargs)
|
129 | 165 | mask_scoring_tokens = None
|
130 | 166 | mask_scoring_tokens_path = os.path.join(pretrained_model_name_or_path, "mask_scoring_tokens.json")
|
|
0 commit comments