Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

During pre-training, using FA2 consumes more memory than using SDPA #172

Open
neavo opened this issue Jan 6, 2025 · 7 comments
Open

During pre-training, using FA2 consumes more memory than using SDPA #172

neavo opened this issue Jan 6, 2025 · 7 comments

Comments

@neavo
Copy link

neavo commented Jan 6, 2025

As described in the title

When performing pre-training, using FA2 will consume more GPU memory than using SDPA.

I am using the trainer from transformers for training, and the simplified code is roughly as follows:

model = AutoModelForMaskedLM.from_pretrained(
    MODEL_PATH,
    torch_dtype = torch.bfloat16 if is_torch_bf16_gpu_available() == True else torch.float16,
    attn_implementation = "flash_attention_2" # spda
).to("cuda" if torch.cuda.is_available() else "cpu")

training_args = TrainingArguments(
    bf16 = True,
    optim = "paged_adamw_8bit",
    warmup_ratio = 0.1,
    weight_decay = 5e-5,
    learning_rate = 5e-5,
    num_train_epochs = 1,
    per_device_eval_batch_size = 16,
    per_device_train_batch_size = 8,
    gradient_checkpointing = False,
)

trainer = Trainer(
    args = training_args,
    model = model,
    data_collator = DataCollatorForLanguageModeling(
        tokenizer = tokenizer,
        mlm = True,
        mlm_probability = 0.30,
        pad_to_multiple_of = 8,
    ),
    eval_dataset = eval_dataset,
    train_dataset = train_dataset,
    processing_class = tokenizer,
)

When all other parameters are kept consistent and only the attn_implementation is changed, the GPU memory usage rates are 48% and 88%, respectively.

When using FA2, the GPU memory usage is significantly higher than with SDPA and also much higher than with other traditional Bert-Like models, and there is no improvement in speed.

The same phenomenon has been observed on both Windows 11 24H2 and Ubuntu@WSL2.

ENVS:
PyTorch 2.5.1
Python 3.12.8
flash_attn v2.7.2.post1

fa2
sdpa

@neavo
Copy link
Author

neavo commented Jan 7, 2025

So is this a problem with my workflow or a bug?

@staghado
Copy link
Collaborator

staghado commented Jan 7, 2025

In general I think the FA2 support on Windows is not well tested(https://github.com/Dao-AILab/flash-attention?tab=readme-ov-file#installation-and-features). We only ever used Linux machines for the pre-training part.

@neavo
Copy link
Author

neavo commented Jan 7, 2025

In general I think the FA2 support on Windows is not well tested(https://github.com/Dao-AILab/flash-attention?tab=readme-ov-file#installation-and-features). We only ever used Linux machines for the pre-training part.

Yes, it's possible.
But as described in the main text, I also got the same results when running the training script in WSL2 Env.
Since FA2 works correctly during training of other models, I suspect this might be a bug specific to the HF Trainer or a configuration error. I wonder if anyone else can verify this.

@NohTow
Copy link
Collaborator

NohTow commented Jan 7, 2025

IIRC, compilation does not work properly on Windows/WSL, but this should not cause such a gap and should affect both path.
Could you try verifying that using attn_implementation = "flash_attention_2" indeed use the FA2 path?

As raised by @staghado, we mainly use Linux machines (and WSL is not totally equivalent to Linux machines), so it is a bit hard for us to debug.

Since FA2 works correctly during training of other models,

Could you give more information about which setup you are referring to? I believe Jina is the only encoder with FA.
Also, could you give the full boilerplate, as if I recall correctly, @tomaarsen uses WSL sometimes, so maybe he can try running the boilerplate and see if he experiences the same behavior.

@neavo
Copy link
Author

neavo commented Jan 7, 2025

Could you try verifying that using attn_implementation = "flash_attention_2" indeed use the FA2 path?

Yes, the data in the main text represents the results after specifying attn_implementation = "flash_attention_2".

Could you give more information about which setup you are referring to? I believe Jina is the only encoder with FA.

I may not have expressed it clearly: when training other types of models (such as Qwen-2.5) in the same dependency environment, FA2 is indeed working normally, and a significant reduction in memory usage can be observed. However, I have not tried using FA2 in the training of other Bert-Like models.

Also, could you give the full boilerplate, as if I recall correctly, @tomaarsen uses WSL sometimes, so maybe he can try running the boilerplate and see if he experiences the same behavior.

This is the most simplified script after removing the data logic, which can reproduce the issue mentioned earlier.
Place the plain text file containing the corpus, sample.txt, and the model folder modern_bert in the same directory, and then execute python sample.py.

.
├── modern_bert
│   ├── config.json
│   ├── model.safetensors
│   ├── special_tokens_map.json
│   ├── tokenizer.json
│   └── tokenizer_config.json
├── sample.py
└── sample.txt

Switch between FA2 and SDPA by modifying the constant ATTN_IMPLEMENTATION at the beginning of the script.
The script contains some Chinese comments and logs, but I think they should not have any actual impact :)

@neavo
Copy link
Author

neavo commented Jan 10, 2025

IIRC, compilation does not work properly on Windows/WSL, but this should not cause such a gap and should affect both path. Could you try verifying that using attn_implementation = "flash_attention_2" indeed use the FA2 path?

As raised by @staghado, we mainly use Linux machines (and WSL is not totally equivalent to Linux machines), so it is a bit hard for us to debug.

Since FA2 works correctly during training of other models,

Could you give more information about which setup you are referring to? I believe Jina is the only encoder with FA. Also, could you give the full boilerplate, as if I recall correctly, @tomaarsen uses WSL sometimes, so maybe he can try running the boilerplate and see if he experiences the same behavior.

There is a new finding:
Using the same weight file and system environment, when training for the downstream task (NER), FA2 can observe a significant increase in speed and reduction in memory usage.
Compared with SDPA, the speed increases by +100% and the memory usage decreases by -50%. I think this is the normal performance of FA2 taking effect.
I guess that some of the differential settings or steps in the two tasks of MLM and TokenClassification might be the reason for this difference. @NohTow

@neavo
Copy link
Author

neavo commented Jan 13, 2025

Still unable to identify the root cause of the issue, but a "silly" solution has been found.
I observed that when FA2 is enabled and causes abnormal GPU memory usage, it doesn't immediately max out the memory in the first STEP.
Instead, in the following few STEPS, there are several abnormal spikes in memory usage.
I used the following code to manually clear the GPU memory, and it worked: after clearing the memory a few times at the start of training, the subsequent memory usage stabilized.

def clear_memory(self, threshold: float) -> None:
    result = os.popen("nvidia-smi --query-gpu=memory.total,memory.reserved,memory.used --format csv,noheader,nounits").readlines()
    result = result[0].strip().split(", ")
    total = int(result[0])
    used = int(result[1]) + int(result[2])

    if used / total > threshold:
        torch.cuda.empty_cache()

def on_step_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
    self.clear_memory(0.92)

Overall, I believe the abnormal GPU memory usage under the FA2 path does exist, but with the above method, the training process can now proceed relatively normally.
I can continue to provide information to assist you in fully resolving the issue.
Alternatively, if you believe this issue can be considered resolved with this approach and no further follow-up is needed, you can close this ISSUE.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants