Skip to content
This repository has been archived by the owner on Nov 1, 2024. It is now read-only.

Deprecate convert_to_singleton #691

Open
andrewPoulton opened this issue Mar 24, 2023 · 13 comments
Open

Deprecate convert_to_singleton #691

andrewPoulton opened this issue Mar 24, 2023 · 13 comments
Assignees
Labels
enhancement New feature or request

Comments

@andrewPoulton
Copy link
Contributor

As noted in #689, convert_to_singleton doesn't produce statedicts with compatible keys (for some unknown reason).

Since reshard_mp can do the same job, without the GPU node requirement of convert_to_singleton, we should deprecate convert_to_singleton.

TODO: Work out dependencies on covert_to_singleton, and identify any special cases it can handle that reshard_mp can't (such as separating out qkv weights, as noted by @tangbinh)

@andrewPoulton andrewPoulton added the enhancement New feature or request label Mar 24, 2023
@andrewPoulton andrewPoulton self-assigned this Mar 24, 2023
@larekrow
Copy link

larekrow commented Mar 31, 2023

reshard_mp.py --num-output-parts 1 currently does not work with the OPT weights. Please see #695.

@ayeeyecorp
Copy link

Since reshard_mp can do the same job, without the GPU node requirement of convert_to_singleton, we should deprecate convert_to_singleton.

@andrewPoulton Is this true though? I was unable to convert 8 shards successfully to restored.pt using:

python -m metaseq.scripts.reshard_mp \
--input "opt/shards/reshard-model_part-*.pt" \
--output "opt/pt/reshard_no_os_mp8/reshard-model_part-{i}.pt" \
--num-output-parts 1

@tangbinh
Copy link
Contributor

tangbinh commented Apr 5, 2023

I was unable to convert 8 shards successfully to restored.pt using:

@ayeeyecorp Can you share the stack trace? I suspect it might be related to the fact that the checkpoints available on the OPT page are flattened, which are are not compatible with reshard.mp.

@andrewPoulton
Copy link
Contributor Author

@tangbinh let's add a flat param check to reshard_*, and raise an error unless user specifically wants to unflatten. I'll create an issue to track in a bit. Happy to own as well.

@tangbinh
Copy link
Contributor

tangbinh commented Apr 5, 2023

@andrewPoulton I was adding an option to split the KVQ weights in reshard_mp, but I think this is probably not needed for 2 reasons:

  1. This weight splitting has already been included in the script convert_opt_original_pytorch_checkpoint_to_pytorch.py. Previously, there was a bug that basically turned this off, but it has been fixed (see Fix convert_opt_original_pytorch_checkpoint_to_pytorch.py typo huggingface/transformers#22526).
  2. Previously, we supported both transformer_lm and transformer_lm_megatron models, but the the former has been removed in Unify transformer_lm_megatron and transformer_lm #633. Therefore, there's no need to split KVQ weights within Metaseq.

Once we fixed #625, I think we can safely remove convert_to_singleton.py as users are able to load OPT checkpoints using reshard_mp.py and convert_opt_original_pytorch_checkpoint_to_pytorch.py and Huggingface Transformers.

@ayeeyecorp
Copy link

ayeeyecorp commented Apr 5, 2023

I was unable to convert 8 shards successfully to restored.pt using:

@ayeeyecorp Can you share the stack trace? I suspect it might be related to the fact that the checkpoints available on the OPT page are flattened, which are are not compatible with reshard.mp.

@andrewPoulton - I did not save the stack trace from that particular test - can redo. However, here is the tail end snippet of the stack trace after running ./metaseq/metaseq/scripts/convert_to_singleton on OPT-175B checkpoints:

raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( RuntimeError: Error(s) in loading state_dict for FlattenParamsWrapper: Missing key(s) in state_dict: "_fpw_module.decoder.layers.0._fsdp_wrapped_module.flat_param_0", "_fpw_module.decoder.layers.1._fsdp_wrapped_module.flat_param_0", "_fpw_module.decoder.layers.2._fsdp_wrapped_module.flat_param_0", "_fpw_module.decoder.layers.3._fsdp_wrapped_module.flat_param_0", "_fpw_module.decoder.layers.4._fsdp_wrapped_module.flat_param_0", "_fpw_module.decoder.layers.5._fsdp_wrapped_module.flat_param_0", "_fpw_module.decoder.layers.6._fsdp_wrapped_module.flat_param_0", "_fpw_module.decoder.layers.7._fsdp_wrapped_module.flat_param_0", "_fpw_module.decoder.layers.8._fsdp_wrapped_module.flat_param_0", "_fpw_module.decoder.layers.9._fsdp_wrapped_module.flat_param_0", "_fpw_module.decoder.layers.10._fsdp_wrapped_module.flat_param_0", "_fpw_module.decoder.layers.11._fsdp_wrapped_module.flat_param_0", "_fpw_module.decoder.layers.12._fsdp_wrapped_module.flat_param_0", "_fpw_module.decoder.layers.13._fsdp_wrapped_module.flat_param_0", "_fpw_module.decoder.layers.14._fsdp_wrapped_module.flat_param_0", "_fpw_module.decoder.layers.15._fsdp_wrapped_module.flat_param_0", "_fpw_module.decoder.layers.16._fsdp_wrapped_module.flat_param_0", "_fpw_module.decoder.layers.17._fsdp_wrapped_module.flat_param_0", "_fpw_module.decoder.layers.18._fsdp_wrapped_module.flat_param_0", "_fpw_module.decoder.layers.19._fsdp_wrapped_module.flat_param_0", "_fpw_module.decoder.layers.20._fsdp_wrapped_module.flat_param_0", "_fpw_module.decoder.layers.21._fsdp_wrapped_module.flat_param_0", "_fpw_module.decoder.layers.22._fsdp_wrapped_module.flat_param_0", "_fpw_module.decoder.layers.23._fsdp_wrapped_module.flat_param_0", "_fpw_module.decoder.layers.24._fsdp_wrapped_module.flat_param_0", "_fpw_module.decoder.layers.25._fsdp_wrapped_module.flat_param_0", "_fpw_module.decoder.layers.26._fsdp_wrapped_module.flat_param_0", "_fpw_module.decoder.layers.27._fsdp_wrapped_module.flat_param_0", "_fpw_module.decoder.layers.28._fsdp_wrapped_module.flat_param_0", "_fpw_module.decoder.layers.29._fsdp_wrapped_module.flat_param_0", "_fpw_module.decoder.layers.30._fsdp_wrapped_module.flat_param_0", "_fpw_module.decoder.layers.31._fsdp_wrapped_module.flat_param_0", "_fpw_module.decoder.layers.32._fsdp_wrapped_module.flat_param_0", "_fpw_module.decoder.layers.33._fsdp_wrapped_module.flat_param_0", "_fpw_module.decoder.layers.34._fsdp_wrapped_module.flat_param_0", "_fpw_module.decoder.layers.35._fsdp_wrapped_module.flat_param_0", "_fpw_module.decoder.layers.36._fsdp_wrapped_module.flat_param_0", "_fpw_module.decoder.layers.37._fsdp_wrapped_module.flat_param_0", "_fpw_module.decoder.layers.38._fsdp_wrapped_module.flat_param_0", "_fpw_module.decoder.layers.39._fsdp_wrapped_module.flat_param_0", "_fpw_module.decoder.layers.40._fsdp_wrapped_module.flat_param_0", "_fpw_module.decoder.layers.41._fsdp_wrapped_module.flat_param_0", "_fpw_module.decoder.layers.42._fsdp_wrapped_module.flat_param_0", "_fpw_module.decoder.layers.43._fsdp_wrapped_module.flat_param_0", "_fpw_module.decoder.layers.44._fsdp_wrapped_module.flat_param_0", "_fpw_module.decoder.layers.45._fsdp_wrapped_module.flat_param_0", "_fpw_module.decoder.layers.46._fsdp_wrapped_module.flat_param_0", "_fpw_module.decoder.layers.47._fsdp_wrapped_module.flat_param_0". Unexpected key(s) in state_dict: "_fpw_module.decoder.layers.0._fsdp_wrapped_module._fpw_module.self_attn.qkv_proj.weight", "_fpw_module.decoder.layers.0._fsdp_wrapped_module._fpw_module.self_attn.qkv_proj.bias", "_fpw_module.decoder.layers.0._fsdp_wrapped_module._fpw_module.self_attn.out_proj.weight", "_fpw_module.decoder.layers.0._fsdp_wrapped_module._fpw_module.self_attn.out_proj.bias", "_fpw_module.decoder.layers.0._fsdp_wrapped_module._fpw_module.self_attn_layer_norm.weight", "_fpw_module.decoder.layers.0._fsdp_wrapped_module._fpw_module.self_attn_layer_norm.bias", "_fpw_module.decoder.layers.0._fsdp_wrapped_module._fpw_module.fc1.weight", "_fpw_module.decoder.layers.0._fsdp_wrapped_module._fpw_module.fc1.bias", "_fpw_module.decoder.layers.0._fsdp_wrapped_module._fpw_module.fc2.weight", "_fpw_module.decoder.layers.0._fsdp_wrapped_module._fpw_module.fc2.bias", "_fpw_module.decoder.layers.0._fsdp_wrapped_module._fpw_module.final_layer_norm.weight", "_fpw_module.decoder.layers.0._fsdp_wrapped_module._fpw_module.final_layer_norm.bias", "_fpw_module.decoder.layers.1._fsdp_wrapped_module._fpw_module.self_attn.qkv_proj.weight", "_fpw_module.decoder.layers.1._fsdp_wrapped_module._fpw_module.self_attn.qkv_proj.bias", "_fpw_module.decoder.layers.1._fsdp_wrapped_module._fpw_module.self_attn.out_proj.weight", "_fpw_module.decoder.layers.1._fsdp_wrapped_module._fpw_module.self_attn.out_proj.bias", "_fpw_module.decoder.layers.1._fsdp_wrapped_module._fpw_module.self_attn_layer_norm.weight", "_fpw_module.decoder.layers.1._fsdp_wrapped_module._fpw_module.self_attn_layer_norm.bias", "_fpw_module.decoder.layers.1._fsdp_wrapped_module._fpw_module.fc1.weight", "_fpw_module.decoder.layers.1._fsdp_wrapped_module._fpw_module.fc1.bias", "_fpw_module.decoder.layers.1._fsdp_wrapped_module._fpw_module.fc2.weight", "_fpw_module.decoder.layers.1._fsdp_wrapped_module._fpw_module.fc2.bias", "_fpw_module.decoder.layers.1._fsdp_wrapped_module._fpw_module.final_layer_norm.weight", "_fpw_module.decoder.layers.1._fsdp_wrapped_module._fpw_module.final_layer_norm.bias", "_fpw_module.decoder.layers.2._fsdp_wrapped

The 992 shards were first converted to reshard-model_part-$j.pt using:

for j in {0..7}; do
    python3 -m ./metaseq/metaseq/scripts/reshard_fsdp \
    --input-glob-pattern "./checkpoint_last-model_part-$j-shard*.pt" \
    --output-shard-name "./reshard/reshard-model_part-$j.pt" \
    --num-output-shards 1 --skip-optimizer-state True --unflatten-weights True
done

Should I have set --unflatten-weights False in order for metaseq.scripts.convert_to_singleton and metaseq.scripts.reshard_mp to work correctly?

@andrewPoulton
Copy link
Contributor Author

@ayeeyecorp Just so I'm clear - you first ran reshard_fsdp on the shards (with unflatten-weights=true), then tried running convert_to_singleton on the consolidated shards? If that's so, then can you try running reshard_mp on the consolidated shards instead?

@ayeeyecorp
Copy link

@andrewPoulton

you first ran reshard_fsdp on the shards (with unflatten-weights=true), then tried running convert_to_singleton on the consolidated shards?

Correct, this resulted in the state_dict error

If that's so, then can you try running reshard_mp on the consolidated shards instead?

Will do that again shortly and post stack trace results.

@tangbinh
Copy link
Contributor

tangbinh commented Apr 5, 2023

@ayeeyecorp May I ask why you want to convert the 8 MP parts of OPT 175B into a singleton? I don't think you would be able to load the singleton into any GPU considering its size, which is about 350GB.

Should I have set --unflatten-weights False in order for metaseq.scripts.convert_to_singleton and metaseq.scripts.reshard_mp to work correctly?

convert_to_singleton expects flattened weights; that's probably why you got Missing key(s) in state_dict. However, reshard_mp expects unflattened weights. As suggested by @andrewPoulton, please try to use reshard_mp instead, as we're deprecating convert_to_singleton.

@ayeeyecorp
Copy link

ayeeyecorp commented Apr 6, 2023

@andrewPoulton

I started over earlier today from the 992 shards (resetting my environment per the instructions here using Python3.8) and verified that the 8 consolidated FSDP shards had the correct md5sum. Upon confirmation, I converted the checkpoints, to eliminate use of MP, to 1 with the reshard_mp.py script, with no issues this time, using:

python -m metaseq.scripts.reshard_mp \
    --input "/path/to/resharded/checkpoints/reshard-model_part-*.pt" \
    --output "/path/to/mp/resharded/checkpoints/reshard-model_part-{i}.pt" \
    --num-output-parts 1

Not sure what the original problem was. The md5sum of the single checkpoint (325.2 GB) was: 06e7e7ed424db3834ccd1a776d82ff14

The subsequent step to convert to hugging face using:

python3 transformers.src.transformers.models.opt.convert_opt_original_pytorch_checkpoint_to_pytorch --pytorch_dump_folder_path ~/opt_meta/hugging/ --hf_config config.json --fairseq_path ~/opt_meta/single_shard/reshard-model_part-0.pt,

failed after 1+ hour with the following stack trace:

       size mismatch for decoder.layers.8.final_layer_norm.weight: copying a param with shape torch.Size([12288]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for decoder.layers.8.final_layer_norm.bias: copying a param with shape torch.Size([12288]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for decoder.layers.9.self_attn.k_proj.weight: copying a param with shape torch.Size([12288, 12288]) from checkpoint, the shape in current model is torch.Size([768, 768]).
        size mismatch for decoder.layers.9.self_attn.k_proj.bias: copying a param with shape torch.Size([12288]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for decoder.layers.9.self_attn.v_proj.weight: copying a param with shape torch.Size([12288, 12288]) from checkpoint, the shape in current model is torch.Size([768, 768]).
        size mismatch for decoder.layers.9.self_attn.v_proj.bias: copying a param with shape torch.Size([12288]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for decoder.layers.9.self_attn.q_proj.weight: copying a param with shape torch.Size([12288, 12288]) from checkpoint, the shape in current model is torch.Size([768, 768]).
        size mismatch for decoder.layers.9.self_attn.q_proj.bias: copying a param with shape torch.Size([12288]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for decoder.layers.9.self_attn.out_proj.weight: copying a param with shape torch.Size([12288, 12288]) from checkpoint, the shape in current model is torch.Size([768, 768]).
        size mismatch for decoder.layers.9.self_attn.out_proj.bias: copying a param with shape torch.Size([12288]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for decoder.layers.9.self_attn_layer_norm.weight: copying a param with shape torch.Size([12288]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for decoder.layers.9.self_attn_layer_norm.bias: copying a param with shape torch.Size([12288]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for decoder.layers.9.fc1.weight: copying a param with shape torch.Size([49152, 12288]) from checkpoint, the shape in current model is torch.Size([3072, 768]).
        size mismatch for decoder.layers.9.fc1.bias: copying a param with shape torch.Size([49152]) from checkpoint, the shape in current model is torch.Size([3072]).
        size mismatch for decoder.layers.9.fc2.weight: copying a param with shape torch.Size([12288, 49152]) from checkpoint, the shape in current model is torch.Size([768, 3072]).
        size mismatch for decoder.layers.9.fc2.bias: copying a param with shape torch.Size([12288]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for decoder.layers.9.final_layer_norm.weight: copying a param with shape torch.Size([12288]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for decoder.layers.9.final_layer_norm.bias: copying a param with shape torch.Size([12288]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for decoder.layers.10.self_attn.k_proj.weight: copying a param with shape torch.Size([12288, 12288]) from checkpoint, the shape in current model is torch.Size([768, 768]).
        size mismatch for decoder.layers.10.self_attn.k_proj.bias: copying a param with shape torch.Size([12288]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for decoder.layers.10.self_attn.v_proj.weight: copying a param with shape torch.Size([12288, 12288]) from checkpoint, the shape in current model is torch.Size([768, 768]).
        size mismatch for decoder.layers.10.self_attn.v_proj.bias: copying a param with shape torch.Size([12288]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for decoder.layers.10.self_attn.q_proj.weight: copying a param with shape torch.Size([12288, 12288]) from checkpoint, the shape in current model is torch.Size([768, 768]).
        size mismatch for decoder.layers.10.self_attn.q_proj.bias: copying a param with shape torch.Size([12288]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for decoder.layers.10.self_attn.out_proj.weight: copying a param with shape torch.Size([12288, 12288]) from checkpoint, the shape in current model is torch.Size([768, 768]).
        size mismatch for decoder.layers.10.self_attn.out_proj.bias: copying a param with shape torch.Size([12288]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for decoder.layers.10.self_attn_layer_norm.weight: copying a param with shape torch.Size([12288]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for decoder.layers.10.self_attn_layer_norm.bias: copying a param with shape torch.Size([12288]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for decoder.layers.10.fc1.weight: copying a param with shape torch.Size([49152, 12288]) from checkpoint, the shape in current model is torch.Size([3072, 768]).
        size mismatch for decoder.layers.10.fc1.bias: copying a param with shape torch.Size([49152]) from checkpoint, the shape in current model is torch.Size([3072]).
        size mismatch for decoder.layers.10.fc2.weight: copying a param with shape torch.Size([12288, 49152]) from checkpoint, the shape in current model is torch.Size([768, 3072]).
        size mismatch for decoder.layers.10.fc2.bias: copying a param with shape torch.Size([12288]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for decoder.layers.10.final_layer_norm.weight: copying a param with shape torch.Size([12288]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for decoder.layers.10.final_layer_norm.bias: copying a param with shape torch.Size([12288]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for decoder.layers.11.self_attn.k_proj.weight: copying a param with shape torch.Size([12288, 12288]) from checkpoint, the shape in current model is torch.Size([768, 768]).
        size mismatch for decoder.layers.11.self_attn.k_proj.bias: copying a param with shape torch.Size([12288]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for decoder.layers.11.self_attn.v_proj.weight: copying a param with shape torch.Size([12288, 12288]) from checkpoint, the shape in current model is torch.Size([768, 768]).
        size mismatch for decoder.layers.11.self_attn.v_proj.bias: copying a param with shape torch.Size([12288]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for decoder.layers.11.self_attn.q_proj.weight: copying a param with shape torch.Size([12288, 12288]) from checkpoint, the shape in current model is torch.Size([768, 768]).
        size mismatch for decoder.layers.11.self_attn.q_proj.bias: copying a param with shape torch.Size([12288]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for decoder.layers.11.self_attn.out_proj.weight: copying a param with shape torch.Size([12288, 12288]) from checkpoint, the shape in current model is torch.Size([768, 768]).
        size mismatch for decoder.layers.11.self_attn.out_proj.bias: copying a param with shape torch.Size([12288]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for decoder.layers.11.self_attn_layer_norm.weight: copying a param with shape torch.Size([12288]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for decoder.layers.11.self_attn_layer_norm.bias: copying a param with shape torch.Size([12288]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for decoder.layers.11.fc1.weight: copying a param with shape torch.Size([49152, 12288]) from checkpoint, the shape in current model is torch.Size([3072, 768]).
        size mismatch for decoder.layers.11.fc1.bias: copying a param with shape torch.Size([49152]) from checkpoint, the shape in current model is torch.Size([3072]).
        size mismatch for decoder.layers.11.fc2.weight: copying a param with shape torch.Size([12288, 49152]) from checkpoint, the shape in current model is torch.Size([768, 3072]).
        size mismatch for decoder.layers.11.fc2.bias: copying a param with shape torch.Size([12288]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for decoder.layers.11.final_layer_norm.weight: copying a param with shape torch.Size([12288]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for decoder.layers.11.final_layer_norm.bias: copying a param with shape torch.Size([12288]) from checkpoint, the shape in current model is torch.Size([768]).
(venv) [ec2-user@ip-172-31-28-254 opt]$

I followed @patrickvonplaten conversion instructions found here and generated a config.json with the following:

from transformers import OPTConfig
num_layers = 12
num_heads = 12
d_model = 768
config = OPTConfig(hidden_size=d_model, num_attention_heads=num_heads, num_hidden_layers=num_layers, ffn_dim=4*d_model)
config.save_pretrained("./")  # <- this will create a `config.json` in your current folder

Thoughts on what could be going wrong with the HF conversion? I will re-run the operation overnight and log the full failure stack trace.

@tangbinh - thank you for the clarification. I am converting the 8 MP parts of OPT 175B into a singleton to run quantization experiments against

@tangbinh
Copy link
Contributor

tangbinh commented Apr 6, 2023

@ayeeyecorp For OPT 175B, we should have num_layers = 96, num_heads = 96, and d_model = 12288.

@ayeeyecorp
Copy link

ayeeyecorp commented Apr 6, 2023

@ayeeyecorp For OPT 175B, we should have num_layers = 96, num_heads = 96, and d_model = 12288.

@tangbinh that was quick! brilliant, will give that a go now. I blindly used values from HF... thank you

@ayeeyecorp
Copy link

ayeeyecorp commented Apr 6, 2023

After updating instance to 1TB+ of RAM... I successfully generated a .bin file using src.transformers.models.opt.convert_opt_original_pytorch_checkpoint_to_pytorch!

Thanks for the support.

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

4 participants