-
Notifications
You must be signed in to change notification settings - Fork 62
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
During pre-training, using FA2 consumes more memory than using SDPA #172
Comments
So is this a problem with my workflow or a bug? |
In general I think the FA2 support on Windows is not well tested(https://github.com/Dao-AILab/flash-attention?tab=readme-ov-file#installation-and-features). We only ever used Linux machines for the pre-training part. |
Yes, it's possible. |
IIRC, compilation does not work properly on Windows/WSL, but this should not cause such a gap and should affect both path. As raised by @staghado, we mainly use Linux machines (and WSL is not totally equivalent to Linux machines), so it is a bit hard for us to debug.
Could you give more information about which setup you are referring to? I believe Jina is the only encoder with FA. |
Yes, the data in the main text represents the results after specifying
I may not have expressed it clearly: when training other types of models (such as Qwen-2.5) in the same dependency environment, FA2 is indeed working normally, and a significant reduction in memory usage can be observed. However, I have not tried using FA2 in the training of other Bert-Like models.
This is the most simplified script after removing the data logic, which can reproduce the issue mentioned earlier.
Switch between FA2 and SDPA by modifying the constant ATTN_IMPLEMENTATION at the beginning of the script. |
There is a new finding: |
Still unable to identify the root cause of the issue, but a "silly" solution has been found. def clear_memory(self, threshold: float) -> None:
result = os.popen("nvidia-smi --query-gpu=memory.total,memory.reserved,memory.used --format csv,noheader,nounits").readlines()
result = result[0].strip().split(", ")
total = int(result[0])
used = int(result[1]) + int(result[2])
if used / total > threshold:
torch.cuda.empty_cache()
def on_step_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
self.clear_memory(0.92) Overall, I believe the abnormal GPU memory usage under the FA2 path does exist, but with the above method, the training process can now proceed relatively normally. |
As described in the title
When performing pre-training, using FA2 will consume more GPU memory than using SDPA.
I am using the trainer from transformers for training, and the simplified code is roughly as follows:
When all other parameters are kept consistent and only the
attn_implementation
is changed, the GPU memory usage rates are 48% and 88%, respectively.When using FA2, the GPU memory usage is significantly higher than with SDPA and also much higher than with other traditional Bert-Like models, and there is no improvement in speed.
The same phenomenon has been observed on both Windows 11 24H2 and Ubuntu@WSL2.
ENVS:
PyTorch 2.5.1
Python 3.12.8
flash_attn v2.7.2.post1
The text was updated successfully, but these errors were encountered: