Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

학습 모델 로드시 질문 #4

Open
robinsongh381 opened this issue Jan 14, 2020 · 1 comment
Open

학습 모델 로드시 질문 #4

robinsongh381 opened this issue Jan 14, 2020 · 1 comment

Comments

@robinsongh381
Copy link

안녕하세요

좋은 자료 공유해주셔서 감사의 말씀을 우선 정합니다.

Inference.py 에서

    convert_keys = {}
    for k, v in checkpoint['model_state_dict'].items():
        new_key_name = k.replace("module.", '')
        if new_key_name not in model_dict:
            print("{} is not int model_dict".format(new_key_name))
            continue
        convert_keys[new_key_name] = v

다음과 같이 convert_keys 를 정의하고 model.load_state_dict(convert_keys)를 하셨는데,
왜 바로 model.load_state_dict(checkpoint['model_state_dict']) 하시지 않았는지 아니면 하면 안되는지 궁금하여 질문을 드립니다

감사합니다

@dave-rtzr
Copy link

아마 모델을 분산학습 시키셔서 모든 weight들의 이름이 module.~ 이런식으로 되어 있을 겁니다.
따라서 그냥 model.load_state_dict(checkpoint['model_state_dict'])를 하면 key error가 발생합니다.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants