Skip to content

ChatGLM4模型由于transformers版本问题部分解决 #52

@3186218763

Description

@3186218763

ChatGLMForConditionalGeneration部分需要额外继承transformers的GenerationMixin,

`class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel, GenerationMixin):
def init(self, config: ChatGLMConfig, empty_init=True, device=None):
super().init(config)

    self.max_sequence_length = config.max_length  # 最大序列长度
    self.transformer = ChatGLMModel(config, empty_init=empty_init, device=device)  # 使用 ChatGLMModel 类
    self.config = config

def _update_model_kwargs_for_generation(
        self,
        outputs: ModelOutput,
        model_kwargs: Dict[str, Any],
        is_encoder_decoder: bool = False,
        standardize_cache_format: bool = False,
) -> Dict[str, Any]:
    # 更新 past_key_values
    _, model_kwargs["past_key_values"] = self._extract_past_from_model_output(
        outputs
    )

`
self._extract_past_from_model_output方法和之前不一样,传入请删除standardize_cache_format,高版本没有这个参数,返回也变成了cache_name, past_key_values

把model_kwargs["past_key_values"] = self._extract_past_from_model_output(
outputs, standardize_cache_format=standardize_cache_format
)
变成
_, model_kwargs["past_key_values"] = self._extract_past_from_model_output(
outputs
)就欧克了

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions