diff --git a/llama_stack/cli/llama.py b/llama_stack/cli/llama.py index 8ca82db81..f0466facd 100644 --- a/llama_stack/cli/llama.py +++ b/llama_stack/cli/llama.py @@ -9,6 +9,7 @@ from .download import Download from .model import ModelParser from .stack import StackParser +from .verify_download import VerifyDownload class LlamaCLIParser: @@ -27,9 +28,10 @@ def __init__(self): subparsers = self.parser.add_subparsers(title="subcommands") # Add sub-commands - Download.create(subparsers) ModelParser.create(subparsers) StackParser.create(subparsers) + Download.create(subparsers) + VerifyDownload.create(subparsers) def parse_args(self) -> argparse.Namespace: return self.parser.parse_args() diff --git a/llama_stack/cli/model/model.py b/llama_stack/cli/model/model.py index 3804bf43c..f59ba8376 100644 --- a/llama_stack/cli/model/model.py +++ b/llama_stack/cli/model/model.py @@ -10,6 +10,7 @@ from llama_stack.cli.model.download import ModelDownload from llama_stack.cli.model.list import ModelList from llama_stack.cli.model.prompt_format import ModelPromptFormat +from llama_stack.cli.model.verify_download import ModelVerifyDownload from llama_stack.cli.subcommand import Subcommand @@ -32,3 +33,4 @@ def __init__(self, subparsers: argparse._SubParsersAction): ModelList.create(subparsers) ModelPromptFormat.create(subparsers) ModelDescribe.create(subparsers) + ModelVerifyDownload.create(subparsers) diff --git a/llama_stack/cli/model/verify_download.py b/llama_stack/cli/model/verify_download.py new file mode 100644 index 000000000..b8e6bf173 --- /dev/null +++ b/llama_stack/cli/model/verify_download.py @@ -0,0 +1,24 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import argparse + +from llama_stack.cli.subcommand import Subcommand + + +class ModelVerifyDownload(Subcommand): + def __init__(self, subparsers: argparse._SubParsersAction): + super().__init__() + self.parser = subparsers.add_parser( + "verify-download", + prog="llama model verify-download", + description="Verify the downloaded checkpoints' checksums", + formatter_class=argparse.RawTextHelpFormatter, + ) + + from llama_stack.cli.verify_download import setup_verify_download_parser + + setup_verify_download_parser(self.parser) diff --git a/llama_stack/cli/verify_download.py b/llama_stack/cli/verify_download.py new file mode 100644 index 000000000..f86bed6af --- /dev/null +++ b/llama_stack/cli/verify_download.py @@ -0,0 +1,144 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import argparse +import hashlib +from dataclasses import dataclass +from functools import partial +from pathlib import Path +from typing import Dict, List, Optional + +from rich.console import Console +from rich.progress import Progress, SpinnerColumn, TextColumn + +from llama_stack.cli.subcommand import Subcommand + + +@dataclass +class VerificationResult: + filename: str + expected_hash: str + actual_hash: Optional[str] + exists: bool + matches: bool + + +class VerifyDownload(Subcommand): + """Llama cli for verifying downloaded model files""" + + def __init__(self, subparsers: argparse._SubParsersAction): + super().__init__() + self.parser = subparsers.add_parser( + "verify-download", + prog="llama verify-download", + description="Verify integrity of downloaded model files", + formatter_class=argparse.RawTextHelpFormatter, + ) + setup_verify_download_parser(self.parser) + + +def setup_verify_download_parser(parser: argparse.ArgumentParser) -> None: + parser.add_argument( + "--model-id", + required=True, + help="Model ID to verify", + ) + parser.set_defaults(func=partial(run_verify_cmd, parser=parser)) + + +def calculate_md5(filepath: Path, chunk_size: int = 8192) -> str: + md5_hash = hashlib.md5() + with open(filepath, "rb") as f: + for chunk in iter(lambda: f.read(chunk_size), b""): + md5_hash.update(chunk) + return md5_hash.hexdigest() + + +def load_checksums(checklist_path: Path) -> Dict[str, str]: + checksums = {} + with open(checklist_path, "r") as f: + for line in f: + if line.strip(): + md5sum, filepath = line.strip().split(" ", 1) + # Remove leading './' if present + filepath = filepath.lstrip("./") + checksums[filepath] = md5sum + return checksums + + +def verify_files( + model_dir: Path, checksums: Dict[str, str], console: Console +) -> List[VerificationResult]: + results = [] + + with Progress( + SpinnerColumn(), + TextColumn("[progress.description]{task.description}"), + console=console, + ) as progress: + for filepath, expected_hash in checksums.items(): + full_path = model_dir / filepath + task_id = progress.add_task(f"Verifying {filepath}...", total=None) + + exists = full_path.exists() + actual_hash = None + matches = False + + if exists: + actual_hash = calculate_md5(full_path) + matches = actual_hash == expected_hash + + results.append( + VerificationResult( + filename=filepath, + expected_hash=expected_hash, + actual_hash=actual_hash, + exists=exists, + matches=matches, + ) + ) + + progress.remove_task(task_id) + + return results + + +def run_verify_cmd(args: argparse.Namespace, parser: argparse.ArgumentParser): + from llama_stack.distribution.utils.model_utils import model_local_dir + + console = Console() + model_dir = Path(model_local_dir(args.model_id)) + checklist_path = model_dir / "checklist.chk" + + if not model_dir.exists(): + parser.error(f"Model directory not found: {model_dir}") + + if not checklist_path.exists(): + parser.error(f"Checklist file not found: {checklist_path}") + + checksums = load_checksums(checklist_path) + results = verify_files(model_dir, checksums, console) + + # Print results + console.print("\nVerification Results:") + + all_good = True + for result in results: + if not result.exists: + console.print(f"[red]❌ {result.filename}: File not found[/red]") + all_good = False + elif not result.matches: + console.print( + f"[red]❌ {result.filename}: Hash mismatch[/red]\n" + f" Expected: {result.expected_hash}\n" + f" Got: {result.actual_hash}" + ) + all_good = False + else: + console.print(f"[green]✓ {result.filename}: Verified[/green]") + + if all_good: + console.print("\n[green]All files verified successfully![/green]")