Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

运行example/basic_language_model_gpt2_ml.py生成时报错ValueError: Error when checking model input #496

Open
nameless0704 opened this issue Nov 15, 2022 · 1 comment

Comments

@nameless0704
Copy link

nameless0704 commented Nov 15, 2022

提问时请尽可能提供如下信息:

基本信息

  • 你使用的操作系统: Win10
  • 你使用的Python版本: 3.7.9
  • 你使用的Tensorflow版本: 1.14.0
  • 你使用的Keras版本: 2.3.1
  • 你使用的bert4keras版本: 0.11.4
  • 你使用纯keras还是tf.keras: 纯keras
  • 你加载的预训练模型: roberta_zh_L-6-H-768_A-12,来自https://github.com/brightmart/roberta_zh

核心代码

#使用basic_language_model_gpt2_ml.py原文,仅model的model参数改为‘roberta’
class ArticleCompletion(AutoRegressiveDecoder):
    """基于随机采样的文章续写
    """
    @AutoRegressiveDecoder.wraps(default_rtype='probas')
    def predict(self, inputs, output_ids, states):
        token_ids = np.concatenate([inputs[0], output_ids], 1)
        return self.last_token(model).predict(token_ids)

    def generate(self, text, n=1, topp=0.95):
        token_ids, _ = tokenizer.encode(text)
        results = self.random_sample([token_ids], n, topp=topp)  # 基于随机采样
        return [text + tokenizer.decode(ids) for ids in results]

article_completion = ArticleCompletion(
    start_id=None,
    end_id=511,  # 511是中文句号
    maxlen=256,
    minlen=128
)
print(article_completion.generate(u'今天天气不错'))

输出信息

ValueError: Error when checking model input: the list of Numpy arrays that you are passing to your model is not the size the model expected. Expected to see 2 array(s), but instead got the following list of 1 arrays: [array([[ 791, 1921, 1921, 3698,  679, 7231]])]

自我尝试

看到了 Issue #446 里面写tf2.0有问题,但是我降到了1.15或者1.14试了都还是报错了,所以求救,谢谢。

@bojone
Copy link
Owner

bojone commented Nov 27, 2022

报错的原因就是roberta不能用来替换gpt,所以【仅model的model参数改为‘roberta’】就是错误原因所在。

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants