diff --git a/start.py b/start.py index 16efdf4ac3..317419e344 100644 --- a/start.py +++ b/start.py @@ -30,6 +30,13 @@ def parse_arge(): type=str, help="eval tasks, separated by comma, example: hellaswag,mmlu", ) + + parser.add_argument( + "--num_fewshot", + type=int, + default=0, + help="number of fewshot examples to use for each task", + ) parser.add_argument( "--is_lora", type=bool, @@ -55,7 +62,7 @@ def parse_arge(): return args -def run_vllm(model_id_or_path, tasks): +def run_vllm(model_id_or_path, tasks, num_fewshot=0): model_args = { "pretrained": model_id_or_path, # required: taken from UI, no default value "tensor_parallel_size": 8, @@ -69,12 +76,13 @@ def run_vllm(model_id_or_path, tasks): --model_args={model_args_str} \ --tasks={tasks} \ --batch_size=auto \ + --num_fewshot={num_fewshot} \ --output_path=/opt/ml/model/" print(f"Running command: {cmd}") return os.system(cmd) -def run_hf(model_id_path, peft_model_id_or_path, tasks): +def run_hf(model_id_path, peft_model_id_or_path, tasks, num_fewshot=0): model_args = { "pretrained": model_id_path, # required: taken from UI, no default value "peft": peft_model_id_or_path, @@ -86,6 +94,7 @@ def run_hf(model_id_path, peft_model_id_or_path, tasks): --model_args {model_args_str} \ --tasks {tasks} \ --batch_size=auto \ + --num_fewshot={num_fewshot} \ --output_path=/opt/ml/model/" print(f"Running command: {cmd}") return os.system(cmd) @@ -145,9 +154,13 @@ def main(): # future.add_done_callback(lambda p: print(f"Uploaded to {p.result()}")) # model_id = merged_model_path if peft_model_id is not None and len(peft_model_id) > 0: - code = run_hf(model_id, peft_model_id, script_args.tasks) + code = run_hf( + model_id_path=model_id, + peft_model_id_or_path=peft_model_id, + tasks=script_args.tasks, + num_fewshot=script_args.num_fewshot) else: - code = run_vllm(model_id_or_path=model_id, tasks=script_args.tasks) + code = run_vllm(model_id_or_path=model_id, tasks=script_args.tasks, num_fewshot=script_args.num_fewshot) if code != 0: raise Exception("Evaluation job has failed") @@ -155,5 +168,3 @@ def main(): if __name__ == "__main__": main() - - diff --git a/start.sh b/start.sh index e78c5884bd..da6959d26b 100644 --- a/start.sh +++ b/start.sh @@ -1,6 +1,7 @@ export HUGGINGFACE_HUB_CACHE=/tmp/.cache export HF_HUB_ENABLE_HF_TRANSFER=1 export NUMEXPR_MAX_THREADS=96 +export OMP_NUM_THREADS=96 pip install -e . pip install -e ".[vllm]" echo "$@"