Skip to content

Commit c207885

Browse files
authored
ENH Extend usage for OLoRA finetune script (#2308)
- allow DDP - make it work on CPU - set seed and dtype Related: dequantize_bnb_weight is updated not to move to cuda if not available. --------- Signed-off-by: jiqing-feng <[email protected]>
1 parent 3d2bf9a commit c207885

File tree

3 files changed

+40
-12
lines changed

3 files changed

+40
-12
lines changed

examples/olora_finetuning/README.md

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,19 @@ OLoRA also supports quantization. To use 4-bit quantization try:
3939
```bash
4040
python3 examples/olora_finetuning/olora_finetuning.py --base_model facebook/opt-350m --quantize
4141
```
42+
or you can just pass a quantized model without the quantize flag.
43+
44+
If you want to run DDP by [accelerate](https://huggingface.co/docs/accelerate/en/index), please run `accelerate config` to set your ddp config, and run:
45+
```bash
46+
accelerate launch examples/olora_finetuning/olora_finetuning.py --base_model facebook/opt-350m
47+
```
48+
please add `--device_map cpu` if you want to run finetune on CPU.
49+
50+
If you want to train a quantized model like AWQ and GPTQ which do not support olora init method, please pass `--init_lora_weights gaussian`. For example:
51+
```bash
52+
python3 examples/olora_finetuning/olora_finetuning.py --base_model hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4 --init_lora_weights gaussian
53+
54+
```
4255

4356

4457
## Use the model

examples/olora_finetuning/olora_finetuning.py

Lines changed: 26 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,13 @@
1313
# limitations under the License.
1414

1515

16-
from typing import List
16+
import os
17+
from typing import List, Optional
1718

1819
import torch
1920
import transformers
2021
from datasets import load_dataset
21-
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
22+
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, set_seed
2223

2324
from peft import (
2425
LoraConfig,
@@ -43,23 +44,33 @@ def train(
4344
lora_alpha: int = 16,
4445
lora_dropout: float = 0.05,
4546
lora_target_modules: List[str] = None,
47+
torch_dtype: str = "float16",
4648
init_lora_weights="olora",
49+
seed: Optional[int] = None,
4750
):
48-
model = AutoModelForCausalLM.from_pretrained(
49-
base_model,
50-
device_map=device_map,
51-
quantization_config=BitsAndBytesConfig(
51+
# Set device_map to the right place when enabling DDP.
52+
world_size = int(os.environ.get("WORLD_SIZE", 0)) or int(os.environ.get("PMI_SIZE", 0))
53+
if world_size > 1 and device_map != "cpu":
54+
from accelerate import Accelerator
55+
56+
device_map = {"": Accelerator().process_index}
57+
# Set seed
58+
if seed is not None:
59+
set_seed(seed)
60+
model_kwargs = {"torch_dtype": getattr(torch, torch_dtype), "device_map": device_map}
61+
if quantize:
62+
model_kwargs["quantization_config"] = BitsAndBytesConfig(
5263
load_in_4bit=True,
5364
bnb_4bit_compute_dtype=torch.bfloat16,
5465
bnb_4bit_use_double_quant=True,
5566
bnb_4bit_quant_type="nf4",
5667
)
57-
if quantize
58-
else None,
59-
torch_dtype=torch.float16,
60-
)
68+
model = AutoModelForCausalLM.from_pretrained(base_model, **model_kwargs)
6169

6270
tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True)
71+
# For some tokenizer with no pad token like llama
72+
if tokenizer.pad_token is None:
73+
tokenizer.pad_token = tokenizer.eos_token
6374

6475
def tokenize(prompt, add_eos_token=True):
6576
result = tokenizer(
@@ -112,7 +123,6 @@ def generate_and_tokenize_prompt(example):
112123
warmup_steps=100,
113124
num_train_epochs=num_epochs,
114125
learning_rate=learning_rate,
115-
fp16=True,
116126
logging_steps=100,
117127
optim="adamw_torch",
118128
evaluation_strategy="steps",
@@ -122,6 +132,7 @@ def generate_and_tokenize_prompt(example):
122132
output_dir=output_dir,
123133
save_total_limit=3,
124134
load_best_model_at_end=True,
135+
ddp_find_unused_parameters=False if world_size > 1 else None,
125136
),
126137
data_collator=transformers.DataCollatorForSeq2Seq(
127138
tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True
@@ -159,7 +170,9 @@ def generate_prompt(example):
159170
parser.add_argument("--lora_alpha", type=int, default=16)
160171
parser.add_argument("--lora_dropout", type=float, default=0.05)
161172
parser.add_argument("--lora_target_modules", type=str, default=None)
173+
parser.add_argument("--torch_dtype", type=str, default="float16")
162174
parser.add_argument("--init_lora_weights", type=str, default="olora")
175+
parser.add_argument("--seed", type=int, default=None)
163176

164177
args = parser.parse_args()
165178

@@ -180,5 +193,7 @@ def generate_prompt(example):
180193
lora_alpha=args.lora_alpha,
181194
lora_dropout=args.lora_dropout,
182195
lora_target_modules=args.lora_target_modules,
196+
torch_dtype=args.torch_dtype,
183197
init_lora_weights=args.init_lora_weights,
198+
seed=args.seed,
184199
)

src/peft/utils/integrations.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ def dequantize_bnb_weight(weight: torch.nn.Parameter, state=None):
8888
# BNB requires CUDA weights
8989
device = weight.device
9090
is_cpu = device.type == torch.device("cpu").type
91-
if is_cpu:
91+
if is_cpu and torch.cuda.is_available():
9292
weight = weight.to(torch.device("cuda"))
9393

9494
cls_name = weight.__class__.__name__

0 commit comments

Comments
 (0)