generated from datawhalechina/repo-template
-
Notifications
You must be signed in to change notification settings - Fork 520
Open
Description
ChatGLMForConditionalGeneration部分需要额外继承transformers的GenerationMixin,
`class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel, GenerationMixin):
def init(self, config: ChatGLMConfig, empty_init=True, device=None):
super().init(config)
self.max_sequence_length = config.max_length # 最大序列长度
self.transformer = ChatGLMModel(config, empty_init=empty_init, device=device) # 使用 ChatGLMModel 类
self.config = config
def _update_model_kwargs_for_generation(
self,
outputs: ModelOutput,
model_kwargs: Dict[str, Any],
is_encoder_decoder: bool = False,
standardize_cache_format: bool = False,
) -> Dict[str, Any]:
# 更新 past_key_values
_, model_kwargs["past_key_values"] = self._extract_past_from_model_output(
outputs
)
`
self._extract_past_from_model_output方法和之前不一样,传入请删除standardize_cache_format,高版本没有这个参数,返回也变成了cache_name, past_key_values
把model_kwargs["past_key_values"] = self._extract_past_from_model_output(
outputs, standardize_cache_format=standardize_cache_format
)
变成
_, model_kwargs["past_key_values"] = self._extract_past_from_model_output(
outputs
)就欧克了
Metadata
Metadata
Assignees
Labels
No labels