Skip to content

Commit

Permalink
adding num_fewshot
Browse files Browse the repository at this point in the history
  • Loading branch information
proserve committed Dec 27, 2023
1 parent cdb8450 commit cd858b8
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 6 deletions.
23 changes: 17 additions & 6 deletions start.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,13 @@ def parse_arge():
type=str,
help="eval tasks, separated by comma, example: hellaswag,mmlu",
)

parser.add_argument(
"--num_fewshot",
type=int,
default=0,
help="number of fewshot examples to use for each task",
)
parser.add_argument(
"--is_lora",
type=bool,
Expand All @@ -55,7 +62,7 @@ def parse_arge():
return args


def run_vllm(model_id_or_path, tasks):
def run_vllm(model_id_or_path, tasks, num_fewshot=0):
model_args = {
"pretrained": model_id_or_path, # required: taken from UI, no default value
"tensor_parallel_size": 8,
Expand All @@ -69,12 +76,13 @@ def run_vllm(model_id_or_path, tasks):
--model_args={model_args_str} \
--tasks={tasks} \
--batch_size=auto \
--num_fewshot={num_fewshot} \
--output_path=/opt/ml/model/"
print(f"Running command: {cmd}")
return os.system(cmd)


def run_hf(model_id_path, peft_model_id_or_path, tasks):
def run_hf(model_id_path, peft_model_id_or_path, tasks, num_fewshot=0):
model_args = {
"pretrained": model_id_path, # required: taken from UI, no default value
"peft": peft_model_id_or_path,
Expand All @@ -86,6 +94,7 @@ def run_hf(model_id_path, peft_model_id_or_path, tasks):
--model_args {model_args_str} \
--tasks {tasks} \
--batch_size=auto \
--num_fewshot={num_fewshot} \
--output_path=/opt/ml/model/"
print(f"Running command: {cmd}")
return os.system(cmd)
Expand Down Expand Up @@ -145,15 +154,17 @@ def main():
# future.add_done_callback(lambda p: print(f"Uploaded to {p.result()}"))
# model_id = merged_model_path
if peft_model_id is not None and len(peft_model_id) > 0:
code = run_hf(model_id, peft_model_id, script_args.tasks)
code = run_hf(
model_id_path=model_id,
peft_model_id_or_path=peft_model_id,
tasks=script_args.tasks,
num_fewshot=script_args.num_fewshot)
else:
code = run_vllm(model_id_or_path=model_id, tasks=script_args.tasks)
code = run_vllm(model_id_or_path=model_id, tasks=script_args.tasks, num_fewshot=script_args.num_fewshot)

if code != 0:
raise Exception("Evaluation job has failed")


if __name__ == "__main__":
main()


1 change: 1 addition & 0 deletions start.sh
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
export HUGGINGFACE_HUB_CACHE=/tmp/.cache
export HF_HUB_ENABLE_HF_TRANSFER=1
export NUMEXPR_MAX_THREADS=96
export OMP_NUM_THREADS=96
pip install -e .
pip install -e ".[vllm]"
echo "$@"
Expand Down

0 comments on commit cd858b8

Please sign in to comment.