Skip to content

Commit

Permalink
Enable Flash Attention in recompute and causal modes (#21)
Browse files Browse the repository at this point in the history
* Enable Flash Attention in recompute and causal modes

* Add flash_attention_causal_mask to generation utils

* Propagate Flash Attention causal_mask to finetuning example

* Modify README example and provide additional description

* Add flash_attention_causal_mask to FT README
  • Loading branch information
wszczurekhabana authored and bhargaveede committed Feb 19, 2024
1 parent 50c3d13 commit 92e4f64
Show file tree
Hide file tree
Showing 9 changed files with 122 additions and 5 deletions.
3 changes: 2 additions & 1 deletion examples/language-modeling/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -552,7 +552,8 @@ python3 ../gaudi_spawn.py --use_deepspeed --world_size 8 run_lora_clm.py \
--lora_rank 4 \
--lora_target_modules "q_proj" "v_proj" "k_proj" "o_proj" \
--validation_split_percentage 4 \
--use_flash_attention True
--use_flash_attention True \
--flash_attention_causal_mask True
```

- Multi-card finetuning of Falcon-180B:
Expand Down
10 changes: 10 additions & 0 deletions examples/language-modeling/run_lora_clm.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,15 @@ class ModelArguments:
)
},
)
flash_attention_causal_mask: bool = field(
default=False,
metadata={
"help": (
"Whether to enable causal mask in Habana flash attention for fine-tuning."
" It is applicable only when use_flash_attention is True.",
)
},
)
use_fused_rope: bool = field(
default=True,
metadata={
Expand Down Expand Up @@ -545,6 +554,7 @@ def main():
if model_args.use_flash_attention:
model.generation_config.use_flash_attention = True
model.generation_config.flash_attention_recompute = model_args.flash_attention_recompute
model.generation_config.flash_attention_causal_mask = model_args.flash_attention_causal_mask
if not model_args.use_fused_rope:
model.generation_config.use_fused_rope = False

Expand Down
24 changes: 24 additions & 0 deletions examples/text-generation/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,30 @@ QUANT_CONFIG=./quantization_config/maxabs_quant.json python ../gaudi_spawn.py \
```
`--fp8` is required to enable quantization in fp8.

### Using Habana Flash Attention

Habana Flash Attention addresses large sequence lenghts on prompt stage of inference. Using causal attention mask on prompt stage requires input sequences in batch to be of the same length, but can provide a memory saving, thus enabling higher batch sizes.

Below example uses `flash_attention_recompute` mode in order to reduce memory consumption on prompt stage. Additionally since all sequences in a batch are of the same lenght it uses `flash_attention_causal_mask` which will further improve performance by taking advantage of specific lower-diagonal shape of inputs to softmax operation.

```bash
python ../gaudi_spawn.py --use_deepspeed --world_size 8 run_generation.py \
--model_name_or_path meta-llama/Llama-2-70b-hf \
--use_hpu_graphs \
--use_kv_cache \
--reuse_cache \
--trim_logits \
--attn_softmax_bf16 \
--max_input_tokens 31744 \
--max_new_tokens 1024 \
--batch_size=12 \
--use_flash_attention \
--flash_attention_recompute \
--flash_attention_causal_mask \
--book_source
```

For more details see [documentation](https://docs.habana.ai/en/latest/PyTorch/Model_Optimization_PyTorch/Optimization_in_PyTorch_Models.html#using-fused-sdpa).

## Language Model Evaluation Harness

Expand Down
54 changes: 54 additions & 0 deletions examples/text-generation/run_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,21 @@ def setup_parser(parser):
action="store_true",
help="Whether to enable Habana Flash Attention, provided that the model supports it.",
)
parser.add_argument(
"--flash_attention_recompute",
action="store_true",
help="Whether to enable Habana Flash Attention in recompute mode on first token generation. This gives an opportunity of splitting graph internally which helps reduce memory consumption.",
)
parser.add_argument(
"--flash_attention_causal_mask",
action="store_true",
help="Whether to enable Habana Flash Attention in causal mode on first token generation.",
)
parser.add_argument(
"--book_source",
action="store_true",
help="Whether to use project Guttenberg books data as input. Usefull for testing large sequence lenghts.",
)
parser.add_argument(
"--torch_compile",
action="store_true",
Expand Down Expand Up @@ -271,6 +286,45 @@ def main():
# Benchmark over the prompts below
if args.prompt:
input_sentences = args.prompt
elif args.book_source:

def download_book(book_id):
import os

import requests

url = f"https://www.gutenberg.org/cache/epub/{book_id}/pg{book_id}.txt"
response = requests.get(url)
if response.status_code == 200:
pid = os.getpid()
save_path = f"/tmp/{book_id}_{pid}.txt"
with open(save_path, "wb") as file:
file.write(response.content)
print(f"Book downloaded and saved to: {save_path}")
return save_path
else:
print("Failed to download book! Exiting...")
import sys

sys.exit()

def assemble_prompt(prompt_size, book_path):
prompt = ""
counter = 0
book_lines = open(book_path).readlines()
for line in book_lines:
for word in line.split():
counter += 1
prompt += word + " "
if counter == prompt_size:
return [prompt] * args.batch_size

book_ids = [
2701, # Moby Dick; Or, The Whale
1513, # Romeo and Juliet
1342, # Pride and Prejudice
]
input_sentences = assemble_prompt(prompt_size=args.max_input_tokens, book_path=download_book(book_ids[0]))
else:
input_sentences = [
"DeepSpeed is a machine learning framework",
Expand Down
2 changes: 2 additions & 0 deletions examples/text-generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,8 @@ def setup_generation_config(args, model, tokenizer):
assert generation_config.bucket_size > 0
generation_config.kv_cache_fp8 = args.kv_cache_fp8
generation_config.use_flash_attention = args.use_flash_attention
generation_config.flash_attention_recompute = args.flash_attention_recompute
generation_config.flash_attention_causal_mask = args.flash_attention_causal_mask
return generation_config


Expand Down
3 changes: 3 additions & 0 deletions optimum/habana/transformers/generation/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ class GaudiGenerationConfig(GenerationConfig):
Whether to use flash attention optimization.
flash_attention_recompute (`bool`, *optional*):
Whether to enable recompute if use Habana flash attention.
flash_attention_causal_mask (`bool`, *optional*):
Whether to enable causal_mask if use Habana flash attention.
"""

def __init__(self, **kwargs):
Expand All @@ -49,4 +51,5 @@ def __init__(self, **kwargs):
self.kv_cache_fp8 = kwargs.get("kv_cache_fp8", None)
self.use_flash_attention = kwargs.get("use_flash_attention", None)
self.flash_attention_recompute = kwargs.get("flash_attention_recompute", None)
self.flash_attention_causal_mask = kwargs.get("flash_attention_causal_mask", None)
self.use_fused_rope = kwargs.get("use_fused_rope", None)
1 change: 1 addition & 0 deletions optimum/habana/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -707,6 +707,7 @@ def generate(
# determine whether flash attention needs to be used
model_kwargs["use_flash_attention"] = generation_config.use_flash_attention
model_kwargs["flash_attention_recompute"] = True if generation_config.flash_attention_recompute else False
model_kwargs["flash_attention_causal_mask"] = True if generation_config.flash_attention_causal_mask else False
model_kwargs["use_fused_rope"] = False if not generation_config.use_fused_rope else True

if not self.config.is_encoder_decoder:
Expand Down
26 changes: 22 additions & 4 deletions optimum/habana/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,7 @@ def pre_attn_forward(
reuse_cache: Optional[bool] = False,
use_flash_attention: Optional[bool] = False,
flash_attention_recompute: Optional[bool] = False,
flash_attention_causal_mask: Optional[bool] = False,
cache_idx: int = None,
use_fused_rope: Optional[bool] = True,
**kwargs,
Expand All @@ -220,6 +221,7 @@ def pre_attn_forward(
- add new args reuse_cache
- add new args use_flash_attention
- add new arg flash_attention_recompute
- add new arg flash_attention_causal_mask
"""
if "padding_mask" in kwargs:
warnings.warn(
Expand Down Expand Up @@ -310,10 +312,15 @@ def pre_attn_forward(
)
else:
# first token
with ht.sdp_kernel(enable_recompute=flash_attention_recompute):
attn_output = FusedSDPA.apply(
query_states, key_states, value_states, attention_mask, 0.0, False, None
)
if flash_attention_causal_mask:
# causal masking on first token requires inputs to be of the same lenght
with ht.sdp_kernel(enable_recompute=flash_attention_recompute):
attn_output = FusedSDPA.apply(query_states, key_states, value_states, None, 0.0, True, None)
else:
with ht.sdp_kernel(enable_recompute=flash_attention_recompute):
attn_output = FusedSDPA.apply(
query_states, key_states, value_states, attention_mask, 0.0, False, None
)

else:
query_states, key_states, value_states, attention_mask = gaudi_llama_repeat_kv(
Expand Down Expand Up @@ -455,6 +462,7 @@ def forward(
reuse_cache: Optional[bool] = False,
use_flash_attention: Optional[bool] = False,
flash_attention_recompute: Optional[bool] = False,
flash_attention_causal_mask: Optional[bool] = False,
cache_idx: int = None,
use_fused_rope: Optional[bool] = True,
**kwargs,
Expand All @@ -467,6 +475,7 @@ def forward(
- add new args reuse_cache
- add new args use_flash_attention
- add new arg flash_attention_recompute
- add new arg flash_attention_causal_mask
"""
if "padding_mask" in kwargs:
warnings.warn(
Expand All @@ -486,6 +495,7 @@ def forward(
reuse_cache,
use_flash_attention=use_flash_attention,
flash_attention_recompute=flash_attention_recompute,
flash_attention_causal_mask=flash_attention_causal_mask,
cache_idx=cache_idx,
use_fused_rope=use_fused_rope,
**kwargs,
Expand Down Expand Up @@ -517,6 +527,7 @@ def pre_attn(
reuse_cache: Optional[bool] = False,
use_flash_attention: Optional[bool] = False,
flash_attention_recompute: Optional[bool] = False,
flash_attention_causal_mask: Optional[bool] = False,
cache_idx: int = None,
use_fused_rope: Optional[bool] = True,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
Expand All @@ -533,6 +544,7 @@ def pre_attn(
reuse_cache,
use_flash_attention,
flash_attention_recompute,
flash_attention_causal_mask,
cache_idx=cache_idx,
use_fused_rope=use_fused_rope,
)
Expand Down Expand Up @@ -583,6 +595,7 @@ def forward(
reuse_cache: Optional[bool] = False,
use_flash_attention: Optional[bool] = False,
flash_attention_recompute: Optional[bool] = False,
flash_attention_causal_mask: Optional[bool] = False,
cache_idx: int = None,
use_fused_rope: Optional[bool] = True,
) -> Union[Tuple, BaseModelOutputWithPast]:
Expand All @@ -594,6 +607,7 @@ def forward(
- add new args reuse_cache
- add new args use_flash_attention
- add new arg flash_attention_recompute
- add new arg flash_attention_causal_mask
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
Expand Down Expand Up @@ -702,6 +716,7 @@ def forward(
reuse_cache=reuse_cache,
use_flash_attention=use_flash_attention,
flash_attention_recompute=flash_attention_recompute,
flash_attention_causal_mask=flash_attention_causal_mask,
cache_idx=cache_idx,
use_fused_rope=use_fused_rope,
)
Expand Down Expand Up @@ -777,6 +792,7 @@ def forward(
reuse_cache: Optional[bool] = False,
use_flash_attention: Optional[bool] = False,
flash_attention_recompute: Optional[bool] = False,
flash_attention_causal_mask: Optional[bool] = False,
cache_idx: int = None,
use_fused_rope: Optional[bool] = True,
) -> Union[Tuple, CausalLMOutputWithPast]:
Expand All @@ -801,6 +817,7 @@ def forward(
reuse_cache=reuse_cache,
use_flash_attention=use_flash_attention,
flash_attention_recompute=flash_attention_recompute,
flash_attention_causal_mask=flash_attention_causal_mask,
cache_idx=cache_idx,
use_fused_rope=use_fused_rope,
)
Expand Down Expand Up @@ -914,6 +931,7 @@ def prepare_inputs_for_generation(
"reuse_cache": reuse_cache,
"use_flash_attention": kwargs.get("use_flash_attention"),
"flash_attention_recompute": kwargs.get("flash_attention_recompute"),
"flash_attention_causal_mask": kwargs.get("flash_attention_causal_mask"),
"cache_idx": kwargs.get("cache_idx"),
}
)
Expand Down
4 changes: 4 additions & 0 deletions optimum/habana/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -909,6 +909,8 @@ def hpu_deepspeed_checkpointing(function, *checkpoint_args):
inputs["use_flash_attention"] = True
if self.model.generation_config.flash_attention_recompute:
inputs["flash_attention_recompute"] = True
if self.model.generation_config.flash_attention_causal_mask:
inputs["flash_attention_causal_mask"] = True
if not self.model.generation_config.use_fused_rope:
inputs["use_fused_rope"] = False

Expand Down Expand Up @@ -1686,6 +1688,8 @@ def evaluation_loop(
inputs["use_flash_attention"] = True
if self.model.generation_config.flash_attention_recompute:
inputs["flash_attention_recompute"] = True
if self.model.generation_config.flash_attention_causal_mask:
inputs["flash_attention_causal_mask"] = True
if not self.model.generation_config.use_fused_rope:
inputs["use_fused_rope"] = False

Expand Down

0 comments on commit 92e4f64

Please sign in to comment.