Skip to content

Commit fa532c0

Browse files
authored
Merge pull request #139 from georgian-io/flash-attn
Flash Attention Implementation & Fuller Config Options
2 parents 30ec177 + 0f683f6 commit fa532c0

File tree

5 files changed

+102
-18
lines changed

5 files changed

+102
-18
lines changed

README.md

+24
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,30 @@ Then the second command initiates the fine-tuning process using the settings spe
5151

5252
The configuration file is the central piece that defines the behavior of the toolkit. It is written in YAML format and consists of several sections that control different aspects of the process, such as data ingestion, model definition, training, inference, and quality assurance. We highlight some of the critical sections.
5353

54+
#### Flash Attention 2
55+
56+
To enable Flash-attention for [supported models](https://huggingface.co/docs/transformers/perf_infer_gpu_one#flashattention-2). First install `flash-attn`:
57+
58+
**pipx**
59+
60+
```shell
61+
pipx inject llm-toolkit flash-attn --pip-args=--no-build-isolation
62+
```
63+
64+
**pip**
65+
66+
```
67+
pip install flash-attn --no-build-isolation
68+
```
69+
70+
Then, add to config file.
71+
72+
```yaml
73+
model:
74+
torch_dtype: "bfloat16" # or "float16" if using older GPU
75+
attn_implementation: "flash_attention_2"
76+
```
77+
5478
#### Data Ingestion
5579
5680
An example of what the data ingestion may look like:

config.yml

+2
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@ data:
2424
# Model Definition -------------------
2525
model:
2626
hf_model_ckpt: "NousResearch/Llama-2-7b-hf"
27+
torch_dtype: "bfloat16"
28+
attn_implementation: "flash_attention_2"
2729
quantize: true
2830
bitsandbytes:
2931
load_in_4bit: true

llmtune/finetune/lora.py

+2
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,8 @@ def _get_model(self):
7474
),
7575
use_cache=False,
7676
device_map=self.device_map,
77+
torch_dtype=self._model_config.casted_torch_dtype,
78+
attn_implementation=self._model_config.attn_implementation,
7779
)
7880

7981
model.config.pretraining_tp = 1

llmtune/inference/lora.py

+3-7
Original file line numberDiff line numberDiff line change
@@ -40,17 +40,13 @@ def _get_merged_model(self, weights_path: str):
4040
torch.cuda.empty_cache()
4141

4242
# Load from path
43-
dtype = (
44-
torch.float16
45-
if self.config.training.training_args.fp16
46-
else (torch.bfloat16 if self.config.training.training_args.bf16 else torch.float32)
47-
)
4843

4944
self.model = AutoPeftModelForCausalLM.from_pretrained(
5045
weights_path,
51-
torch_dtype=dtype,
52-
device_map=self.device_map,
46+
torch_dtype=self.config.model.casted_torch_dtype,
5347
quantization_config=(BitsAndBytesConfig(**self.config.model.bitsandbytes.model_dump())),
48+
device_map=self.device_map,
49+
attn_implementation=self.config.model.attn_implementation,
5450
)
5551

5652
"""TODO: figure out multi-gpu

llmtune/pydantic_models/config_model.py

+71-11
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,13 @@ class ModelConfig(BaseModel):
7777
description="Path to the model (huggingface repo or local path)",
7878
)
7979
device_map: Optional[str] = Field("auto", description="device onto which to load the model")
80+
torch_dtype: Optional[str] = Field("auto", description="torch dtype to use for model weights")
81+
attn_implementation: Optional[str] = Field(
82+
None,
83+
description="set desired attention implementation; leave None for default. E.g. `flash_attention_2` (please ensure `torch_dtype` is either float16 or bfloat16).",
84+
)
8085

86+
# Quantization Config
8187
quantize: Optional[bool] = Field(False, description="Flag to enable quantization")
8288
bitsandbytes: BitsAndBytesConfig = Field(None, description="Bits and Bytes configuration")
8389

@@ -99,6 +105,18 @@ def set_device_map_to_none(cls, v, values, **kwargs):
99105
return None
100106
return v
101107

108+
@property
109+
def casted_torch_dtype(self) -> Union[str, torch.dtype]:
110+
if self.torch_dtype == "auto":
111+
return self.torch_dtype
112+
113+
try:
114+
torch_dtype = getattr(torch, self.torch_dtype)
115+
except AttributeError:
116+
raise ValueError(f"{self.torch_dtype} is not a valid torch data type")
117+
118+
return torch_dtype
119+
102120

103121
class LoraConfig(BaseModel):
104122
r: Optional[int] = Field(8, description="Lora rank")
@@ -126,7 +144,6 @@ class LoraConfig(BaseModel):
126144
# )
127145

128146

129-
# TODO: Get comprehensive Args!
130147
class TrainingArgs(BaseModel):
131148
num_train_epochs: Optional[int] = Field(1, description="Number of training epochs")
132149
per_device_train_batch_size: Optional[int] = Field(1, description="Batch size per training device")
@@ -141,9 +158,12 @@ class TrainingArgs(BaseModel):
141158
max_grad_norm: Optional[float] = Field(0.3, description="Maximum gradient norm")
142159
warmup_ratio: Optional[float] = Field(0.03, description="Warmup ratio")
143160
lr_scheduler_type: Optional[str] = Field("constant", description="Learning rate scheduler type")
161+
save_steps: Optional[Union[int, float]] = Field(
162+
500,
163+
description="Number of updates steps before checkpoint saves. Should be an integer or a float in range [0,1). If smaller than 1, will be interpreted as ratio of total training steps.",
164+
)
144165

145166

146-
# TODO: Get comprehensive Args!
147167
class SftArgs(BaseModel):
148168
max_seq_length: Optional[int] = Field(None, description="Maximum sequence length")
149169
neftune_noise_alpha: Optional[float] = Field(
@@ -157,16 +177,56 @@ class TrainingConfig(BaseModel):
157177
sft_args: SftArgs
158178

159179

160-
# TODO: Get comprehensive Args!
161180
class InferenceConfig(BaseModel):
162-
max_new_tokens: Optional[int] = Field(None, description="Maximum new tokens")
163-
use_cache: Optional[bool] = Field(True, description="Flag to enable cache usage")
164-
do_sample: Optional[bool] = Field(True, description="Flag to enable sampling")
165-
top_p: Optional[float] = Field(1.0, description="Top p value")
166-
temperature: Optional[float] = Field(0.1, description="Temperature value")
167-
epsilon_cutoff: Optional[float] = Field(0.0, description="epsilon cutoff value")
168-
eta_cutoff: Optional[float] = Field(0.0, description="eta cutoff value")
169-
top_k: Optional[int] = Field(50, description="top-k sampling")
181+
# Length
182+
max_length: Optional[int] = Field(None, description="The maximum length the generated tokens can have.")
183+
max_new_tokens: Optional[int] = Field(None, description="The maximum numbers of tokens to generate.")
184+
min_length: Optional[int] = Field(0, description="The minimum length of the sequence to be generated.")
185+
min_new_tokens: Optional[int] = Field(None, description="The minimum numbers of tokens to generate.")
186+
early_stopping: Optional[Union[bool, str]] = Field(
187+
False, description="Controls the stopping condition for beam search."
188+
)
189+
max_time: Optional[float] = Field(None, description="The maximum amount of time for the computation in seconds.")
190+
191+
# Generation Strategy
192+
do_sample: Optional[bool] = Field(False, description="Whether or not to use sampling.")
193+
num_beams: Optional[int] = Field(1, description="Number of beams for beam search.")
194+
num_beam_groups: Optional[int] = Field(1, description="Number of groups for diversity among beams.")
195+
penalty_alpha: Optional[float] = Field(None, description="Balances model confidence and degeneration penalty.")
196+
use_cache: Optional[bool] = Field(
197+
True,
198+
description="Whether to use past key/values attentions to speed up decoding.",
199+
)
200+
201+
# Manipulation of Model Output Logits
202+
temperature: Optional[float] = Field(1.0, description="Modulates the next token probabilities.")
203+
top_k: Optional[int] = Field(
204+
50,
205+
description="Number of highest probability tokens to keep for top-k-filtering.",
206+
)
207+
top_p: Optional[float] = Field(
208+
1.0,
209+
description="Keeps the smallest set of most probable tokens summing up to top_p.",
210+
)
211+
typical_p: Optional[float] = Field(1.0, description="Local typicality measure.")
212+
epsilon_cutoff: Optional[float] = Field(0.0, description="Minimum conditional probability for token sampling.")
213+
eta_cutoff: Optional[float] = Field(0.0, description="Hybrid of locally typical sampling and epsilon sampling.")
214+
diversity_penalty: Optional[float] = Field(
215+
0.0, description="Penalty for token repetition across different beam groups."
216+
)
217+
repetition_penalty: Optional[float] = Field(1.0, description="Penalty for token repetition.")
218+
encoder_repetition_penalty: Optional[float] = Field(
219+
1.0, description="Penalty on sequences not in the original input."
220+
)
221+
length_penalty: Optional[float] = Field(1.0, description="Exponential penalty to the length for beam search.")
222+
no_repeat_ngram_size: Optional[int] = Field(0, description="Size of ngrams that cannot occur more than once.")
223+
bad_words_ids: Optional[List[List[int]]] = Field(None, description="Tokens that are not allowed to be generated.")
224+
force_words_ids: Optional[List[Union[List[int], List[List[int]]]]] = Field(
225+
None, description="Tokens that must be generated."
226+
)
227+
renormalize_logits: Optional[bool] = Field(
228+
False, description="Whether to renormalize logits after all processors."
229+
)
170230

171231

172232
class AblationConfig(BaseModel):

0 commit comments

Comments
 (0)