diff --git a/recipes/3p_integrations/modal/many-llamas-human-eval/README.md b/recipes/3p_integrations/modal/many-llamas-human-eval/README.md new file mode 100644 index 000000000..1c3c1b661 --- /dev/null +++ b/recipes/3p_integrations/modal/many-llamas-human-eval/README.md @@ -0,0 +1,71 @@ +# Many-Llamas Human-Eval + +In this directory, we run an experiment answering the question: + +*If we run enough Llama models in parallel, can they outperform GPT-4o on HumanEval?* + +It seeks to increase model performance not through scaling parameters, but by scaling compute time. + +### Technical Blog + +This experiment built by the team at [Modal](https://modal.com), and is described in the following blog post: + +[Beat GPT-4o at Python by searching with 100 small Llamas](https://modal.com/blog/llama-human-eval) + +The experiment has since been upgraded to use the [Llama 3.2 3B Instruct](https://huggingface.co/meta-llama/Llama-3.2-3B-Instruct) model, and runnable end-to-end using the Modal serverless platform. + +## Run it yourself + +### Install the Modal CLI +From within your virtual environment, run: +```bash +pip install modal +``` +And if you're new to Modal, authenticate with: +```bash +modal setup +# or if that doesn't work, try +# python -m modal setup +``` + +That's all! + +This CLI will execute your modal apps, which build and run containers on the cloud, on your GPU of choice. + +### HuggingFace Pull Access + +To download the model, you'll first need to accept the [Llama 3.2 License](https://huggingface.co/meta-llama/Llama-3.2-3B-Instruct) on HuggingFace and be approved for access. + +Then, create a [modal secret](https://modal.com/secrets) named `huggingface`, to which you'll add your HF_TOKEN as an environment variable. + +### Run The Experiment + +This command will run every step for you: +```bash +bash run_e2e.sh +``` + +Or if you prefer to run it manually, you can step through each of the modal commands in [the script](./run_e2e.sh). + +This will execute: +1. Downloading the Llama 3.2 3B Instruct model to a cloud volume +2. Deploying a vLLM inference server to GPUs +3. Running hundreds of parallel generations on the HumanEval test set +4. Running the evaluation script to compute pass@k and fail@k +5. Generating graphs of pass@k and fail@k + +### Results + +The resulting plots of the evals will be saved locally to: +- `/tmp/plot-pass-k.jpeg` +- `/tmp/plot-fail-k.jpeg` + +`/tmp/plot-pass-k.jpeg` shows pass@k for the Llama 3.2 3B Instruct model vs pass@1 for GPT-4o. + +![plot-pass-k](https://github.com/user-attachments/assets/11e9dc6e-4322-4d44-b928-4ed7c4ce8262) + +You'll see that at 100 generations, the Llama model is able to perform on-par with GPT-4o. At higher scale, the Llama model will outperform GPT-4o. + +`/tmp/plot-fail-k.jpeg` shows fail@k across a log-scale, showing smooth scaling of this method. + +![plot-fail-k](https://github.com/user-attachments/assets/7286e4ff-5090-4288-bd62-8a078c6dc5a1) diff --git a/recipes/3p_integrations/modal/many-llamas-human-eval/download.py b/recipes/3p_integrations/modal/many-llamas-human-eval/download.py new file mode 100644 index 000000000..d96f36537 --- /dev/null +++ b/recipes/3p_integrations/modal/many-llamas-human-eval/download.py @@ -0,0 +1,64 @@ +# ## Downloading Llama 3.2 3B Instruct Model +# This script uses a Modal Function to download the model into a cloud Volume. +# +# Run it with: +# modal run download + +import modal + +MODELS_DIR = "/llamas" +DEFAULT_NAME = "meta-llama/Llama-3.2-3B-Instruct" + +MINUTES = 60 +HOURS = 60 * MINUTES + +# Create a modal Volume to store the model +volume = modal.Volume.from_name("llamas", create_if_missing=True) + +# This defines the image to use for the modal function +image = ( + modal.Image.debian_slim(python_version="3.10") + .pip_install( + [ + "huggingface_hub", # download models from the Hugging Face Hub + "hf-transfer", # download models faster with Rust + ] + ) + .env({"HF_HUB_ENABLE_HF_TRANSFER": "1"}) +) + +# We run the function from a modal App, which will have our HF_SECRET env var set. +# Add your HuggingFace secret access token here: https://modal.com/secrets +# secret name: huggingface +# env var name: HF_TOKEN +app = modal.App(image=image, secrets=[modal.Secret.from_name("huggingface")]) + +# This function will be ran in the cloud, with the volume mounted. +@app.function(volumes={MODELS_DIR: volume}, timeout=4 * HOURS) +def download_model(model_name, force_download=False): + from huggingface_hub import snapshot_download + + volume.reload() + + snapshot_download( + model_name, + local_dir=MODELS_DIR + "/" + model_name, + ignore_patterns=[ + "*.pt", + "*.bin", + "*.pth", + "original/*", + ], # Ensure safetensors + force_download=force_download, + ) + + volume.commit() + + print("Model successfully downloaded") + +@app.local_entrypoint() +def main( + model_name: str = DEFAULT_NAME, + force_download: bool = False, +): + download_model.remote(model_name, force_download) diff --git a/recipes/3p_integrations/modal/many-llamas-human-eval/eval.py b/recipes/3p_integrations/modal/many-llamas-human-eval/eval.py new file mode 100644 index 000000000..5d2c135be --- /dev/null +++ b/recipes/3p_integrations/modal/many-llamas-human-eval/eval.py @@ -0,0 +1,96 @@ +# ## Evaluating HumanEval Results using Modal Sandboxes +# This script will take generated results and evaluate them. +# We use Modal Sandboxes to safely evaluate LLM-generated results. +# +# Run it with: +# modal run eval + +from pathlib import Path + +import modal + +app = modal.App("many-llamas-human-eval") + +volume = modal.Volume.from_name("humaneval", create_if_missing=True) + +sandbox_image = ( + modal.Image.debian_slim() + .apt_install("git") + .run_commands( + "git clone https://github.com/modal-labs/human-eval.git", + "pip install -e human-eval", + ) +) + +MINUTES = 60 + +@app.function(volumes={"/humaneval": volume}, timeout=10 * MINUTES) +def eval_single_task(sample_file_path: str, problem_file_path: str): + with modal.Volume.ephemeral() as vol: + with vol.batch_upload() as batch: + batch.put_file(sample_file_path, "samples.jsonl") + batch.put_file(problem_file_path, "problems.jsonl") + + print(f"Starting sandbox for {sample_file_path}") + sandbox = modal.Sandbox.create( + "bash", + "-c", + "evaluate_functional_correctness vol/samples.jsonl --problem_file=vol/problems.jsonl --n_workers=32", + image=sandbox_image, + volumes={"/vol": vol}, + timeout=10 * MINUTES, + cpu=32, + ) + + try: + sandbox.wait() + print(f"Finished sandbox for {sample_file_path}") + except FunctionTimeoutError: + print("Sandbox timed out") + + if sandbox.returncode == 0: + print(sandbox.stdout.read()) + data = b"" + for chunk in vol.read_file("samples.jsonl_results.jsonl"): + data += chunk + with open(f"{sample_file_path}_results.jsonl", "wb") as f: + f.write(data) + else: + print(f"Tests failed with code {sandbox.returncode}") + print(sandbox.stderr.read()) + + +@app.function(volumes={"/humaneval": volume}, timeout=10 * MINUTES) +def eval_all_tasks(): + import os + + volume.reload() + + # Find all files matching /humaneval/{env}/{run}/{id}.jsonl + envs = [element for element in Path("/humaneval").iterdir() if element.is_dir()] + for env in envs: + print(f"looking in {env}") + problem_file = env / "data.jsonl" + + pattern = "*/*.jsonl" + handles = [] + for file_path in env.glob(pattern): + # Skip files that end with _results.jsonl + if str(file_path).endswith("_results.jsonl"): + continue + + print(f"Checking {file_path}") + # Check if the corresponding results file exists + results_file = f"{file_path}_results.jsonl" + if not os.path.exists(results_file): + # If it doesn't exist, run do_eval + print("Spawning on", file_path, problem_file) + handles.append(eval_single_task.spawn(file_path, problem_file)) + + for handle in handles: + handle.get() + + +@app.local_entrypoint() +def main(): + eval_all_tasks.remote() diff --git a/recipes/3p_integrations/modal/many-llamas-human-eval/generate.py b/recipes/3p_integrations/modal/many-llamas-human-eval/generate.py new file mode 100644 index 000000000..4ea6cd9ce --- /dev/null +++ b/recipes/3p_integrations/modal/many-llamas-human-eval/generate.py @@ -0,0 +1,248 @@ +# ## Generating HumanEval Results with our Llama 3.2 3B Instruct Model +# This app starts many parallel clients to send requests to the vLLM server. +# +# For each of the tasks in the HumanEval test set, we'll run a client to request 1000 completions. +# Results are saved to our mounted volume. +# +# Run it with: +# modal run generate --data-dir test --no-dry-run --n 1000 --subsample 100 + +from datetime import datetime +import json +from pathlib import Path +from dataclasses import dataclass, asdict + +import modal + +# This defines the image to use for running openai clients in parallel +image = modal.Image.debian_slim(python_version="3.11").pip_install( + "openai==1.38.0", "datasets==2.20.0" +) + +app = modal.App("many-llamas-human-eval", image=image) + +volume = modal.Volume.from_name("humaneval", create_if_missing=True) +DATA_DIR = Path("/mnt/humaneval") + +default_system_prompt = "Write the body for the Python function provided in the prompt below. Do not write anything else. Your output will be directly concatenated with the prompt and the resulting function executed against tests." + +MINUTES = 60 # seconds +HOURS = 60 * MINUTES + +@dataclass +class CompletionParams: + model: str = None + max_tokens: int = 1024 + temperature: float = 0.7 + top_p: float = 0.9 + frequency_penalty: float = 0 + presence_penalty: float = 0 + n: int = 1 + stop: str = None + seed: int = None + +@dataclass +class ClientParams: + app_name: str = "many-llamas-human-eval" + workspace: str = None + api_key: str = "super-secret-token" # match the secret in inference.py + + @property + def url(self): + return f"https://{self.workspace}--{self.app_name}-serve.modal.run/v1" + + +@app.local_entrypoint() +def main( + app_name: str = "many-llamas-human-eval", + workspace: str = None, + api_key: str = "super-secret-token", + model: str = None, + max_tokens: int = 1024, + temperature: float = 0.7, + top_p: float = 0.9, + frequency_penalty: float = 0, + presence_penalty: float = 0, + n: int = 1, + stop: str = None, + seed: int = None, + data_dir: str = "dev-llm", + subsample: int = 1, # percent of the test split to read + system_prompt: str = default_system_prompt, + dry_run: bool = True, +): + if workspace is None: + workspace = modal.config._profile + + client_params = ClientParams(app_name, workspace, api_key) + + completion_params = CompletionParams( + model=model, + max_tokens=max_tokens, + temperature=temperature, + top_p=top_p, + frequency_penalty=frequency_penalty, + presence_penalty=presence_penalty, + n=n, + stop=stop, + seed=seed, + ) + + # Run a remote download function to save the HumanEval dataset in the cloud volume + save_dataset.remote(path=data_dir, subsample=subsample) + + # Run a remote generation function + results = run_human_eval.remote( + client_params=client_params, + completion_params=completion_params, + system_prompt=system_prompt, + data_dir=data_dir, + dry_run=dry_run, + ) + if results: + with open("/tmp/results.jsonl", "w") as f: + f.writelines(json.dumps(result) + "\n" for result in results) + print(f"results saved locally to {f.name}") + +# This is the parent function that spawns a client for each eval task +@app.function(volumes={DATA_DIR: volume}, timeout=1 * HOURS) +def run_human_eval( + client_params: ClientParams, + completion_params: CompletionParams, + data_dir="dev-llm", + system_prompt: str = default_system_prompt, + dry_run=True, +): + dataset = load_dataset(data_dir) + + timestamp = datetime.utcnow().isoformat() + "Z" + output_dir = Path(DATA_DIR) / data_dir / f"run-{timestamp}" + output_dir.mkdir(parents=True, exist_ok=True) + handles = [] + print(f"Eval set contains {len(dataset)} items") + + # For each eval item in the dataset, spawn a parallel openAI client worker that generates n completions each + print(Colors.BOLD, f"Spawning clients for each eval item. You may notice a brief wait while the inference server(s) boot.", Colors.END, sep="") + for i, item in enumerate(dataset): + handles.append( + run_item.spawn( + item, + client_params, + completion_params, + system_prompt, + output_dir, + dry_run, + ) + ) + + for handle in handles: + result = handle.get() + + if not dry_run: + return result + +# This function is responsible for generating n completions for a single eval item +# It calls into our deployed vLLM server and saves results to the cloud volume +@app.function(volumes={DATA_DIR: volume}, timeout=1 * HOURS) +def run_item( + item: dict, + client_params: ClientParams, + completion_params: CompletionParams, + system_prompt: str, + output_dir: Path, + dry_run: bool, +): + client = create_client(client_params) + if not completion_params.model: + model = client.models.list().data[0] + model = model.id + completion_params.model = model + + prompt = item["prompt"] + messages = [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": prompt}, + ] + + per_request = 250 + ct, completions = completion_params.n, [] + if not dry_run: + while ct > 0: + response = get_completion( + client, + messages=messages, + **asdict(completion_params) | dict(n=min(ct, per_request)), + ) + if response: + completions += [ + { + "task_id": item["task_id"], + "completion": choice.message.content, + } + for choice in response.choices + ] + ct -= per_request + + index = item["task_id"].split("/")[-1] + output_path = output_dir / f"{index}.jsonl" + output_path.parent.mkdir(parents=True, exist_ok=True) + with open(output_path, "w") as f: + f.writelines(json.dumps(completion) + "\n" for completion in completions) + + print(Colors.GREEN + f"Completions saved to {output_path}" + Colors.END) + + +class Colors: + """ANSI color codes""" + + GREEN = "\033[0;32m" + RED = "\033[0;31m" + BLUE = "\033[0;34m" + GRAY = "\033[0;90m" + BOLD = "\033[1m" + END = "\033[0m" + + +def get_completion(client, **kwargs): + try: + response = client.chat.completions.create(**kwargs) + return response + except Exception as e: + print(Colors.RED, f"Error during API call: {e}", Colors.END, sep="") + return None + + +def create_client(client_params: ClientParams): + from openai import OpenAI + + client = OpenAI(api_key=client_params.api_key) + client.base_url = client_params.url + + return client + +# This function downloads the HumanEval dataset +@app.function(volumes={DATA_DIR: volume}) +def save_dataset(path="dev-llm", subsample: int = 1): + import datasets + + path = DATA_DIR / path + + ds = datasets.load_dataset( + "openai/openai_humaneval", + # reads 0% to subsample% of the test split + split=datasets.ReadInstruction("test", to=subsample, unit="%"), + ) + + ds.to_json(path / "data.jsonl") + + volume.commit() + + +def load_dataset(path="dev-llm"): + import datasets + + path = DATA_DIR / path + + ds = datasets.load_dataset(path=str(path), data_files="data.jsonl") + + return ds["train"] diff --git a/recipes/3p_integrations/modal/many-llamas-human-eval/inference.py b/recipes/3p_integrations/modal/many-llamas-human-eval/inference.py new file mode 100644 index 000000000..45bb60420 --- /dev/null +++ b/recipes/3p_integrations/modal/many-llamas-human-eval/inference.py @@ -0,0 +1,149 @@ +# ## Serving Llama 3.2 3B Instruct Model With vLLM +# This app runs a vLLM server on an A100 GPU. +# +# Run it with: +# modal deploy inference + +import modal + +# This defines the image to use for the vLLM server container +vllm_image = modal.Image.debian_slim(python_version="3.10").pip_install( + "vllm==0.5.3post1" +) + + +MODELS_DIR = "/llamas" +MODEL_NAME = "meta-llama/Llama-3.2-3B-Instruct" + +# Ensure the model is downloaded and the volume exists +try: + volume = modal.Volume.lookup("llamas", create_if_missing=False) +except modal.exception.NotFoundError: + raise Exception("Download models first with modal run download") + +app = modal.App("many-llamas-human-eval") + +N_GPU = 1 # tip: for best results, first upgrade to more powerful GPUs, and only then increase GPU count +TOKEN = ( + "super-secret-token" # auth token. for production use, replace with a modal.Secret +) + +MINUTES = 60 # seconds +HOURS = 60 * MINUTES + +@app.function( + image=vllm_image, + gpu=modal.gpu.A100(count=N_GPU, size="40GB"), + container_idle_timeout=5 * MINUTES, + timeout=24 * HOURS, + allow_concurrent_inputs=20, # VLLM will batch requests so many can be received at once + volumes={MODELS_DIR: volume}, + concurrency_limit=10, # max 10 GPUs +) +@modal.asgi_app() +def serve(): + import fastapi + import vllm.entrypoints.openai.api_server as api_server + from vllm.engine.arg_utils import AsyncEngineArgs + from vllm.engine.async_llm_engine import AsyncLLMEngine + from vllm.entrypoints.logger import RequestLogger + from vllm.entrypoints.openai.serving_chat import OpenAIServingChat + from vllm.entrypoints.openai.serving_completion import ( + OpenAIServingCompletion, + ) + from vllm.usage.usage_lib import UsageContext + + volume.reload() # ensure we have the latest version of the weights + + # create a fastAPI app that uses vLLM's OpenAI-compatible router + web_app = fastapi.FastAPI( + title=f"OpenAI-compatible {MODEL_NAME} server", + description="Run an OpenAI-compatible LLM server with vLLM on modal.com", + version="0.0.1", + docs_url="/docs", + ) + + # security: CORS middleware for external requests + http_bearer = fastapi.security.HTTPBearer( + scheme_name="Bearer Token", + description="See code for authentication details.", + ) + web_app.add_middleware( + fastapi.middleware.cors.CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], + ) + + # security: inject dependency on authed routes + async def is_authenticated(api_key: str = fastapi.Security(http_bearer)): + if api_key.credentials != TOKEN: + raise fastapi.HTTPException( + status_code=fastapi.status.HTTP_401_UNAUTHORIZED, + detail="Invalid authentication credentials", + ) + return {"username": "authenticated_user"} + + router = fastapi.APIRouter(dependencies=[fastapi.Depends(is_authenticated)]) + + # wrap vllm's router in auth router + router.include_router(api_server.router) + # add authed vllm to our fastAPI app + web_app.include_router(router) + + engine_args = AsyncEngineArgs( + model=MODELS_DIR + "/" + MODEL_NAME, + tensor_parallel_size=N_GPU, + gpu_memory_utilization=0.90, + max_model_len=2048, + enforce_eager=False, # capture the graph for faster inference, but slower cold starts (30s > 20s) + ) + + engine = AsyncLLMEngine.from_engine_args( + engine_args, usage_context=UsageContext.OPENAI_API_SERVER + ) + + model_config = get_model_config(engine) + + request_logger = RequestLogger(max_log_len=2048) + + api_server.openai_serving_chat = OpenAIServingChat( + engine, + model_config=model_config, + served_model_names=[MODEL_NAME], + chat_template=None, + response_role="assistant", + lora_modules=[], + prompt_adapters=[], + request_logger=request_logger, + ) + api_server.openai_serving_completion = OpenAIServingCompletion( + engine, + model_config=model_config, + served_model_names=[MODEL_NAME], + lora_modules=[], + prompt_adapters=[], + request_logger=request_logger, + ) + + return web_app + + +def get_model_config(engine): + import asyncio + + try: # adapted from vLLM source -- https://github.com/vllm-project/vllm/blob/507ef787d85dec24490069ffceacbd6b161f4f72/vllm/entrypoints/openai/api_server.py#L235C1-L247C1 + event_loop = asyncio.get_running_loop() + except RuntimeError: + event_loop = None + + if event_loop is not None and event_loop.is_running(): + # If the current is instanced by Ray Serve, + # there is already a running event loop + model_config = event_loop.run_until_complete(engine.get_model_config()) + else: + # When using single vLLM without engine_use_ray + model_config = asyncio.run(engine.get_model_config()) + + return model_config diff --git a/recipes/3p_integrations/modal/many-llamas-human-eval/plot.py b/recipes/3p_integrations/modal/many-llamas-human-eval/plot.py new file mode 100644 index 000000000..db225fb13 --- /dev/null +++ b/recipes/3p_integrations/modal/many-llamas-human-eval/plot.py @@ -0,0 +1,194 @@ +# ## Plotting HumanEval Results +# This script will calculate pass@k and fail@k for our experiment and plot them. +# +# Run it with: +# modal run plot + +import io +import json +from pathlib import Path +from typing import List, Union +import itertools + +import modal + +try: + volume = modal.Volume.lookup("humaneval", create_if_missing=False) +except modal.exception.NotFoundError: + raise Exception("Generate results first with modal run generate --data-dir test --no-dry-run --n 1000 --subsample 100") + + +image = modal.Image.debian_slim(python_version="3.11").pip_install( + "numpy==1.26.4", + "pandas==2.2.3", + "matplotlib==3.9.2", + "seaborn==0.13.2", +) + +app = modal.App("many-llamas-human-eval", image=image) + +DATA_DIR = Path("/mnt/humaneval") + +with image.imports(): + import numpy as np + import pandas as pd + import matplotlib.pyplot as plt + import seaborn as sns + +@app.function(volumes={DATA_DIR: volume}) +def render_plots(): + run_dirs = list(sorted((DATA_DIR / "test").glob("run-*"))) + + for run_dir in reversed(run_dirs): + if len(list(run_dir.iterdir())) < 150: + print(f"skipping incomplete run {run_dir}") + else: + break + + all_result_paths = list(run_dir.glob("*.jsonl_results.jsonl")) + + data = [] + for path in all_result_paths: + data += [json.loads(line) for line in path.read_text(encoding='utf-8').splitlines()] + + for element in data: + del element["completion"] + + df = pd.DataFrame.from_records(data) + + gb = df.groupby("task_id") + passes = gb["passed"].sum() + + def estimate_pass_at_k( + num_samples: Union[int, List[int], np.ndarray], + num_correct: Union[List[int], np.ndarray], + k: int + ) -> np.ndarray: + """ + Estimates pass@k of each problem and returns them in an array. + """ + + def estimator(n: int, c: int, k: int) -> float: + """ + Calculates 1 - comb(n - c, k) / comb(n, k). + """ + if n - c < k: + return 1.0 + return 1.0 - np.prod(1.0 - k / np.arange(n - c + 1, n + 1)) + + if isinstance(num_samples, int): + num_samples_it = itertools.repeat(num_samples, len(num_correct)) + else: + assert len(num_samples) == len(num_correct) + num_samples_it = iter(num_samples) + + return np.array([estimator(int(n), int(c), k) for n, c in zip(num_samples_it, num_correct)]) + + pass_at_ks = {} + + for k in [1, 10, 100, 1000]: + pass_at_ks[k] = estimate_pass_at_k(1000, passes, k) + + pass_at_k = {k: np.mean(v) for k, v in pass_at_ks.items()} + + plot_df = pd.DataFrame( + {"k": pass_at_k.keys(), + "pass@k": pass_at_k.values()} + ) + plot_df["fail@k"] = 1 - plot_df["pass@k"] + + sns.set_theme(style='dark') + plt.style.use("dark_background") + + plt.rcParams['font.sans-serif'] = ["Inter", "Arial", "DejaVu Sans", "Liberation Sans", "Bitstream Vera Sans", "sans-serif"] + + sns.despine() + + sns.set_context("talk", rc={"lines.linewidth": 2.5}) + + gpt4o_benchmark = 0.902 + + # First plot + plt.figure(figsize=(10, 6)) + fg = sns.lineplot( + x="k", + y="pass@k", + data=plot_df, + color="#7FEE64", + linewidth=6, + alpha=0.9, + label="Llama 3.2 3B Instruct pass@k" + ) + + initial_lim = fg.axes.get_xlim() + fg.axes.hlines( + gpt4o_benchmark, *initial_lim, + linestyle="--", + alpha=0.6, + zorder=-1, + label="GPT-4o fail@1" + ) + fg.axes.set_xlim(*initial_lim) + fg.axes.set_ylabel("") + fg.axes.set_ylim(0, 1) + plt.tight_layout(pad=1.2) + plt.legend() + + # Save the first plot as bytes + img_buffer = io.BytesIO() + plt.savefig(img_buffer, format='jpeg') + plot_1_img_bytes = img_buffer.getvalue() + plt.close() + + # Second plot + plt.figure(figsize=(10, 6)) + fg = sns.lineplot( + x="k", + y="fail@k", + data=plot_df, + color="#7FEE64", + linewidth=6, + alpha=0.9, + label="Llama 3.2 3B Instruct fail@k" + ) + + initial_lim = fg.axes.get_xlim() + fg.axes.hlines( + 1 - gpt4o_benchmark, *initial_lim, + linestyle="--", + alpha=0.6, + zorder=-1, + label="GPT-4o fail@1" + ) + fg.axes.set_xlim(*initial_lim) + fg.axes.set_ylabel("") + fg.axes.set_yscale("log") + fg.axes.set_xscale("log") + fg.axes.set_xlim(0.5, 2000) + fg.axes.set_ylim(1e-2, 1e0) + plt.tight_layout(pad=1.2) + plt.legend() + + # Save the second plot as bytes + img_buffer = io.BytesIO() + plt.savefig(img_buffer, format='jpeg') + plot_2_img_bytes = img_buffer.getvalue() + plt.close() + + return [plot_1_img_bytes, plot_2_img_bytes] + +@app.local_entrypoint() +def main(): + plots = render_plots.remote() + + assert len(plots) == 2 + + with open ("/tmp/plot-pass-k.jpeg", "wb") as f: + f.write(plots[0]) + + with open ("/tmp/plot-fail-k.jpeg", "wb") as f: + f.write(plots[1]) + + print("Plots saved to:") + print(" /tmp/plot-pass-k.jpeg") + print(" /tmp/plot-fail-k.jpeg") \ No newline at end of file diff --git a/recipes/3p_integrations/modal/many-llamas-human-eval/run_e2e.sh b/recipes/3p_integrations/modal/many-llamas-human-eval/run_e2e.sh new file mode 100644 index 000000000..d544425b5 --- /dev/null +++ b/recipes/3p_integrations/modal/many-llamas-human-eval/run_e2e.sh @@ -0,0 +1,21 @@ +#!/bin/bash +set -euo pipefail +IFS=$'\n\t' + +command -v modal >/dev/null 2>&1 || { echo >&2 "modal command not found. Install modal first! Aborting."; exit 1; } + +echo 'downloading LLaMA 3.2 3B Instruct model' +echo 'make sure to create a Secret called huggingface on Modal and accept the LLaMA 3.2 license' +modal run download.py + +echo 'deploying vLLM inference server' +modal deploy inference.py + +echo 'running HumanEval generation' +modal run generate.py --data-dir test --no-dry-run --n 1000 --subsample 100 + +echo 'running HumanEval evaluation' +modal run eval.py + +echo 'generating graphs for pass@k and fail@k' +modal run plot.py \ No newline at end of file