Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a verify-download command to llama CLI #457

Merged
merged 2 commits into from
Nov 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion llama_stack/cli/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from .download import Download
from .model import ModelParser
from .stack import StackParser
from .verify_download import VerifyDownload


class LlamaCLIParser:
Expand All @@ -27,9 +28,10 @@ def __init__(self):
subparsers = self.parser.add_subparsers(title="subcommands")

# Add sub-commands
Download.create(subparsers)
ModelParser.create(subparsers)
StackParser.create(subparsers)
Download.create(subparsers)
VerifyDownload.create(subparsers)

def parse_args(self) -> argparse.Namespace:
return self.parser.parse_args()
Expand Down
2 changes: 2 additions & 0 deletions llama_stack/cli/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from llama_stack.cli.model.download import ModelDownload
from llama_stack.cli.model.list import ModelList
from llama_stack.cli.model.prompt_format import ModelPromptFormat
from llama_stack.cli.model.verify_download import ModelVerifyDownload

from llama_stack.cli.subcommand import Subcommand

Expand All @@ -32,3 +33,4 @@ def __init__(self, subparsers: argparse._SubParsersAction):
ModelList.create(subparsers)
ModelPromptFormat.create(subparsers)
ModelDescribe.create(subparsers)
ModelVerifyDownload.create(subparsers)
24 changes: 24 additions & 0 deletions llama_stack/cli/model/verify_download.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

import argparse

from llama_stack.cli.subcommand import Subcommand


class ModelVerifyDownload(Subcommand):
def __init__(self, subparsers: argparse._SubParsersAction):
super().__init__()
self.parser = subparsers.add_parser(
"verify-download",
prog="llama model verify-download",
description="Verify the downloaded checkpoints' checksums",
formatter_class=argparse.RawTextHelpFormatter,
)

from llama_stack.cli.verify_download import setup_verify_download_parser

setup_verify_download_parser(self.parser)
144 changes: 144 additions & 0 deletions llama_stack/cli/verify_download.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

import argparse
import hashlib
from dataclasses import dataclass
from functools import partial
from pathlib import Path
from typing import Dict, List, Optional

from rich.console import Console
from rich.progress import Progress, SpinnerColumn, TextColumn

from llama_stack.cli.subcommand import Subcommand


@dataclass
class VerificationResult:
filename: str
expected_hash: str
actual_hash: Optional[str]
exists: bool
matches: bool


class VerifyDownload(Subcommand):
"""Llama cli for verifying downloaded model files"""

def __init__(self, subparsers: argparse._SubParsersAction):
super().__init__()
self.parser = subparsers.add_parser(
"verify-download",
prog="llama verify-download",
description="Verify integrity of downloaded model files",
formatter_class=argparse.RawTextHelpFormatter,
)
setup_verify_download_parser(self.parser)


def setup_verify_download_parser(parser: argparse.ArgumentParser) -> None:
parser.add_argument(
"--model-id",
required=True,
help="Model ID to verify",
)
parser.set_defaults(func=partial(run_verify_cmd, parser=parser))


def calculate_md5(filepath: Path, chunk_size: int = 8192) -> str:
md5_hash = hashlib.md5()
with open(filepath, "rb") as f:
for chunk in iter(lambda: f.read(chunk_size), b""):
md5_hash.update(chunk)
return md5_hash.hexdigest()


def load_checksums(checklist_path: Path) -> Dict[str, str]:
checksums = {}
with open(checklist_path, "r") as f:
for line in f:
if line.strip():
md5sum, filepath = line.strip().split(" ", 1)
# Remove leading './' if present
filepath = filepath.lstrip("./")
checksums[filepath] = md5sum
return checksums


def verify_files(
model_dir: Path, checksums: Dict[str, str], console: Console
) -> List[VerificationResult]:
results = []

with Progress(
SpinnerColumn(),
TextColumn("[progress.description]{task.description}"),
console=console,
) as progress:
for filepath, expected_hash in checksums.items():
full_path = model_dir / filepath
task_id = progress.add_task(f"Verifying {filepath}...", total=None)

exists = full_path.exists()
actual_hash = None
matches = False

if exists:
actual_hash = calculate_md5(full_path)
matches = actual_hash == expected_hash

results.append(
VerificationResult(
filename=filepath,
expected_hash=expected_hash,
actual_hash=actual_hash,
exists=exists,
matches=matches,
)
)

progress.remove_task(task_id)

return results


def run_verify_cmd(args: argparse.Namespace, parser: argparse.ArgumentParser):
from llama_stack.distribution.utils.model_utils import model_local_dir

console = Console()
model_dir = Path(model_local_dir(args.model_id))
checklist_path = model_dir / "checklist.chk"

if not model_dir.exists():
parser.error(f"Model directory not found: {model_dir}")

if not checklist_path.exists():
parser.error(f"Checklist file not found: {checklist_path}")

checksums = load_checksums(checklist_path)
results = verify_files(model_dir, checksums, console)

# Print results
console.print("\nVerification Results:")

all_good = True
for result in results:
if not result.exists:
console.print(f"[red]❌ {result.filename}: File not found[/red]")
all_good = False
elif not result.matches:
console.print(
f"[red]❌ {result.filename}: Hash mismatch[/red]\n"
f" Expected: {result.expected_hash}\n"
f" Got: {result.actual_hash}"
)
all_good = False
else:
console.print(f"[green]✓ {result.filename}: Verified[/green]")

if all_good:
console.print("\n[green]All files verified successfully![/green]")