Skip to content

Commit 92e4f64

Browse files
wszczurekhabanabhargaveede
authored andcommitted
Enable Flash Attention in recompute and causal modes (#21)
* Enable Flash Attention in recompute and causal modes * Add flash_attention_causal_mask to generation utils * Propagate Flash Attention causal_mask to finetuning example * Modify README example and provide additional description * Add flash_attention_causal_mask to FT README
1 parent 50c3d13 commit 92e4f64

File tree

9 files changed

+122
-5
lines changed

9 files changed

+122
-5
lines changed

examples/language-modeling/README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -552,7 +552,8 @@ python3 ../gaudi_spawn.py --use_deepspeed --world_size 8 run_lora_clm.py \
552552
--lora_rank 4 \
553553
--lora_target_modules "q_proj" "v_proj" "k_proj" "o_proj" \
554554
--validation_split_percentage 4 \
555-
--use_flash_attention True
555+
--use_flash_attention True \
556+
--flash_attention_causal_mask True
556557
```
557558

558559
- Multi-card finetuning of Falcon-180B:

examples/language-modeling/run_lora_clm.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,15 @@ class ModelArguments:
156156
)
157157
},
158158
)
159+
flash_attention_causal_mask: bool = field(
160+
default=False,
161+
metadata={
162+
"help": (
163+
"Whether to enable causal mask in Habana flash attention for fine-tuning."
164+
" It is applicable only when use_flash_attention is True.",
165+
)
166+
},
167+
)
159168
use_fused_rope: bool = field(
160169
default=True,
161170
metadata={
@@ -545,6 +554,7 @@ def main():
545554
if model_args.use_flash_attention:
546555
model.generation_config.use_flash_attention = True
547556
model.generation_config.flash_attention_recompute = model_args.flash_attention_recompute
557+
model.generation_config.flash_attention_causal_mask = model_args.flash_attention_causal_mask
548558
if not model_args.use_fused_rope:
549559
model.generation_config.use_fused_rope = False
550560

examples/text-generation/README.md

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -296,6 +296,30 @@ QUANT_CONFIG=./quantization_config/maxabs_quant.json python ../gaudi_spawn.py \
296296
```
297297
`--fp8` is required to enable quantization in fp8.
298298

299+
### Using Habana Flash Attention
300+
301+
Habana Flash Attention addresses large sequence lenghts on prompt stage of inference. Using causal attention mask on prompt stage requires input sequences in batch to be of the same length, but can provide a memory saving, thus enabling higher batch sizes.
302+
303+
Below example uses `flash_attention_recompute` mode in order to reduce memory consumption on prompt stage. Additionally since all sequences in a batch are of the same lenght it uses `flash_attention_causal_mask` which will further improve performance by taking advantage of specific lower-diagonal shape of inputs to softmax operation.
304+
305+
```bash
306+
python ../gaudi_spawn.py --use_deepspeed --world_size 8 run_generation.py \
307+
--model_name_or_path meta-llama/Llama-2-70b-hf \
308+
--use_hpu_graphs \
309+
--use_kv_cache \
310+
--reuse_cache \
311+
--trim_logits \
312+
--attn_softmax_bf16 \
313+
--max_input_tokens 31744 \
314+
--max_new_tokens 1024 \
315+
--batch_size=12 \
316+
--use_flash_attention \
317+
--flash_attention_recompute \
318+
--flash_attention_causal_mask \
319+
--book_source
320+
```
321+
322+
For more details see [documentation](https://docs.habana.ai/en/latest/PyTorch/Model_Optimization_PyTorch/Optimization_in_PyTorch_Models.html#using-fused-sdpa).
299323

300324
## Language Model Evaluation Harness
301325

examples/text-generation/run_generation.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,21 @@ def setup_parser(parser):
232232
action="store_true",
233233
help="Whether to enable Habana Flash Attention, provided that the model supports it.",
234234
)
235+
parser.add_argument(
236+
"--flash_attention_recompute",
237+
action="store_true",
238+
help="Whether to enable Habana Flash Attention in recompute mode on first token generation. This gives an opportunity of splitting graph internally which helps reduce memory consumption.",
239+
)
240+
parser.add_argument(
241+
"--flash_attention_causal_mask",
242+
action="store_true",
243+
help="Whether to enable Habana Flash Attention in causal mode on first token generation.",
244+
)
245+
parser.add_argument(
246+
"--book_source",
247+
action="store_true",
248+
help="Whether to use project Guttenberg books data as input. Usefull for testing large sequence lenghts.",
249+
)
235250
parser.add_argument(
236251
"--torch_compile",
237252
action="store_true",
@@ -271,6 +286,45 @@ def main():
271286
# Benchmark over the prompts below
272287
if args.prompt:
273288
input_sentences = args.prompt
289+
elif args.book_source:
290+
291+
def download_book(book_id):
292+
import os
293+
294+
import requests
295+
296+
url = f"https://www.gutenberg.org/cache/epub/{book_id}/pg{book_id}.txt"
297+
response = requests.get(url)
298+
if response.status_code == 200:
299+
pid = os.getpid()
300+
save_path = f"/tmp/{book_id}_{pid}.txt"
301+
with open(save_path, "wb") as file:
302+
file.write(response.content)
303+
print(f"Book downloaded and saved to: {save_path}")
304+
return save_path
305+
else:
306+
print("Failed to download book! Exiting...")
307+
import sys
308+
309+
sys.exit()
310+
311+
def assemble_prompt(prompt_size, book_path):
312+
prompt = ""
313+
counter = 0
314+
book_lines = open(book_path).readlines()
315+
for line in book_lines:
316+
for word in line.split():
317+
counter += 1
318+
prompt += word + " "
319+
if counter == prompt_size:
320+
return [prompt] * args.batch_size
321+
322+
book_ids = [
323+
2701, # Moby Dick; Or, The Whale
324+
1513, # Romeo and Juliet
325+
1342, # Pride and Prejudice
326+
]
327+
input_sentences = assemble_prompt(prompt_size=args.max_input_tokens, book_path=download_book(book_ids[0]))
274328
else:
275329
input_sentences = [
276330
"DeepSpeed is a machine learning framework",

examples/text-generation/utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -347,6 +347,8 @@ def setup_generation_config(args, model, tokenizer):
347347
assert generation_config.bucket_size > 0
348348
generation_config.kv_cache_fp8 = args.kv_cache_fp8
349349
generation_config.use_flash_attention = args.use_flash_attention
350+
generation_config.flash_attention_recompute = args.flash_attention_recompute
351+
generation_config.flash_attention_causal_mask = args.flash_attention_causal_mask
350352
return generation_config
351353

352354

optimum/habana/transformers/generation/configuration_utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@ class GaudiGenerationConfig(GenerationConfig):
3333
Whether to use flash attention optimization.
3434
flash_attention_recompute (`bool`, *optional*):
3535
Whether to enable recompute if use Habana flash attention.
36+
flash_attention_causal_mask (`bool`, *optional*):
37+
Whether to enable causal_mask if use Habana flash attention.
3638
"""
3739

3840
def __init__(self, **kwargs):
@@ -49,4 +51,5 @@ def __init__(self, **kwargs):
4951
self.kv_cache_fp8 = kwargs.get("kv_cache_fp8", None)
5052
self.use_flash_attention = kwargs.get("use_flash_attention", None)
5153
self.flash_attention_recompute = kwargs.get("flash_attention_recompute", None)
54+
self.flash_attention_causal_mask = kwargs.get("flash_attention_causal_mask", None)
5255
self.use_fused_rope = kwargs.get("use_fused_rope", None)

optimum/habana/transformers/generation/utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -707,6 +707,7 @@ def generate(
707707
# determine whether flash attention needs to be used
708708
model_kwargs["use_flash_attention"] = generation_config.use_flash_attention
709709
model_kwargs["flash_attention_recompute"] = True if generation_config.flash_attention_recompute else False
710+
model_kwargs["flash_attention_causal_mask"] = True if generation_config.flash_attention_causal_mask else False
710711
model_kwargs["use_fused_rope"] = False if not generation_config.use_fused_rope else True
711712

712713
if not self.config.is_encoder_decoder:

optimum/habana/transformers/models/llama/modeling_llama.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,7 @@ def pre_attn_forward(
207207
reuse_cache: Optional[bool] = False,
208208
use_flash_attention: Optional[bool] = False,
209209
flash_attention_recompute: Optional[bool] = False,
210+
flash_attention_causal_mask: Optional[bool] = False,
210211
cache_idx: int = None,
211212
use_fused_rope: Optional[bool] = True,
212213
**kwargs,
@@ -220,6 +221,7 @@ def pre_attn_forward(
220221
- add new args reuse_cache
221222
- add new args use_flash_attention
222223
- add new arg flash_attention_recompute
224+
- add new arg flash_attention_causal_mask
223225
"""
224226
if "padding_mask" in kwargs:
225227
warnings.warn(
@@ -310,10 +312,15 @@ def pre_attn_forward(
310312
)
311313
else:
312314
# first token
313-
with ht.sdp_kernel(enable_recompute=flash_attention_recompute):
314-
attn_output = FusedSDPA.apply(
315-
query_states, key_states, value_states, attention_mask, 0.0, False, None
316-
)
315+
if flash_attention_causal_mask:
316+
# causal masking on first token requires inputs to be of the same lenght
317+
with ht.sdp_kernel(enable_recompute=flash_attention_recompute):
318+
attn_output = FusedSDPA.apply(query_states, key_states, value_states, None, 0.0, True, None)
319+
else:
320+
with ht.sdp_kernel(enable_recompute=flash_attention_recompute):
321+
attn_output = FusedSDPA.apply(
322+
query_states, key_states, value_states, attention_mask, 0.0, False, None
323+
)
317324

318325
else:
319326
query_states, key_states, value_states, attention_mask = gaudi_llama_repeat_kv(
@@ -455,6 +462,7 @@ def forward(
455462
reuse_cache: Optional[bool] = False,
456463
use_flash_attention: Optional[bool] = False,
457464
flash_attention_recompute: Optional[bool] = False,
465+
flash_attention_causal_mask: Optional[bool] = False,
458466
cache_idx: int = None,
459467
use_fused_rope: Optional[bool] = True,
460468
**kwargs,
@@ -467,6 +475,7 @@ def forward(
467475
- add new args reuse_cache
468476
- add new args use_flash_attention
469477
- add new arg flash_attention_recompute
478+
- add new arg flash_attention_causal_mask
470479
"""
471480
if "padding_mask" in kwargs:
472481
warnings.warn(
@@ -486,6 +495,7 @@ def forward(
486495
reuse_cache,
487496
use_flash_attention=use_flash_attention,
488497
flash_attention_recompute=flash_attention_recompute,
498+
flash_attention_causal_mask=flash_attention_causal_mask,
489499
cache_idx=cache_idx,
490500
use_fused_rope=use_fused_rope,
491501
**kwargs,
@@ -517,6 +527,7 @@ def pre_attn(
517527
reuse_cache: Optional[bool] = False,
518528
use_flash_attention: Optional[bool] = False,
519529
flash_attention_recompute: Optional[bool] = False,
530+
flash_attention_causal_mask: Optional[bool] = False,
520531
cache_idx: int = None,
521532
use_fused_rope: Optional[bool] = True,
522533
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
@@ -533,6 +544,7 @@ def pre_attn(
533544
reuse_cache,
534545
use_flash_attention,
535546
flash_attention_recompute,
547+
flash_attention_causal_mask,
536548
cache_idx=cache_idx,
537549
use_fused_rope=use_fused_rope,
538550
)
@@ -583,6 +595,7 @@ def forward(
583595
reuse_cache: Optional[bool] = False,
584596
use_flash_attention: Optional[bool] = False,
585597
flash_attention_recompute: Optional[bool] = False,
598+
flash_attention_causal_mask: Optional[bool] = False,
586599
cache_idx: int = None,
587600
use_fused_rope: Optional[bool] = True,
588601
) -> Union[Tuple, BaseModelOutputWithPast]:
@@ -594,6 +607,7 @@ def forward(
594607
- add new args reuse_cache
595608
- add new args use_flash_attention
596609
- add new arg flash_attention_recompute
610+
- add new arg flash_attention_causal_mask
597611
"""
598612
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
599613
output_hidden_states = (
@@ -702,6 +716,7 @@ def forward(
702716
reuse_cache=reuse_cache,
703717
use_flash_attention=use_flash_attention,
704718
flash_attention_recompute=flash_attention_recompute,
719+
flash_attention_causal_mask=flash_attention_causal_mask,
705720
cache_idx=cache_idx,
706721
use_fused_rope=use_fused_rope,
707722
)
@@ -777,6 +792,7 @@ def forward(
777792
reuse_cache: Optional[bool] = False,
778793
use_flash_attention: Optional[bool] = False,
779794
flash_attention_recompute: Optional[bool] = False,
795+
flash_attention_causal_mask: Optional[bool] = False,
780796
cache_idx: int = None,
781797
use_fused_rope: Optional[bool] = True,
782798
) -> Union[Tuple, CausalLMOutputWithPast]:
@@ -801,6 +817,7 @@ def forward(
801817
reuse_cache=reuse_cache,
802818
use_flash_attention=use_flash_attention,
803819
flash_attention_recompute=flash_attention_recompute,
820+
flash_attention_causal_mask=flash_attention_causal_mask,
804821
cache_idx=cache_idx,
805822
use_fused_rope=use_fused_rope,
806823
)
@@ -914,6 +931,7 @@ def prepare_inputs_for_generation(
914931
"reuse_cache": reuse_cache,
915932
"use_flash_attention": kwargs.get("use_flash_attention"),
916933
"flash_attention_recompute": kwargs.get("flash_attention_recompute"),
934+
"flash_attention_causal_mask": kwargs.get("flash_attention_causal_mask"),
917935
"cache_idx": kwargs.get("cache_idx"),
918936
}
919937
)

optimum/habana/transformers/trainer.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -909,6 +909,8 @@ def hpu_deepspeed_checkpointing(function, *checkpoint_args):
909909
inputs["use_flash_attention"] = True
910910
if self.model.generation_config.flash_attention_recompute:
911911
inputs["flash_attention_recompute"] = True
912+
if self.model.generation_config.flash_attention_causal_mask:
913+
inputs["flash_attention_causal_mask"] = True
912914
if not self.model.generation_config.use_fused_rope:
913915
inputs["use_fused_rope"] = False
914916

@@ -1686,6 +1688,8 @@ def evaluation_loop(
16861688
inputs["use_flash_attention"] = True
16871689
if self.model.generation_config.flash_attention_recompute:
16881690
inputs["flash_attention_recompute"] = True
1691+
if self.model.generation_config.flash_attention_causal_mask:
1692+
inputs["flash_attention_causal_mask"] = True
16891693
if not self.model.generation_config.use_fused_rope:
16901694
inputs["use_fused_rope"] = False
16911695

0 commit comments

Comments
 (0)