Skip to content

Model inference speed is too slow (positively related to max_new_tokens length) #580

Open
@SunriseB

Description

@SunriseB

Description

Inference speed is too slow and positively related to max_new_tokens length. For example, I set max_new_tokens=1000, which would take almost 30s~40s with A100

Background

After loading a fine-tuned model (based on phi-2), I tried to test the accuracy of the fine-tune process. The result was good but I found another necessary issue: the speed was too slow.

Hardcore

1 A100

model

phi-2(with almost 2.7B parameters)

fine-tune method

PEFT
peft_training_args = TrainingArguments(
output_dir = output_dir,
warmup_steps=1,
per_device_train_batch_size=4,
gradient_accumulation_steps=4,
max_steps=800,
learning_rate=2e-4,
optim="paged_adamw_8bit",
logging_steps=25,
logging_dir="./logs",
save_strategy="steps",
save_steps=25,
evaluation_strategy="steps",
eval_steps=25,
do_eval=True,
gradient_checkpointing=True,
report_to="none",
overwrite_output_dir = 'True',
group_by_length=True,
)

Inference Code

  1. load model

base_model_id="microsoft/phi-2"
base_model = AutoModelForCausalLM.from_pretrained('base_model_id',
device_map='auto',
trust_remote_code=True,
use_auth_token=True
)

eval_tokenizer = AutoTokenizer.from_pretrained(base_model_id, add_bos_token=True, trust_remote_code=True, use_fast=False)
eval_tokenizer.pad_token = eval_tokenizer.eos_token

from peft import PeftModel

ft_model = PeftModel.from_pretrained(base_model, checkpoint_path, torch_dtype=torch.float16, is_trainable=False)

  1. generate method

def gen(model, p, maxlen=100, sample=True):
toks = eval_tokenizer(p, return_tensors="pt")
res = model.generate(**toks.to("cuda"), max_new_tokens=maxlen, do_sample=sample, num_return_sequences=1, temperature=0.01, num_beams=1, top_p=0.99)
return eval_tokenizer.batch_decode(res,skip_special_tokens=True)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions