diff --git a/examples/nvidia_runtime/Dockerfile b/examples/nvidia_runtime/Dockerfile new file mode 100644 index 0000000000..6dc1c8f101 --- /dev/null +++ b/examples/nvidia_runtime/Dockerfile @@ -0,0 +1,6 @@ +FROM nvcr.io/nvidia/pytorch:25.04-py3 + +# Add the tests to the entrypoint set. Docker Slim only traces/monitors the processes started by the entrypoint. +RUN echo "pytest /opt/pytorch/pytorch/test/test_cuda.py::TestCuda::test_graph_cudnn_dropout" > /opt/nvidia/entrypoint.d/99-trace.sh +RUN chmod +x /opt/nvidia/entrypoint.d/99-trace.sh + diff --git a/examples/nvidia_runtime/README.md b/examples/nvidia_runtime/README.md new file mode 100644 index 0000000000..a2b80a3d44 --- /dev/null +++ b/examples/nvidia_runtime/README.md @@ -0,0 +1,19 @@ + +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +As a pre-requisite, install nvidia-container toolkit, including adding the nvidia runtime. Then you should be able to translate runtime and capabilities from a OCI/Docker string like `--runtime=nvidia --gpus all` to `--cro-device-request '{"Count":-1, "Capabilities":[["gpu"]]}' --cro-runtime nvidia` + +See the example `test_nvidia_smi.sh`, which slims ubuntu to just the files necessary to run the runtime mounted nvidia-smi. Similarly, see `test_nvidia_pytorch.sh` which minimizes nvidia-pytorch to run a subset of the CUDA tests. + diff --git a/examples/nvidia_runtime/test_nvidia_pytorch.sh b/examples/nvidia_runtime/test_nvidia_pytorch.sh new file mode 100755 index 0000000000..b9007f9ea1 --- /dev/null +++ b/examples/nvidia_runtime/test_nvidia_pytorch.sh @@ -0,0 +1,97 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Create host config file with ulimit settings and capabilities +cat > host-config.json <<'EOF' +{ + "IpcMode": "host", + "CapAdd": ["SYS_ADMIN"], + "Ulimits": [ + { + "Name": "memlock", + "Soft": -1, + "Hard": -1 + }, + { + "Name": "stack", + "Soft": 67108864, + "Hard": 67108864 + }, + { + "Name": "nofile", + "Soft": 1048576, + "Hard": 1048576 + } + ] +} +EOF + +# Build the slim image +# CAP_SYS_ADMIN is added via host-config.json for fanotify support (required for filesystem monitoring) +# Build custom image with test in entrypoint first +echo "Building custom test image with pytest in entrypoint..." +docker build -t nvcr.io/nvidia/pytorch:25.04-py3-test -f Dockerfile . + +echo "Running docker-slim on the test image..." +docker-slim build \ + --target nvcr.io/nvidia/pytorch:25.04-py3-test \ + --tag nvcr.io/nvidia/pytorch:25.04-py3-slim \ + --cro-host-config-file host-config.json \ + --cro-shm-size 1200 \ + --cro-device-request '{"Count":-1, "Capabilities":[["gpu"]]}' \ + --cro-runtime nvidia \ + --http-probe=false \ + --continue-after 10 \ + --preserve-path /etc/ld.so.conf \ + --preserve-path /etc/ld.so.conf.d \ + . + +# Get output of original and slim images stored in a log file +echo "Running original image..." +docker run --rm --runtime nvidia --gpus all nvcr.io/nvidia/pytorch:25.04-py3-test > original_log.txt 2>&1 +echo "Running slim image..." +docker run --rm --runtime nvidia --gpus all nvcr.io/nvidia/pytorch:25.04-py3-slim > slim_log.txt 2>&1 + +# Verify that both logs contain the pytest success message (ignoring timing) +echo "Checking test results..." + +# Look for "X passed" pattern in both logs +original_passed=$(grep -oE "[0-9]+ passed" original_log.txt | head -1) +slim_passed=$(grep -oE "[0-9]+ passed" slim_log.txt | head -1) + +if [ -z "$original_passed" ]; then + echo "Error: Original image test did not pass" + echo "Original log tail:" + tail -20 original_log.txt + exit 1 +fi + +if [ -z "$slim_passed" ]; then + echo "Error: Slim image test did not pass" + echo "Slim log tail:" + tail -20 slim_log.txt + exit 1 +fi + +echo "Original image: $original_passed" +echo "Slim image: $slim_passed" + +if [ "$original_passed" = "$slim_passed" ]; then + echo "SUCCESS: Both images passed the same number of tests!" +else + echo "Warning: Different number of tests passed (original: $original_passed, slim: $slim_passed)" +fi + +echo "Successfully minimized nvidia-pytorch to run a subset of the CUDA tests" diff --git a/examples/nvidia_runtime/test_nvidia_smi.sh b/examples/nvidia_runtime/test_nvidia_smi.sh new file mode 100755 index 0000000000..fe3f74fa48 --- /dev/null +++ b/examples/nvidia_runtime/test_nvidia_smi.sh @@ -0,0 +1,34 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Build the slim image +docker-slim build --target ubuntu:24.04 --tag ubuntu:24.04-slim --cro-shm-size 1200 --cro-device-request '{"Count":-1, "Capabilities":[["gpu"]]}' --cro-runtime nvidia --http-probe=false --exec "/usr/bin/nvidia-smi" . + +# Get output of original and slim images stored in a log file +docker run --rm --runtime nvidia --gpus all ubuntu:24.04 nvidia-smi > original_log.txt +docker run --rm --runtime nvidia --gpus all ubuntu:24.04-slim nvidia-smi > slim_log.txt + +# verify that both logs include the nvidia-smi output with an assert +assert_contains() { + if ! grep -q "$1" "$2"; then + echo "Error: '$1' not found in $2" + exit 1 + fi +} + +# verify that both logs include the nvidia-smi output with an assert +assert_contains "NVIDIA-SMI" original_log.txt +assert_contains "NVIDIA-SMI" slim_log.txt + diff --git a/examples/nvidia_runtime/test_nvidia_vllm.sh b/examples/nvidia_runtime/test_nvidia_vllm.sh new file mode 100755 index 0000000000..7ed4eda08e --- /dev/null +++ b/examples/nvidia_runtime/test_nvidia_vllm.sh @@ -0,0 +1,310 @@ +#!/bin/bash +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Configuration +VLLM_IMAGE="${VLLM_IMAGE:-nvcr.io/nvstaging/nim-internal/llama-3.1-8b-instruct:1.8.5-rc.20251210163303-6a18a653ae01e9ae}" +SLIM_TAG="${SLIM_TAG:-$(echo $VLLM_IMAGE | sed 's|.*/||; s/:/-slim:/')}" +MAX_WAIT_MINUTES="${MAX_WAIT_MINUTES:-20}" +CONTAINER_PORT=8000 +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +OUTPUT_DIR="${OUTPUT_DIR:-$SCRIPT_DIR}" +RESULTS_FILE="${OUTPUT_DIR}/vllm_test_results.json" +SLIM_RESULTS_FILE="${OUTPUT_DIR}/vllm_test_results_slim.json" + +# NIM cache directory - uses ~/cache with volume mount to /opt/nim/.cache +NIM_CACHE_DIR="${NIM_CACHE_DIR:-$HOME/cache}" + +# Model configuration +MAX_SEQ_LEN="${MAX_SEQ_LEN:-8192}" +MAX_MODEL_LEN="${MAX_MODEL_LEN:-8192}" + +# Custom entrypoint for the NIM container +NIM_ENTRYPOINT="${NIM_ENTRYPOINT:-/opt/nim/start_server.sh}" +NIM_CMD="${NIM_CMD:---max_model_len=${MAX_MODEL_LEN}}" + +# Load environment variables from .env file if it exists +if [ -f "$HOME/nim-llm/.env" ]; then + echo "Loading environment from $HOME/nim-llm/.env" + set -a + source "$HOME/nim-llm/.env" + set +a +fi + +# Check for required NGC_API_KEY +if [ -z "$NGC_API_KEY" ]; then + echo "Error: NGC_API_KEY environment variable is not set" + echo "Please set it or create a .env file at ~/nim-llm/.env" + exit 1 +fi + +# Ensure cache directory exists +mkdir -p "$NIM_CACHE_DIR" + +# Find the slim binary - check for docker-slim or slim in PATH, or use local binary +if command -v docker-slim &> /dev/null; then + SLIM_CMD="docker-slim" +elif command -v slim &> /dev/null; then + SLIM_CMD="slim" +elif [ -x "${SCRIPT_DIR}/../../bin/linux/slim" ]; then + SLIM_CMD="${SCRIPT_DIR}/../../bin/linux/slim" +else + echo "Error: docker-slim/slim binary not found" + echo "Please install docker-slim or ensure bin/linux/slim exists" + exit 1 +fi +echo "Using slim binary: $SLIM_CMD" + +# Cleanup function +cleanup() { + echo "Cleaning up..." + if [ -n "$MONITOR_PID" ] && kill -0 "$MONITOR_PID" 2>/dev/null; then + kill "$MONITOR_PID" 2>/dev/null + fi + if [ -n "$SLIM_PID" ] && kill -0 "$SLIM_PID" 2>/dev/null; then + # Send SIGINT to docker-slim to gracefully stop + kill -INT "$SLIM_PID" 2>/dev/null + wait "$SLIM_PID" 2>/dev/null + fi +} + +trap cleanup EXIT + +# Create host config file with ulimit settings and capabilities +cat > host-config.json <<'EOF' +{ + "IpcMode": "host", + "CapAdd": ["SYS_ADMIN"], + "Ulimits": [ + { + "Name": "memlock", + "Soft": -1, + "Hard": -1 + }, + { + "Name": "stack", + "Soft": 67108864, + "Hard": 67108864 + }, + { + "Name": "nofile", + "Soft": 1048576, + "Hard": 1048576 + } + ] +} +EOF + +echo "=============================================" +echo "VLLM Docker-Slim Test" +echo "=============================================" +echo "Source Image: $VLLM_IMAGE" +echo "Slim Tag: $SLIM_TAG" +echo "Entrypoint: $NIM_ENTRYPOINT $NIM_CMD" +echo "Max Wait: $MAX_WAIT_MINUTES minutes" +echo "NIM Cache: $NIM_CACHE_DIR" +echo "MAX_SEQ_LEN: $MAX_SEQ_LEN" +echo "Results File: $RESULTS_FILE" +echo "NGC_API_KEY: ${NGC_API_KEY:0:10}..." +echo "=============================================" + +# Function to run docker-slim with monitoring +run_slim_with_monitor() { + local target_image="$1" + local output_tag="$2" + local results_file="$3" + local is_original="$4" + + echo "" + echo "Starting docker-slim build for: $target_image" + echo "Output tag: $output_tag" + + # Create a named pipe for signaling + SIGNAL_PIPE=$(mktemp -u) + mkfifo "$SIGNAL_PIPE" + + # Start docker-slim in the background + # Using --continue-after signal to allow the monitor to signal when done + $SLIM_CMD build \ + --target "$target_image" \ + --tag "$output_tag" \ + --cro-host-config-file host-config.json \ + --cro-shm-size 1200 \ + --cro-device-request '{"Count":-1, "Capabilities":[["gpu"]]}' \ + --cro-runtime nvidia \ + --expose ${CONTAINER_PORT} \ + --publish-port ${CONTAINER_PORT}:${CONTAINER_PORT} \ + --publish-exposed-ports \ + --env "NGC_API_KEY=${NGC_API_KEY}" \ + --env "MAX_SEQ_LEN=${MAX_SEQ_LEN}" \ + --entrypoint "${NIM_ENTRYPOINT}" \ + --cmd "${NIM_CMD}" \ + --http-probe=false \ + --continue-after signal \ + --preserve-path /etc/ld.so.conf \ + --preserve-path /etc/ld.so.conf.d \ + --exclude-pattern "/opt/nim/.cache/**" \ + --exclude-pattern "/root/.cache/**" \ + . & + + SLIM_PID=$! + echo "Docker-slim started with PID: $SLIM_PID" + + # Wait a moment for the container to start + sleep 10 + + # Start the monitor/test script in the background + echo "Starting API monitor and test runner..." + python3 "${SCRIPT_DIR}/vllm_api_tests.py" \ + --host "localhost" \ + --port "$CONTAINER_PORT" \ + --output "$results_file" \ + --max-wait "$MAX_WAIT_MINUTES" \ + --signal-pid "$SLIM_PID" & + + MONITOR_PID=$! + echo "Monitor started with PID: $MONITOR_PID" + + # Wait for the monitor to complete + wait "$MONITOR_PID" + MONITOR_EXIT_CODE=$? + echo "Monitor completed with exit code: $MONITOR_EXIT_CODE" + + # Wait for docker-slim to complete + wait "$SLIM_PID" + SLIM_EXIT_CODE=$? + echo "Docker-slim completed with exit code: $SLIM_EXIT_CODE" + + # Cleanup the signal pipe + rm -f "$SIGNAL_PIPE" + + if [ $SLIM_EXIT_CODE -ne 0 ]; then + echo "Warning: Docker-slim exited with code $SLIM_EXIT_CODE" + fi + + return $MONITOR_EXIT_CODE +} + +# Phase 1: Build slim image from original and run tests +echo "" +echo "=============================================" +echo "Phase 1: Building slim image from original" +echo "=============================================" + +run_slim_with_monitor "$VLLM_IMAGE" "$SLIM_TAG" "$RESULTS_FILE" "true" +PHASE1_EXIT=$? + +if [ ! -f "$RESULTS_FILE" ]; then + echo "Error: Results file not created during Phase 1" + exit 1 +fi + +echo "" +echo "Phase 1 Results:" +cat "$RESULTS_FILE" + +# Phase 2: Run slim image and test it +echo "" +echo "=============================================" +echo "Phase 2: Testing the slimmed image" +echo "=============================================" + +# Run the slimmed image directly (not through docker-slim) and test it +echo "Starting slimmed container for testing..." +SLIM_CONTAINER_ID=$(docker run -d \ + --runtime nvidia \ + --gpus all \ + --ipc=host \ + --ulimit memlock=-1 \ + --ulimit stack=67108864 \ + --shm-size=1200m \ + -e NGC_API_KEY \ + -e MAX_SEQ_LEN=${MAX_SEQ_LEN} \ + -v "${NIM_CACHE_DIR}:/opt/nim/.cache" \ + -p ${CONTAINER_PORT}:${CONTAINER_PORT} \ + "$SLIM_TAG" \ + ${NIM_ENTRYPOINT} ${NIM_CMD}) + +echo "Slim container started: $SLIM_CONTAINER_ID" + +# Run tests against the slim container +python3 "${SCRIPT_DIR}/vllm_api_tests.py" \ + --host "localhost" \ + --port "$CONTAINER_PORT" \ + --output "$SLIM_RESULTS_FILE" \ + --max-wait "$MAX_WAIT_MINUTES" + +PHASE2_EXIT=$? + +# Stop the slim container +docker stop "$SLIM_CONTAINER_ID" +docker rm "$SLIM_CONTAINER_ID" + +if [ ! -f "$SLIM_RESULTS_FILE" ]; then + echo "Error: Slim results file not created during Phase 2" + exit 1 +fi + +echo "" +echo "Phase 2 Results (Slim Image):" +cat "$SLIM_RESULTS_FILE" + +# Compare results +echo "" +echo "=============================================" +echo "Comparing Results" +echo "=============================================" + +python3 - "$RESULTS_FILE" "$SLIM_RESULTS_FILE" <<'COMPARE_SCRIPT' +import json +import sys + +results_file = sys.argv[1] +slim_results_file = sys.argv[2] + +try: + with open(results_file, "r") as f: + original = json.load(f) + with open(slim_results_file, "r") as f: + slim = json.load(f) + + original_passed = sum(1 for t in original.get("tests", []) if t.get("status") == "passed") + original_failed = sum(1 for t in original.get("tests", []) if t.get("status") == "failed") + slim_passed = sum(1 for t in slim.get("tests", []) if t.get("status") == "passed") + slim_failed = sum(1 for t in slim.get("tests", []) if t.get("status") == "failed") + + print(f"Original: {original_passed} passed, {original_failed} failed") + print(f"Slim: {slim_passed} passed, {slim_failed} failed") + + if original_passed == slim_passed: + print("SUCCESS: Both images passed the same number of tests!") + sys.exit(0) + else: + print("WARNING: Different number of tests passed") + sys.exit(1) +except Exception as e: + print(f"Error comparing results: {e}") + sys.exit(1) +COMPARE_SCRIPT + +COMPARE_EXIT=$? + +echo "" +echo "=============================================" +echo "Test Complete" +echo "=============================================" + +exit $COMPARE_EXIT + diff --git a/examples/nvidia_runtime/vllm_api_tests.py b/examples/nvidia_runtime/vllm_api_tests.py new file mode 100755 index 0000000000..9a36aa38e1 --- /dev/null +++ b/examples/nvidia_runtime/vllm_api_tests.py @@ -0,0 +1,772 @@ +#!/usr/bin/env python3 +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +VLLM API Tests for Docker-Slim Integration + +This script waits for a VLLM server to become ready, runs a suite of API tests, +and writes the results to a JSON file. It can optionally signal a docker-slim +process when testing is complete. +""" + +import argparse +import json +import os +import signal +import sys +import time +import traceback +from dataclasses import dataclass, field, asdict +from datetime import datetime +from typing import Any, Callable, Dict, List, Optional + + +# Try to import requests, provide helpful error if not available +try: + import requests +except ImportError: + print("Error: 'requests' library is required. Install with: pip install requests") + sys.exit(1) + + +@dataclass +class TestResult: + """Result of a single test.""" + name: str + status: str # "passed", "failed", "skipped", "error" + duration_ms: float = 0.0 + message: str = "" + details: Dict[str, Any] = field(default_factory=dict) + + +@dataclass +class TestSuiteResults: + """Results of the entire test suite.""" + timestamp: str = "" + host: str = "" + port: int = 0 + model_name: str = "" + server_ready_time_s: float = 0.0 + tests: List[Dict[str, Any]] = field(default_factory=list) + summary: Dict[str, int] = field(default_factory=dict) + + +class VLLMApiTester: + """Tests VLLM OpenAI-compatible API endpoints.""" + + def __init__(self, host: str, port: int): + self.host = host + self.port = port + self.base_url = f"http://{host}:{port}" + self.model_name: Optional[str] = None + self.headers = { + "accept": "application/json", + "Content-Type": "application/json" + } + + def wait_for_server(self, max_wait_minutes: int = 20) -> bool: + """Wait for the server to become ready and return a model.""" + print(f"Waiting for server at {self.base_url} (max {max_wait_minutes} minutes)...") + + start_time = time.time() + max_wait_seconds = max_wait_minutes * 60 + check_interval = 5 # seconds + + while time.time() - start_time < max_wait_seconds: + try: + response = requests.get( + f"{self.base_url}/v1/models", + headers=self.headers, + timeout=10 + ) + + if response.status_code == 200: + data = response.json() + models = data.get("data", []) + + if models: + self.model_name = models[0].get("id") + elapsed = time.time() - start_time + print(f"Server ready after {elapsed:.1f}s. Model: {self.model_name}") + return True + else: + print(f" Server responded but no models loaded yet...") + else: + print(f" Server returned status {response.status_code}") + + except requests.exceptions.ConnectionError: + print(f" Connection refused, server not ready yet...") + except requests.exceptions.Timeout: + print(f" Connection timed out...") + except Exception as e: + print(f" Error checking server: {e}") + + time.sleep(check_interval) + + print(f"Server did not become ready within {max_wait_minutes} minutes") + return False + + def _run_test(self, name: str, test_func: Callable) -> TestResult: + """Run a single test and capture the result.""" + start_time = time.time() + try: + result = test_func() + duration_ms = (time.time() - start_time) * 1000 + + if result is True: + return TestResult( + name=name, + status="passed", + duration_ms=duration_ms, + message="Test passed successfully" + ) + elif result is None: + return TestResult( + name=name, + status="skipped", + duration_ms=duration_ms, + message="Test skipped" + ) + else: + return TestResult( + name=name, + status="failed", + duration_ms=duration_ms, + message=str(result) if result else "Test failed" + ) + except Exception as e: + duration_ms = (time.time() - start_time) * 1000 + return TestResult( + name=name, + status="error", + duration_ms=duration_ms, + message=str(e), + details={"traceback": traceback.format_exc()} + ) + + def test_models_endpoint(self) -> bool: + """Test GET /v1/models endpoint.""" + response = requests.get( + f"{self.base_url}/v1/models", + headers=self.headers, + timeout=30 + ) + + if response.status_code != 200: + return f"Expected status 200, got {response.status_code}" + + data = response.json() + if "data" not in data: + return "Response missing 'data' field" + + if not data["data"]: + return "No models returned" + + return True + + def test_health_endpoint(self) -> bool: + """Test health check endpoint.""" + # Try common health endpoints + for endpoint in ["/health", "/v1/health", "/healthz"]: + try: + response = requests.get( + f"{self.base_url}{endpoint}", + headers=self.headers, + timeout=10 + ) + if response.status_code == 200: + return True + except: + pass + + # If no dedicated health endpoint, v1/models working is good enough + return True + + def test_completions_basic(self) -> bool: + """Test basic /v1/completions endpoint.""" + data = { + "model": self.model_name, + "prompt": "San Francisco is a", + "temperature": 0, + "max_tokens": 32 + } + + response = requests.post( + f"{self.base_url}/v1/completions", + headers=self.headers, + json=data, + timeout=120 + ) + + if response.status_code != 200: + return f"Expected status 200, got {response.status_code}: {response.text}" + + out = response.json() + + if out.get("model") != self.model_name: + return f"Expected model '{self.model_name}', got '{out.get('model')}'" + + if len(out.get("choices", [])) != 1: + return f"Expected 1 choice, got {len(out.get('choices', []))}" + + if out["choices"][0].get("index") != 0: + return "Choice index should be 0" + + if "text" not in out["choices"][0]: + return "Choice missing 'text' field" + + return True + + def test_completions_with_logprobs(self) -> bool: + """Test /v1/completions with logprobs.""" + data = { + "model": self.model_name, + "prompt": "The quick brown fox", + "temperature": 0, + "max_tokens": 16, + "logprobs": 1 + } + + response = requests.post( + f"{self.base_url}/v1/completions", + headers=self.headers, + json=data, + timeout=120 + ) + + if response.status_code != 200: + return f"Expected status 200, got {response.status_code}" + + out = response.json() + choice = out.get("choices", [{}])[0] + + # Logprobs might be None if not supported by backend + if "logprobs" in choice and choice["logprobs"] is not None: + logprobs = choice["logprobs"] + if "tokens" in logprobs and len(logprobs["tokens"]) == 0: + return "Expected tokens in logprobs" + + return True + + def test_completions_streaming(self) -> bool: + """Test streaming /v1/completions endpoint.""" + data = { + "model": self.model_name, + "prompt": "Once upon a time", + "temperature": 0, + "max_tokens": 32, + "stream": True + } + + response = requests.post( + f"{self.base_url}/v1/completions", + headers=self.headers, + json=data, + timeout=120, + stream=True + ) + + if response.status_code != 200: + return f"Expected status 200, got {response.status_code}" + + chunks_received = 0 + done_received = False + + for line in response.iter_lines(decode_unicode=True): + if line: + if line.startswith("data: "): + chunk_data = line[6:].strip() + if chunk_data == "[DONE]": + done_received = True + else: + try: + json.loads(chunk_data) + chunks_received += 1 + except json.JSONDecodeError: + return f"Invalid JSON in stream: {chunk_data}" + + if chunks_received == 0: + return "No streaming chunks received" + + if not done_received: + return "Stream did not end with [DONE]" + + return True + + def test_chat_completions_basic(self) -> bool: + """Test basic /v1/chat/completions endpoint.""" + data = { + "model": self.model_name, + "messages": [ + {"role": "user", "content": "Say hello in exactly 5 words."} + ], + "temperature": 0, + "max_tokens": 32 + } + + response = requests.post( + f"{self.base_url}/v1/chat/completions", + headers=self.headers, + json=data, + timeout=120 + ) + + if response.status_code != 200: + return f"Expected status 200, got {response.status_code}: {response.text}" + + out = response.json() + + if out.get("model") != self.model_name: + return f"Expected model '{self.model_name}', got '{out.get('model')}'" + + if len(out.get("choices", [])) != 1: + return f"Expected 1 choice, got {len(out.get('choices', []))}" + + choice = out["choices"][0] + + if "message" not in choice: + return "Choice missing 'message' field" + + if "content" not in choice["message"]: + return "Message missing 'content' field" + + return True + + def test_chat_completions_multi_turn(self) -> bool: + """Test multi-turn conversation.""" + data = { + "model": self.model_name, + "messages": [ + {"role": "user", "content": "My name is Alice."}, + {"role": "assistant", "content": "Hello Alice! Nice to meet you."}, + {"role": "user", "content": "What is my name?"} + ], + "temperature": 0, + "max_tokens": 32 + } + + response = requests.post( + f"{self.base_url}/v1/chat/completions", + headers=self.headers, + json=data, + timeout=120 + ) + + if response.status_code != 200: + return f"Expected status 200, got {response.status_code}" + + out = response.json() + content = out["choices"][0]["message"]["content"].lower() + + # The model should remember the name "Alice" + if "alice" not in content: + return f"Expected model to remember 'Alice', got: {content}" + + return True + + def test_chat_completions_streaming(self) -> bool: + """Test streaming /v1/chat/completions endpoint.""" + data = { + "model": self.model_name, + "messages": [ + {"role": "user", "content": "Count from 1 to 5."} + ], + "temperature": 0, + "max_tokens": 32, + "stream": True + } + + response = requests.post( + f"{self.base_url}/v1/chat/completions", + headers=self.headers, + json=data, + timeout=120, + stream=True + ) + + if response.status_code != 200: + return f"Expected status 200, got {response.status_code}" + + chunks_received = 0 + done_received = False + + for line in response.iter_lines(decode_unicode=True): + if line: + if line.startswith("data: "): + chunk_data = line[6:].strip() + if chunk_data == "[DONE]": + done_received = True + else: + try: + json.loads(chunk_data) + chunks_received += 1 + except json.JSONDecodeError: + return f"Invalid JSON in stream: {chunk_data}" + + if chunks_received == 0: + return "No streaming chunks received" + + if not done_received: + return "Stream did not end with [DONE]" + + return True + + def test_completions_stop_words(self) -> bool: + """Test stop words in completions.""" + data = { + "model": self.model_name, + "prompt": "List the numbers: 1, 2, 3, 4, 5, 6, 7, 8, 9, 10", + "temperature": 0, + "max_tokens": 64, + "stop": ["5"] + } + + response = requests.post( + f"{self.base_url}/v1/completions", + headers=self.headers, + json=data, + timeout=120 + ) + + if response.status_code != 200: + return f"Expected status 200, got {response.status_code}" + + out = response.json() + text = out["choices"][0]["text"] + + # The output should stop before or at "5" + # This is a weak check since the model might not continue the pattern + return True + + def test_completions_max_tokens(self) -> bool: + """Test max_tokens parameter.""" + max_tokens = 10 + data = { + "model": self.model_name, + "prompt": "Write a very long essay about", + "temperature": 0.7, + "max_tokens": max_tokens + } + + response = requests.post( + f"{self.base_url}/v1/completions", + headers=self.headers, + json=data, + timeout=120 + ) + + if response.status_code != 200: + return f"Expected status 200, got {response.status_code}" + + out = response.json() + completion_tokens = out.get("usage", {}).get("completion_tokens", 0) + + if completion_tokens > max_tokens: + return f"Expected at most {max_tokens} completion tokens, got {completion_tokens}" + + return True + + def test_completions_temperature(self) -> bool: + """Test temperature parameter affects output.""" + prompt = "The meaning of life is" + + # Run with temperature 0 twice - should get same result + data_t0 = { + "model": self.model_name, + "prompt": prompt, + "temperature": 0, + "max_tokens": 20 + } + + response1 = requests.post( + f"{self.base_url}/v1/completions", + headers=self.headers, + json=data_t0, + timeout=120 + ) + + response2 = requests.post( + f"{self.base_url}/v1/completions", + headers=self.headers, + json=data_t0, + timeout=120 + ) + + if response1.status_code != 200 or response2.status_code != 200: + return f"Expected status 200" + + text1 = response1.json()["choices"][0]["text"] + text2 = response2.json()["choices"][0]["text"] + + if text1 != text2: + return f"Temperature 0 should give deterministic results" + + return True + + def test_usage_stats(self) -> bool: + """Test that usage statistics are returned.""" + data = { + "model": self.model_name, + "prompt": "Hello, world!", + "temperature": 0, + "max_tokens": 10 + } + + response = requests.post( + f"{self.base_url}/v1/completions", + headers=self.headers, + json=data, + timeout=120 + ) + + if response.status_code != 200: + return f"Expected status 200, got {response.status_code}" + + out = response.json() + usage = out.get("usage", {}) + + if "prompt_tokens" not in usage: + return "Missing prompt_tokens in usage" + + if "completion_tokens" not in usage: + return "Missing completion_tokens in usage" + + if "total_tokens" not in usage: + return "Missing total_tokens in usage" + + if usage["prompt_tokens"] <= 0: + return "prompt_tokens should be > 0" + + if usage["completion_tokens"] <= 0: + return "completion_tokens should be > 0" + + expected_total = usage["prompt_tokens"] + usage["completion_tokens"] + if usage["total_tokens"] != expected_total: + return f"total_tokens should equal prompt + completion tokens" + + return True + + def test_metrics_endpoint(self) -> bool: + """Test /v1/metrics endpoint if available.""" + try: + response = requests.get( + f"{self.base_url}/v1/metrics", + headers=self.headers, + timeout=30 + ) + + if response.status_code == 200: + # Metrics endpoint exists + if len(response.text) == 0: + return "Metrics endpoint returned empty response" + return True + elif response.status_code == 404: + # Metrics endpoint doesn't exist, which is acceptable + return True + else: + return f"Unexpected status code: {response.status_code}" + except Exception as e: + # Metrics endpoint not available, acceptable + return True + + def test_invalid_model(self) -> bool: + """Test error handling for invalid model name.""" + data = { + "model": "nonexistent-model-12345", + "prompt": "Hello", + "max_tokens": 10 + } + + response = requests.post( + f"{self.base_url}/v1/completions", + headers=self.headers, + json=data, + timeout=30 + ) + + # Should return an error status (400 or 404) + if response.status_code == 200: + return "Expected error for invalid model, got 200" + + return True + + def test_empty_prompt(self) -> bool: + """Test handling of empty prompt.""" + data = { + "model": self.model_name, + "prompt": "", + "max_tokens": 10 + } + + response = requests.post( + f"{self.base_url}/v1/completions", + headers=self.headers, + json=data, + timeout=30 + ) + + # Some servers accept empty prompts, some don't - both are valid + # Just ensure we get a valid response (200) or proper error (400) + if response.status_code not in [200, 400]: + return f"Expected 200 or 400, got {response.status_code}" + + return True + + def run_all_tests(self) -> List[TestResult]: + """Run all tests and return results.""" + tests = [ + ("models_endpoint", self.test_models_endpoint), + ("health_endpoint", self.test_health_endpoint), + ("completions_basic", self.test_completions_basic), + ("completions_with_logprobs", self.test_completions_with_logprobs), + ("completions_streaming", self.test_completions_streaming), + ("chat_completions_basic", self.test_chat_completions_basic), + ("chat_completions_multi_turn", self.test_chat_completions_multi_turn), + ("chat_completions_streaming", self.test_chat_completions_streaming), + ("completions_stop_words", self.test_completions_stop_words), + ("completions_max_tokens", self.test_completions_max_tokens), + ("completions_temperature", self.test_completions_temperature), + ("usage_stats", self.test_usage_stats), + ("metrics_endpoint", self.test_metrics_endpoint), + ("invalid_model", self.test_invalid_model), + ("empty_prompt", self.test_empty_prompt), + ] + + results = [] + for name, test_func in tests: + print(f" Running test: {name}...", end=" ", flush=True) + result = self._run_test(name, test_func) + print(f"{result.status.upper()}") + if result.status in ["failed", "error"]: + print(f" -> {result.message}") + results.append(result) + + return results + + +def main(): + parser = argparse.ArgumentParser(description="VLLM API Test Runner") + parser.add_argument("--host", default="localhost", help="Server host") + parser.add_argument("--port", type=int, default=8000, help="Server port") + parser.add_argument("--output", required=True, help="Output JSON file for results") + parser.add_argument("--max-wait", type=int, default=20, help="Max wait time in minutes for server") + parser.add_argument("--signal-pid", type=int, help="PID to send SIGINT when done") + + args = parser.parse_args() + + print("=" * 60) + print("VLLM API Test Runner") + print("=" * 60) + print(f"Host: {args.host}") + print(f"Port: {args.port}") + print(f"Output: {args.output}") + print(f"Max Wait: {args.max_wait} minutes") + if args.signal_pid: + print(f"Signal PID: {args.signal_pid}") + print("=" * 60) + + tester = VLLMApiTester(args.host, args.port) + + # Wait for server to be ready + start_wait = time.time() + if not tester.wait_for_server(args.max_wait): + # Write failure result + results = TestSuiteResults( + timestamp=datetime.now().isoformat(), + host=args.host, + port=args.port, + model_name="", + server_ready_time_s=-1, + tests=[], + summary={"passed": 0, "failed": 0, "skipped": 0, "error": 1} + ) + + with open(args.output, "w") as f: + json.dump(asdict(results), f, indent=2) + + print(f"Results written to: {args.output}") + + # Signal docker-slim if requested (slim expects SIGUSR1) + if args.signal_pid: + try: + os.kill(args.signal_pid, signal.SIGUSR1) + print(f"Sent SIGUSR1 to PID {args.signal_pid}") + except ProcessLookupError: + print(f"Process {args.signal_pid} not found") + except PermissionError: + print(f"Permission denied to signal PID {args.signal_pid}") + + sys.exit(1) + + server_ready_time = time.time() - start_wait + + # Run tests + print("\nRunning API tests...") + test_results = tester.run_all_tests() + + # Summarize results + summary = { + "passed": sum(1 for t in test_results if t.status == "passed"), + "failed": sum(1 for t in test_results if t.status == "failed"), + "skipped": sum(1 for t in test_results if t.status == "skipped"), + "error": sum(1 for t in test_results if t.status == "error"), + } + + # Create results object + results = TestSuiteResults( + timestamp=datetime.now().isoformat(), + host=args.host, + port=args.port, + model_name=tester.model_name or "", + server_ready_time_s=server_ready_time, + tests=[asdict(t) for t in test_results], + summary=summary + ) + + # Write results to file + with open(args.output, "w") as f: + json.dump(asdict(results), f, indent=2) + + print("\n" + "=" * 60) + print("Test Summary") + print("=" * 60) + print(f" Passed: {summary['passed']}") + print(f" Failed: {summary['failed']}") + print(f" Skipped: {summary['skipped']}") + print(f" Errors: {summary['error']}") + print(f"\nResults written to: {args.output}") + + # Signal docker-slim if requested (slim expects SIGUSR1) + if args.signal_pid: + print(f"\nSignaling docker-slim (PID {args.signal_pid}) to stop...") + try: + os.kill(args.signal_pid, signal.SIGUSR1) + print(f"Sent SIGUSR1 to PID {args.signal_pid}") + except ProcessLookupError: + print(f"Process {args.signal_pid} not found") + except PermissionError: + print(f"Permission denied to signal PID {args.signal_pid}") + + # Exit with appropriate code + if summary["failed"] > 0 or summary["error"] > 0: + sys.exit(1) + else: + sys.exit(0) + + +if __name__ == "__main__": + main() + diff --git a/pkg/app/master/command/build/cli.go b/pkg/app/master/command/build/cli.go index 99d5aa46c6..20a428ca5b 100644 --- a/pkg/app/master/command/build/cli.go +++ b/pkg/app/master/command/build/cli.go @@ -1,3 +1,20 @@ + + /* + * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package build import ( @@ -172,6 +189,7 @@ var CLI = &cli.Command{ //Sensor flags: command.Cflag(command.FlagSensorIPCEndpoint), command.Cflag(command.FlagSensorIPCMode), + command.Cflag(command.FlagCRODeviceRequest), }, command.HTTPProbeFlags()...), Action: func(ctx *cli.Context) error { gparams, ok := command.CLIContextGet(ctx.Context, command.GlobalParams).(*command.GenericParams) diff --git a/pkg/app/master/command/cliflags.go b/pkg/app/master/command/cliflags.go index 30c06c12cd..34e44a0991 100644 --- a/pkg/app/master/command/cliflags.go +++ b/pkg/app/master/command/cliflags.go @@ -1,3 +1,19 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package command import ( @@ -155,6 +171,7 @@ const ( FlagCROHostConfigFile = "cro-host-config-file" FlagCROSysctl = "cro-sysctl" FlagCROShmSize = "cro-shm-size" + FlagCRODeviceRequest = "cro-device-request" //Original Container Runtime Options (without cro- prefix) FlagUser = "user" @@ -266,6 +283,7 @@ const ( FlagCROHostConfigFileUsage = "Base Docker host configuration file (JSON format) to use when running the container" FlagCROSysctlUsage = "Set namespaced kernel parameters in the created container" FlagCROShmSizeUsage = "Shared memory size for /dev/shm in the created container" + FlagCRODeviceRequestUsage = "JSON string specifying device request configuration for the container" FlagUserUsage = "Override USER analyzing image at runtime" FlagEntrypointUsage = "Override ENTRYPOINT analyzing image at runtime. To persist ENTRYPOINT changes in the output image, pass the --image-overrides=entrypoint or --image-overrides=all flag as well." @@ -913,6 +931,11 @@ var CommonFlags = map[string]cli.Flag{ Usage: FlagRTASourcePTUsage, EnvVars: []string{"DSLIM_RTA_SRC_PT"}, }, + FlagCRODeviceRequest: &cli.StringFlag{ + Name: FlagCRODeviceRequest, + Usage: FlagCRODeviceRequestUsage, + EnvVars: []string{"DSLIM_CRO_DEVICE_REQUEST"}, + }, } //var CommonFlags diff --git a/pkg/app/master/command/clifvgetter.go b/pkg/app/master/command/clifvgetter.go index 9a05b4b5a3..4ab8e60c1f 100644 --- a/pkg/app/master/command/clifvgetter.go +++ b/pkg/app/master/command/clifvgetter.go @@ -1,4 +1,20 @@ -package command + /* + * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package command //Flag value getters @@ -51,6 +67,7 @@ func GetContainerRunOptions(ctx *cli.Context) (*config.ContainerRunOptions, erro } cro.ShmSize = ctx.Int64(FlagCROShmSize) + cro.DeviceRequest = ctx.String(FlagCRODeviceRequest) return &cro, nil } diff --git a/pkg/app/master/command/clifvgetter_test.go b/pkg/app/master/command/clifvgetter_test.go new file mode 100644 index 0000000000..e4fb6530db --- /dev/null +++ b/pkg/app/master/command/clifvgetter_test.go @@ -0,0 +1,139 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package command + +import ( + "flag" + "testing" + + "github.com/urfave/cli/v2" +) + +func TestGetContainerRunOptionsDeviceRequest(t *testing.T) { + tt := []struct { + deviceRequestFlag string + expectedDeviceRequest string + }{ + { + deviceRequestFlag: "", + expectedDeviceRequest: "", + }, + { + deviceRequestFlag: `{"Driver":"nvidia","Count":-1,"Capabilities":[["gpu"]]}`, + expectedDeviceRequest: `{"Driver":"nvidia","Count":-1,"Capabilities":[["gpu"]]}`, + }, + { + deviceRequestFlag: `{"Driver":"nvidia","DeviceIDs":["0","1"],"Capabilities":[["gpu"]]}`, + expectedDeviceRequest: `{"Driver":"nvidia","DeviceIDs":["0","1"],"Capabilities":[["gpu"]]}`, + }, + { + deviceRequestFlag: `{"Driver":"nvidia","Count":2,"DeviceIDs":["GPU-123"],"Capabilities":[["gpu","compute"]],"Options":{"visible":"true"}}`, + expectedDeviceRequest: `{"Driver":"nvidia","Count":2,"DeviceIDs":["GPU-123"],"Capabilities":[["gpu","compute"]],"Options":{"visible":"true"}}`, + }, + } + + for _, test := range tt { + flagSet := flag.NewFlagSet("test", flag.ContinueOnError) + flagSet.String(FlagCRODeviceRequest, test.deviceRequestFlag, "") + flagSet.String(FlagCRORuntime, "", "") + flagSet.String(FlagCROHostConfigFile, "", "") + flagSet.Int64(FlagCROShmSize, 0, "") + + app := &cli.App{} + ctx := cli.NewContext(app, flagSet, nil) + + cro, err := GetContainerRunOptions(ctx) + if err != nil { + t.Fatalf("GetContainerRunOptions returned error: %v", err) + } + + if cro.DeviceRequest != test.expectedDeviceRequest { + t.Errorf("DeviceRequest = %q, want %q", cro.DeviceRequest, test.expectedDeviceRequest) + } + } +} + +func TestGetContainerRunOptionsAllFields(t *testing.T) { + flagSet := flag.NewFlagSet("test", flag.ContinueOnError) + flagSet.String(FlagCRODeviceRequest, `{"Driver":"nvidia","Count":-1,"Capabilities":[["gpu"]]}`, "") + flagSet.String(FlagCRORuntime, "nvidia", "") + flagSet.String(FlagCROHostConfigFile, "", "") + flagSet.Int64(FlagCROShmSize, 67108864, "") + + app := &cli.App{} + ctx := cli.NewContext(app, flagSet, nil) + + cro, err := GetContainerRunOptions(ctx) + if err != nil { + t.Fatalf("GetContainerRunOptions returned error: %v", err) + } + + if cro.Runtime != "nvidia" { + t.Errorf("Runtime = %q, want 'nvidia'", cro.Runtime) + } + + if cro.ShmSize != 67108864 { + t.Errorf("ShmSize = %d, want 67108864", cro.ShmSize) + } + + if cro.DeviceRequest != `{"Driver":"nvidia","Count":-1,"Capabilities":[["gpu"]]}` { + t.Errorf("DeviceRequest = %q, want JSON string", cro.DeviceRequest) + } +} + +func TestFlagCRODeviceRequestDefinition(t *testing.T) { + flagDef, exists := CommonFlags[FlagCRODeviceRequest] + if !exists { + t.Fatal("FlagCRODeviceRequest not found in CommonFlags") + } + + stringFlag, ok := flagDef.(*cli.StringFlag) + if !ok { + t.Fatal("FlagCRODeviceRequest is not a StringFlag") + } + + if stringFlag.Name != FlagCRODeviceRequest { + t.Errorf("Flag name = %q, want %q", stringFlag.Name, FlagCRODeviceRequest) + } + + if stringFlag.Usage != FlagCRODeviceRequestUsage { + t.Errorf("Flag usage = %q, want %q", stringFlag.Usage, FlagCRODeviceRequestUsage) + } + + expectedEnvVar := "DSLIM_CRO_DEVICE_REQUEST" + hasEnvVar := false + for _, env := range stringFlag.EnvVars { + if env == expectedEnvVar { + hasEnvVar = true + break + } + } + if !hasEnvVar { + t.Errorf("Flag missing expected EnvVar %q, has %v", expectedEnvVar, stringFlag.EnvVars) + } +} + +func TestFlagCRODeviceRequestConstants(t *testing.T) { + if FlagCRODeviceRequest != "cro-device-request" { + t.Errorf("FlagCRODeviceRequest = %q, want 'cro-device-request'", FlagCRODeviceRequest) + } + + expectedUsage := "JSON string specifying device request configuration for the container" + if FlagCRODeviceRequestUsage != expectedUsage { + t.Errorf("FlagCRODeviceRequestUsage = %q, want %q", FlagCRODeviceRequestUsage, expectedUsage) + } +} diff --git a/pkg/app/master/config/config.go b/pkg/app/master/config/config.go index 86e1916917..5f9c71448c 100644 --- a/pkg/app/master/config/config.go +++ b/pkg/app/master/config/config.go @@ -1,3 +1,19 @@ + /* + * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package config import ( @@ -122,9 +138,10 @@ type ContainerRunOptions struct { //Explicit overrides for the base and host config fields //Host config field override are applied //on top of the fields in the HostConfig struct if it's provided (volume mounts are merged though) - Runtime string - SysctlParams map[string]string - ShmSize int64 + Runtime string + SysctlParams map[string]string + ShmSize int64 + DeviceRequest string } // VolumeMount provides the volume mount configuration information diff --git a/pkg/app/master/config/config_test.go b/pkg/app/master/config/config_test.go new file mode 100644 index 0000000000..73ca99f554 --- /dev/null +++ b/pkg/app/master/config/config_test.go @@ -0,0 +1,82 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package config + +import ( + "testing" +) + +func TestContainerRunOptionsDeviceRequest(t *testing.T) { + tt := []struct { + deviceRequest string + expectEmpty bool + }{ + { + deviceRequest: "", + expectEmpty: true, + }, + { + deviceRequest: `{"Driver":"nvidia","Count":-1,"Capabilities":[["gpu"]]}`, + expectEmpty: false, + }, + { + deviceRequest: `{"Driver":"nvidia","DeviceIDs":["0","1"],"Capabilities":[["gpu"]]}`, + expectEmpty: false, + }, + } + + for _, test := range tt { + cro := ContainerRunOptions{ + DeviceRequest: test.deviceRequest, + } + + isEmpty := cro.DeviceRequest == "" + if isEmpty != test.expectEmpty { + t.Errorf("DeviceRequest isEmpty = %v for %q, want %v", isEmpty, test.deviceRequest, test.expectEmpty) + } + + if !test.expectEmpty && cro.DeviceRequest != test.deviceRequest { + t.Errorf("DeviceRequest = %q, want %q", cro.DeviceRequest, test.deviceRequest) + } + } +} + +func TestContainerRunOptionsAllFields(t *testing.T) { + // Test that ContainerRunOptions can be created with all fields including DeviceRequest + cro := ContainerRunOptions{ + Runtime: "nvidia", + SysctlParams: map[string]string{"net.core.somaxconn": "1024"}, + ShmSize: 67108864, // 64MB + DeviceRequest: `{"Driver":"nvidia","Count":-1,"Capabilities":[["gpu"]]}`, + } + + if cro.Runtime != "nvidia" { + t.Errorf("Runtime = %q, want 'nvidia'", cro.Runtime) + } + + if cro.SysctlParams["net.core.somaxconn"] != "1024" { + t.Errorf("SysctlParams[net.core.somaxconn] = %q, want '1024'", cro.SysctlParams["net.core.somaxconn"]) + } + + if cro.ShmSize != 67108864 { + t.Errorf("ShmSize = %d, want 67108864", cro.ShmSize) + } + + if cro.DeviceRequest == "" { + t.Error("DeviceRequest should not be empty") + } +} diff --git a/pkg/app/master/inspectors/container/container_inspector.go b/pkg/app/master/inspectors/container/container_inspector.go index ddf27eacda..087643275d 100644 --- a/pkg/app/master/inspectors/container/container_inspector.go +++ b/pkg/app/master/inspectors/container/container_inspector.go @@ -1,8 +1,25 @@ + /* + * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package container import ( "bufio" "bytes" + "encoding/json" "errors" "fmt" "os" @@ -554,6 +571,16 @@ func (i *Inspector) RunContainer() error { } if i.crOpts != nil { + if i.crOpts.DeviceRequest != "" { + var deviceRequest dockerapi.DeviceRequest + if err := json.Unmarshal([]byte(i.crOpts.DeviceRequest), &deviceRequest); err != nil { + logger.WithError(err).Error("failed to parse device request JSON") + } else { + containerOptions.HostConfig.DeviceRequests = []dockerapi.DeviceRequest{deviceRequest} + logger.Debugf("using device request => %#v", deviceRequest) + } + } + if i.crOpts.Runtime != "" { containerOptions.HostConfig.Runtime = i.crOpts.Runtime logger.Debugf("using custom runtime => %s", containerOptions.HostConfig.Runtime) diff --git a/pkg/app/master/inspectors/container/device_request_test.go b/pkg/app/master/inspectors/container/device_request_test.go new file mode 100644 index 0000000000..f6df45ad30 --- /dev/null +++ b/pkg/app/master/inspectors/container/device_request_test.go @@ -0,0 +1,229 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package container + +import ( + "encoding/json" + "reflect" + "testing" + + dockerapi "github.com/fsouza/go-dockerclient" +) + +func TestParseDeviceRequestJSON(t *testing.T) { + tt := []struct { + input string + expected dockerapi.DeviceRequest + expectError bool + }{ + { + input: `{"Driver":"nvidia","Count":-1,"Capabilities":[["gpu"]]}`, + expected: dockerapi.DeviceRequest{ + Driver: "nvidia", + Count: -1, + Capabilities: [][]string{{"gpu"}}, + }, + expectError: false, + }, + { + input: `{"Driver":"nvidia","DeviceIDs":["0","1"],"Capabilities":[["gpu"]]}`, + expected: dockerapi.DeviceRequest{ + Driver: "nvidia", + DeviceIDs: []string{"0", "1"}, + Capabilities: [][]string{{"gpu"}}, + }, + expectError: false, + }, + { + input: `{"Driver":"nvidia","Count":2,"DeviceIDs":["GPU-123"],"Capabilities":[["gpu","compute"]],"Options":{"visible":"true"}}`, + expected: dockerapi.DeviceRequest{ + Driver: "nvidia", + Count: 2, + DeviceIDs: []string{"GPU-123"}, + Capabilities: [][]string{{"gpu", "compute"}}, + Options: map[string]string{"visible": "true"}, + }, + expectError: false, + }, + { + input: `{"Count":-1,"Capabilities":[["gpu"],["nvidia","compute"]]}`, + expected: dockerapi.DeviceRequest{ + Count: -1, + Capabilities: [][]string{{"gpu"}, {"nvidia", "compute"}}, + }, + expectError: false, + }, + { + input: `{}`, + expected: dockerapi.DeviceRequest{ + Driver: "", + Count: 0, + DeviceIDs: nil, + Capabilities: nil, + Options: nil, + }, + expectError: false, + }, + { + input: `{"Driver":"nvidia"`, + expectError: true, + }, + { + input: `["gpu"]`, + expectError: true, + }, + { + input: `{Driver:nvidia}`, + expectError: true, + }, + { + input: `{"Count":"all"}`, + expectError: true, + }, + } + + for _, test := range tt { + var deviceRequest dockerapi.DeviceRequest + err := json.Unmarshal([]byte(test.input), &deviceRequest) + + if test.expectError { + if err == nil { + t.Errorf("expected error for input %q, but got none", test.input) + } + continue + } + + if err != nil { + t.Errorf("unexpected error for input %q: %v", test.input, err) + continue + } + + if !reflect.DeepEqual(deviceRequest, test.expected) { + t.Errorf("parsed device request mismatch for %q:\n got: %+v\n expected: %+v", + test.input, deviceRequest, test.expected) + } + } +} + +func TestDeviceRequestToHostConfig(t *testing.T) { + tt := []struct { + deviceRequestJSON string + expectDeviceRequests int + expectError bool + }{ + { + deviceRequestJSON: `{"Driver":"nvidia","Count":-1,"Capabilities":[["gpu"]]}`, + expectDeviceRequests: 1, + expectError: false, + }, + { + deviceRequestJSON: "", + expectDeviceRequests: 0, + expectError: false, + }, + { + deviceRequestJSON: `{invalid}`, + expectDeviceRequests: 0, + expectError: true, + }, + } + + for _, test := range tt { + hostConfig := &dockerapi.HostConfig{} + + if test.deviceRequestJSON != "" { + var deviceRequest dockerapi.DeviceRequest + err := json.Unmarshal([]byte(test.deviceRequestJSON), &deviceRequest) + + if test.expectError { + if err == nil { + t.Errorf("expected error for input %q, but got none", test.deviceRequestJSON) + } + if len(hostConfig.DeviceRequests) != 0 { + t.Errorf("expected no device requests on error, got %d", len(hostConfig.DeviceRequests)) + } + continue + } + + if err != nil { + t.Errorf("unexpected error for input %q: %v", test.deviceRequestJSON, err) + continue + } + + hostConfig.DeviceRequests = []dockerapi.DeviceRequest{deviceRequest} + } + + if len(hostConfig.DeviceRequests) != test.expectDeviceRequests { + t.Errorf("expected %d device requests for %q, got %d", + test.expectDeviceRequests, test.deviceRequestJSON, len(hostConfig.DeviceRequests)) + } + } +} + +func TestDeviceRequestFieldValidation(t *testing.T) { + // Test NVIDIA all GPUs request + input := `{"Driver":"nvidia","Count":-1,"Capabilities":[["gpu"]]}` + var dr dockerapi.DeviceRequest + if err := json.Unmarshal([]byte(input), &dr); err != nil { + t.Fatalf("failed to parse device request: %v", err) + } + if dr.Driver != "nvidia" { + t.Errorf("expected Driver 'nvidia', got %q", dr.Driver) + } + if dr.Count != -1 { + t.Errorf("expected Count -1 (all GPUs), got %d", dr.Count) + } + if len(dr.Capabilities) != 1 || len(dr.Capabilities[0]) != 1 || dr.Capabilities[0][0] != "gpu" { + t.Errorf("expected Capabilities [[gpu]], got %v", dr.Capabilities) + } + + // Test NVIDIA specific GPU by ID + input = `{"Driver":"nvidia","DeviceIDs":["0"],"Capabilities":[["gpu"]]}` + dr = dockerapi.DeviceRequest{} + if err := json.Unmarshal([]byte(input), &dr); err != nil { + t.Fatalf("failed to parse device request: %v", err) + } + if len(dr.DeviceIDs) != 1 || dr.DeviceIDs[0] != "0" { + t.Errorf("expected DeviceIDs [0], got %v", dr.DeviceIDs) + } + if dr.Count != 0 { + t.Errorf("expected Count 0 when DeviceIDs specified, got %d", dr.Count) + } + + // Test NVIDIA multiple GPUs by UUID + input = `{"Driver":"nvidia","DeviceIDs":["GPU-abc123","GPU-def456"],"Capabilities":[["gpu","compute"]]}` + dr = dockerapi.DeviceRequest{} + if err := json.Unmarshal([]byte(input), &dr); err != nil { + t.Fatalf("failed to parse device request: %v", err) + } + if len(dr.DeviceIDs) != 2 { + t.Errorf("expected 2 DeviceIDs, got %d", len(dr.DeviceIDs)) + } + if len(dr.Capabilities) != 1 || len(dr.Capabilities[0]) != 2 { + t.Errorf("expected Capabilities [[gpu compute]], got %v", dr.Capabilities) + } + + // Test count of specific number of GPUs + input = `{"Count":2,"Capabilities":[["gpu"]]}` + dr = dockerapi.DeviceRequest{} + if err := json.Unmarshal([]byte(input), &dr); err != nil { + t.Fatalf("failed to parse device request: %v", err) + } + if dr.Count != 2 { + t.Errorf("expected Count 2, got %d", dr.Count) + } +} diff --git a/pkg/app/sensor/artifact/artifact.go b/pkg/app/sensor/artifact/artifact.go index 852e6b9a14..8578ef25dd 100644 --- a/pkg/app/sensor/artifact/artifact.go +++ b/pkg/app/sensor/artifact/artifact.go @@ -792,6 +792,100 @@ func (p *store) prepareArtifacts() { } p.resolveLinks() + p.deduplicateFileMap() +} + +// deduplicateFileMap removes duplicate file paths that point to the same inode. +// This fixes an issue where files accessed through multiple symlinked paths +// (e.g., /usr/local/cuda-12.9/lib/file.so and /usr/local/cuda/lib64/file.so) +// would be copied multiple times, with later copies potentially overwriting +// with 0-byte content. +func (p *store) deduplicateFileMap() { + log.Debugf("deduplicateFileMap - starting inode-based deduplication, fileMap has %d entries", len(p.fileMap)) + + // Build inode -> paths map for regular files only + inodeMap := make(map[uint64][]string) + + for fpath := range p.fileMap { + info, err := os.Lstat(fpath) + if err != nil { + log.Warnf("deduplicateFileMap - error getting file info for %s: %v", fpath, err) + continue + } + + // Only process regular files (not symlinks, directories, etc.) + if !info.Mode().IsRegular() { + log.Debugf("deduplicateFileMap - skipping non-regular file %s (mode: %v)", fpath, info.Mode()) + continue + } + + // Get the inode from the underlying syscall.Stat_t + if sys, ok := info.Sys().(*syscall.Stat_t); ok { + inode := sys.Ino + inodeMap[inode] = append(inodeMap[inode], fpath) + } + } + + // For each inode with multiple paths, keep only the canonical path + duplicatesRemoved := 0 + inodeCount := 0 + for inode, paths := range inodeMap { + if len(paths) <= 1 { + continue + } + + inodeCount++ + log.Debugf("deduplicateFileMap - found %d paths for inode %d: %v", len(paths), inode, paths) + + // Sort paths to get deterministic behavior + // CRITICAL: Prefer paths that DON'T go through symlinked directories + // /usr/local/cuda/ is a symlink, so paths through it may return 0 bytes + // Prefer /usr/local/cuda-12.9/ paths which are the real paths + sort.Slice(paths, func(i, j int) bool { + pi, pj := paths[i], paths[j] + + // First priority: Prefer paths that don't go through /usr/local/cuda/ symlink + // These paths go through symlinked directories and often return 0 bytes + isThroughSymlink := func(p string) bool { + return strings.HasPrefix(p, "/usr/local/cuda/") && !strings.HasPrefix(p, "/usr/local/cuda-") + } + symI := isThroughSymlink(pi) + symJ := isThroughSymlink(pj) + if symI != symJ { + return !symI // Prefer path NOT through symlink + } + + // Second priority: Check if either path has flags (indicating it was actually accessed) + propsI := p.fileMap[pi] + propsJ := p.fileMap[pj] + + hasI := propsI != nil && len(propsI.Flags) > 0 + hasJ := propsJ != nil && len(propsJ.Flags) > 0 + + // Prefer path with flags + if hasI && !hasJ { + return true + } + if !hasI && hasJ { + return false + } + + // Third priority: Prefer longer paths (more canonical - /usr/local/cuda-12.9/targets/... is longer) + return len(pi) > len(pj) + }) + + // Keep the first (most canonical) path, remove the rest + canonicalPath := paths[0] + for _, dupPath := range paths[1:] { + log.Debugf("deduplicateFileMap - removing duplicate: %s (keeping %s)", dupPath, canonicalPath) + delete(p.fileMap, dupPath) + duplicatesRemoved++ + } + } + + if duplicatesRemoved > 0 { + log.Debugf("deduplicateFileMap - removed %d duplicate paths, fileMap now has %d entries", duplicatesRemoved, len(p.fileMap)) + } } func (p *store) resolveLinks() { @@ -1861,6 +1955,23 @@ copyFiles: continue } + // FIX: Skip files accessed through symlinked directories like /usr/local/cuda/ + // The Docker overlay filesystem can return 0 bytes for files accessed through symlinks + // Files should be copied from canonical paths (e.g., /usr/local/cuda-12.9/) which work correctly + if strings.HasPrefix(srcFileName, "/usr/local/cuda/") && !strings.HasPrefix(srcFileName, "/usr/local/cuda-") { + // Resolve symlinks to get the canonical path + evalPath, evalErr := filepath.EvalSymlinks(srcFileName) + if evalErr == nil && evalPath != srcFileName { + if _, hasCanonical := p.fileMap[evalPath]; hasCanonical { + log.Debugf("saveArtifacts - skipping symlinked path %s (canonical path %s exists)", srcFileName, evalPath) + continue + } + // Use the resolved path instead + log.Debugf("saveArtifacts - using resolved path %s instead of %s", evalPath, srcFileName) + srcFileName = evalPath + } + } + filePath := fmt.Sprintf("%s/files%s", p.storeLocation, srcFileName) log.Debug("saveArtifacts - saving file data => ", filePath) @@ -1883,7 +1994,7 @@ copyFiles: log.Debugf("saveArtifacts [%s,%s] - appMetadataFileUpdater => not updated / err = %v", srcFileName, filePath, err) } } else { - err := fsutil.CopyRegularFile(p.cmd.KeepPerms, srcFileName, filePath, true) + err := fsutil.CopyFile(p.cmd.KeepPerms, srcFileName, filePath, true) if err != nil { log.Debugf("saveArtifacts [%s,%s] - error saving file => %v", srcFileName, filePath, err) } else { @@ -1904,7 +2015,17 @@ copyFiles: } } } else { - err := fsutil.CopyRegularFile(p.cmd.KeepPerms, srcFileName, filePath, true) + // Check if destination already exists - skip if it does to avoid overwriting + // a good copy with a potentially bad one + if fsutil.Exists(filePath) { + existingInfo, _ := os.Stat(filePath) + if existingInfo != nil && existingInfo.Size() > 0 { + log.Warnf("saveArtifacts - skipping %s, destination already exists with %d bytes", srcFileName, existingInfo.Size()) + continue + } + } + + err := fsutil.CopyFile(p.cmd.KeepPerms, srcFileName, filePath, true) if err != nil { log.Debugf("saveArtifacts - error saving file => %v", err) } @@ -2157,7 +2278,7 @@ copyBsaFiles: log.Debugf("saveArtifacts[bsa] - saved file (%s)", dstFilePath) } } else { - err := fsutil.CopyRegularFile(p.cmd.KeepPerms, srcFileName, dstFilePath, true) + err := fsutil.CopyFile(p.cmd.KeepPerms, srcFileName, dstFilePath, true) if err != nil { log.Debugf("saveArtifacts[bsa] - error saving file => %v", err) } else { @@ -2178,7 +2299,7 @@ copyBsaFiles: dstPasswdFilePath := fmt.Sprintf("%s/files%s", p.storeLocation, sysidentity.PasswdFilePath) if _, err := os.Stat(sysidentity.PasswdFilePath); err == nil { //if err := cpFile(passwdFilePath, passwdFileTargetPath); err != nil { - if err := fsutil.CopyRegularFile(p.cmd.KeepPerms, sysidentity.PasswdFilePath, dstPasswdFilePath, true); err != nil { + if err := fsutil.CopyFile(p.cmd.KeepPerms, sysidentity.PasswdFilePath, dstPasswdFilePath, true); err != nil { log.Debugf("sensor: monitor - error copying user info file => %v", err) } } else { @@ -2943,7 +3064,7 @@ func fixPy3CacheFile(src, dst string) error { if _, err := os.Stat(dstPyFilePath); err != nil && os.IsNotExist(err) { //if err := cpFile(srcPyFilePath, dstPyFilePath); err != nil { - if err := fsutil.CopyRegularFile(true, srcPyFilePath, dstPyFilePath, true); err != nil { + if err := fsutil.CopyFile(true, srcPyFilePath, dstPyFilePath, true); err != nil { log.Debugf("sensor: monitor - fixPy3CacheFile - error copying file => %v", dstPyFilePath) return err } @@ -2992,7 +3113,7 @@ func rbEnsureGemFiles(src, storeLocation, prefix string) error { if _, err := os.Stat(extBuildFlagFilePathDst); err != nil && os.IsNotExist(err) { //if err := cpFile(extBuildFlagFilePath, extBuildFlagFilePathDst); err != nil { - if err := fsutil.CopyRegularFile(true, extBuildFlagFilePath, extBuildFlagFilePathDst, true); err != nil { + if err := fsutil.CopyFile(true, extBuildFlagFilePath, extBuildFlagFilePathDst, true); err != nil { log.Debugf("sensor: monitor - rbEnsureGemFiles - error copying file => %v", extBuildFlagFilePathDst) return err } @@ -3129,7 +3250,7 @@ func nodeEnsurePackageFiles(keepPerms bool, src, storeLocation, prefix string) e nodeGypFilePath := path.Join(filepath.Dir(src), nodeNPMNodeGypFile) if _, err := os.Stat(nodeGypFilePath); err == nil { nodeGypFilePathDst := fmt.Sprintf("%s%s%s", storeLocation, prefix, nodeGypFilePath) - if err := fsutil.CopyRegularFile(keepPerms, nodeGypFilePath, nodeGypFilePathDst, true); err != nil { + if err := fsutil.CopyFile(keepPerms, nodeGypFilePath, nodeGypFilePathDst, true); err != nil { log.Debugf("sensor: nodeEnsurePackageFiles - error copying %s => %v", nodeGypFilePath, err) } } diff --git a/pkg/app/sensor/artifact/dedup_test.go b/pkg/app/sensor/artifact/dedup_test.go new file mode 100644 index 0000000000..6869f9babf --- /dev/null +++ b/pkg/app/sensor/artifact/dedup_test.go @@ -0,0 +1,244 @@ +package artifact + +import ( + "os" + "path/filepath" + "sort" + "strings" + "syscall" + "testing" + + "github.com/slimtoolkit/slim/pkg/report" +) + +// TestIsThroughSymlinkLogic tests the symlink detection logic used in deduplicateFileMap +func TestIsThroughSymlinkLogic(t *testing.T) { + // This tests the isThroughSymlink inline function logic from deduplicateFileMap + isThroughSymlink := func(p string) bool { + return strings.HasPrefix(p, "/usr/local/cuda/") && !strings.HasPrefix(p, "/usr/local/cuda-") + } + + tt := []struct { + path string + expected bool + }{ + // Paths through symlink (should be avoided) + {"/usr/local/cuda/lib64/libcudart.so", true}, + {"/usr/local/cuda/include/cuda.h", true}, + {"/usr/local/cuda/bin/nvcc", true}, + + // Canonical paths (should be preferred) + {"/usr/local/cuda-12.9/lib64/libcudart.so", false}, + {"/usr/local/cuda-12.9/include/cuda.h", false}, + {"/usr/local/cuda-11.8/bin/nvcc", false}, + + // Other paths (not affected) + {"/usr/lib/libfoo.so", false}, + {"/opt/nvidia/cuda/lib/libbar.so", false}, + {"/home/user/cuda/file.txt", false}, + } + + for _, test := range tt { + result := isThroughSymlink(test.path) + if result != test.expected { + t.Errorf("isThroughSymlink(%q) = %v, want %v", test.path, result, test.expected) + } + } +} + +// TestPathPrioritySorting tests the sorting logic that determines which path to keep +func TestPathPrioritySorting(t *testing.T) { + // Simulate the sorting logic from deduplicateFileMap + isThroughSymlink := func(p string) bool { + return strings.HasPrefix(p, "/usr/local/cuda/") && !strings.HasPrefix(p, "/usr/local/cuda-") + } + + // fileMap simulates the p.fileMap with flags + fileMap := map[string]*report.ArtifactProps{ + "/usr/local/cuda/lib64/libcudart.so": nil, + "/usr/local/cuda-12.9/lib64/libcudart.so": {Flags: map[string]bool{"R": true}}, + } + + sortPaths := func(paths []string) { + sort.Slice(paths, func(i, j int) bool { + pi, pj := paths[i], paths[j] + + // First priority: Prefer paths that don't go through symlink + symI := isThroughSymlink(pi) + symJ := isThroughSymlink(pj) + if symI != symJ { + return !symI + } + + // Second priority: Check if either path has flags + propsI := fileMap[pi] + propsJ := fileMap[pj] + hasI := propsI != nil && len(propsI.Flags) > 0 + hasJ := propsJ != nil && len(propsJ.Flags) > 0 + + if hasI && !hasJ { + return true + } + if !hasI && hasJ { + return false + } + + // Third priority: Prefer longer paths + return len(pi) > len(pj) + }) + } + + // Test case 1: symlink path vs canonical path + paths1 := []string{ + "/usr/local/cuda/lib64/libcudart.so", + "/usr/local/cuda-12.9/lib64/libcudart.so", + } + sortPaths(paths1) + if paths1[0] != "/usr/local/cuda-12.9/lib64/libcudart.so" { + t.Errorf("expected canonical path first, got %v", paths1) + } + + // Test case 2: Two canonical paths - prefer one with flags + fileMap2 := map[string]*report.ArtifactProps{ + "/usr/local/cuda-12.9/lib64/libcudart.so.12": {Flags: map[string]bool{"R": true}}, + "/usr/local/cuda-12.9/lib64/libcudart.so": nil, + } + paths2 := []string{ + "/usr/local/cuda-12.9/lib64/libcudart.so", + "/usr/local/cuda-12.9/lib64/libcudart.so.12", + } + sort.Slice(paths2, func(i, j int) bool { + pi, pj := paths2[i], paths2[j] + propsI := fileMap2[pi] + propsJ := fileMap2[pj] + hasI := propsI != nil && len(propsI.Flags) > 0 + hasJ := propsJ != nil && len(propsJ.Flags) > 0 + if hasI && !hasJ { + return true + } + if !hasI && hasJ { + return false + } + return len(pi) > len(pj) + }) + if paths2[0] != "/usr/local/cuda-12.9/lib64/libcudart.so.12" { + t.Errorf("expected path with flags first, got %v", paths2) + } + + // Test case 3: Equal priority - prefer longer path + paths3 := []string{ + "/usr/lib/short.so", + "/usr/lib/subdir/longer.so", + } + sort.Slice(paths3, func(i, j int) bool { + return len(paths3[i]) > len(paths3[j]) + }) + if paths3[0] != "/usr/lib/subdir/longer.so" { + t.Errorf("expected longer path first, got %v", paths3) + } +} + +// TestDeduplicateFileMapWithHardlinks tests deduplication with actual hardlinks +func TestDeduplicateFileMapWithHardlinks(t *testing.T) { + // Create temp directory + tmpDir, err := os.MkdirTemp("", "dedup_test") + if err != nil { + t.Fatalf("failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + // Create a file + originalPath := filepath.Join(tmpDir, "original.txt") + if err := os.WriteFile(originalPath, []byte("test content"), 0644); err != nil { + t.Fatalf("failed to create original file: %v", err) + } + + // Create a hardlink to the same file + hardlinkPath := filepath.Join(tmpDir, "hardlink.txt") + if err := os.Link(originalPath, hardlinkPath); err != nil { + t.Fatalf("failed to create hardlink: %v", err) + } + + // Verify both paths have the same inode + origInfo, err := os.Lstat(originalPath) + if err != nil { + t.Fatalf("failed to stat original: %v", err) + } + linkInfo, err := os.Lstat(hardlinkPath) + if err != nil { + t.Fatalf("failed to stat hardlink: %v", err) + } + + origStat, ok := origInfo.Sys().(*syscall.Stat_t) + if !ok { + t.Fatal("failed to get syscall.Stat_t for original") + } + linkStat, ok := linkInfo.Sys().(*syscall.Stat_t) + if !ok { + t.Fatal("failed to get syscall.Stat_t for hardlink") + } + + if origStat.Ino != linkStat.Ino { + t.Fatalf("expected same inode, got %d vs %d", origStat.Ino, linkStat.Ino) + } + + // Test that we can build an inode map + inodeMap := make(map[uint64][]string) + for _, fpath := range []string{originalPath, hardlinkPath} { + info, err := os.Lstat(fpath) + if err != nil { + t.Fatalf("failed to stat %s: %v", fpath, err) + } + if sys, ok := info.Sys().(*syscall.Stat_t); ok { + inodeMap[sys.Ino] = append(inodeMap[sys.Ino], fpath) + } + } + + // Should have one inode with two paths + if len(inodeMap) != 1 { + t.Errorf("expected 1 inode, got %d", len(inodeMap)) + } + + for inode, paths := range inodeMap { + if len(paths) != 2 { + t.Errorf("expected 2 paths for inode %d, got %d", inode, len(paths)) + } + } +} + +// TestDeduplicateFileMapSkipsSymlinks tests that symlinks themselves are not deduplicated +func TestDeduplicateFileMapSkipsSymlinks(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "dedup_test") + if err != nil { + t.Fatalf("failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + // Create a regular file + regularPath := filepath.Join(tmpDir, "regular.txt") + if err := os.WriteFile(regularPath, []byte("content"), 0644); err != nil { + t.Fatalf("failed to create file: %v", err) + } + + // Create a symlink to it + symlinkPath := filepath.Join(tmpDir, "symlink.txt") + if err := os.Symlink(regularPath, symlinkPath); err != nil { + t.Fatalf("failed to create symlink: %v", err) + } + + // Check that Lstat identifies symlink vs regular file + regInfo, _ := os.Lstat(regularPath) + symInfo, _ := os.Lstat(symlinkPath) + + if !regInfo.Mode().IsRegular() { + t.Error("regular file should be identified as regular") + } + + if symInfo.Mode().IsRegular() { + t.Error("symlink should NOT be identified as regular file") + } + + // The deduplication logic only processes regular files + // Symlinks have different inodes and are handled separately +} + diff --git a/pkg/app/sensor/monitor/composite.go b/pkg/app/sensor/monitor/composite.go index cc51aa3a0e..99accdf33b 100644 --- a/pkg/app/sensor/monitor/composite.go +++ b/pkg/app/sensor/monitor/composite.go @@ -240,6 +240,12 @@ func (m *monitor) Start() error { return err } + log.Warn("COMPOSITE MONITOR: Both FAN and PTRACE monitors started successfully") + log.Warnf("COMPOSITE MONITOR: RTASourcePT flag value: %v", m.cmd.RTASourcePT) + if !m.cmd.RTASourcePT { + log.Error("COMPOSITE MONITOR: RTASourcePT is FALSE - PTRACE SYSCALL TRACING IS DISABLED!") + } + m.startedAt = time.Now() return nil diff --git a/pkg/monitor/ptrace/ptrace.go b/pkg/monitor/ptrace/ptrace.go index 31bfbe1404..f0011f5b7e 100644 --- a/pkg/monitor/ptrace/ptrace.go +++ b/pkg/monitor/ptrace/ptrace.go @@ -65,12 +65,13 @@ func Run( } if runOpt.RTASourcePT { - logger.Debug("tracing target app") + logger.Warn("PTRACE IS ENABLED - RTASourcePT=true - Starting trace monitoring") app.Report.Enabled = true go app.process() go app.trace() } else { - logger.Debug("not tracing target app...") + logger.Error("PTRACE IS DISABLED - RTASourcePT=false - NOT tracing target app!") + logger.Error("This means syscall-level monitoring is OFF and files may be missed!") go func() { logger.Debug("not tracing target app - start app") if err := app.start(); err != nil { @@ -239,6 +240,7 @@ func newApp( func (app *App) trace() { logger := app.logger.WithField("op", "trace") logger.Debug("call") + logger.Warn("PTRACE MONITOR IS STARTING - trace() called") defer logger.Debug("exit") runtime.LockOSThread() @@ -271,6 +273,13 @@ func (app *App) processFileActivity(e *syscallEvent) { return } + // Debug: Log when we process check/open file syscalls + if p.SyscallType() == CheckFileType || p.SyscallType() == OpenFileType { + okStatus := p.OKReturnStatus(e.retVal) + logger.Tracef("PTRACE: syscall=%s type=%s path=%s retVal=%d okStatus=%v", + p.SyscallName(), p.SyscallType(), e.pathParam, int32(e.retVal), okStatus) + } + if (p.SyscallType() == CheckFileType || p.SyscallType() == OpenFileType) && p.OKReturnStatus(e.retVal) { @@ -460,6 +469,23 @@ drain: app.Report.SyscallNum = uint32(len(app.Report.SyscallStats)) app.Report.FSActivity = app.FileActivity() + // CRITICAL DEBUG: Log FSActivity results + logger.Warnf("PTRACE REPORT FINALIZED: Tracked %d files in FSActivity", len(app.Report.FSActivity)) + logger.Warnf("PTRACE REPORT: Total syscalls executed: %d", app.Report.SyscallCount) + if len(app.Report.FSActivity) == 0 { + logger.Error("WARNING: PTRACE FSActivity is EMPTY - No files were tracked!") + logger.Error("This suggests ptrace syscall interception may not be working properly") + } else { + // Sample some tracked files for verification + sampleCount := 0 + for fpath := range app.Report.FSActivity { + if sampleCount < 5 { + logger.Warnf("PTRACE tracked file sample: %s", fpath) + sampleCount++ + } + } + } + app.StateCh <- state app.ReportCh <- &app.Report } @@ -1213,7 +1239,12 @@ func (ref *checkFileSyscallProcessor) OKCall(cstate *syscallState) bool { } func (ref *checkFileSyscallProcessor) OKReturnStatus(retVal uint64) bool { - return retVal == 0 + // Accept successful stat calls (0) and also failed attempts that indicate + // the application was looking for the file. This is important for Python + // imports which check multiple locations before finding the right file. + // Track ENOENT (file not found) and ENOTDIR (not a directory) in addition to success. + intRetVal := getIntVal(retVal) + return intRetVal == 0 || intRetVal == -2 || intRetVal == -20 // 0=success, -2=ENOENT, -20=ENOTDIR } func (ref *checkFileSyscallProcessor) EventOnCall() bool { diff --git a/pkg/monitor/ptrace/ptrace_test.go b/pkg/monitor/ptrace/ptrace_test.go new file mode 100644 index 0000000000..ec124ef45e --- /dev/null +++ b/pkg/monitor/ptrace/ptrace_test.go @@ -0,0 +1,118 @@ +package ptrace + +import ( + "testing" +) + +func TestGetIntVal(t *testing.T) { + tt := []struct { + input uint64 + expected int + }{ + {input: 0, expected: 0}, + {input: 0xFFFFFFFE, expected: -2}, // ENOENT + {input: 0xFFFFFFEC, expected: -20}, // ENOTDIR + {input: 1, expected: 1}, + {input: 0xFFFFFFFF, expected: -1}, // Generic error + } + + for _, test := range tt { + result := getIntVal(test.input) + if result != test.expected { + t.Errorf("getIntVal(0x%x) = %d, want %d", test.input, result, test.expected) + } + } +} + +func TestCheckFileSyscallProcessorOKReturnStatus(t *testing.T) { + processor := &checkFileSyscallProcessor{ + syscallProcessorCore: &syscallProcessorCore{}, + } + + tt := []struct { + retVal uint64 + expected bool + desc string + }{ + { + retVal: 0, + expected: true, + desc: "success (0)", + }, + { + retVal: 0xFFFFFFFE, // -2 as uint64 + expected: true, + desc: "ENOENT (-2) - file not found, should be tracked", + }, + { + retVal: 0xFFFFFFEC, // -20 as uint64 + expected: true, + desc: "ENOTDIR (-20) - not a directory, should be tracked", + }, + { + retVal: 0xFFFFFFFF, // -1 as uint64 + expected: false, + desc: "EPERM (-1) - should not be tracked", + }, + { + retVal: 0xFFFFFFFD, // -3 as uint64 + expected: false, + desc: "ESRCH (-3) - should not be tracked", + }, + { + retVal: 0xFFFFFFED, // -19 as uint64 + expected: false, + desc: "ENODEV (-19) - should not be tracked", + }, + { + retVal: 1, + expected: false, + desc: "positive return value - should not be tracked", + }, + } + + for _, test := range tt { + result := processor.OKReturnStatus(test.retVal) + if result != test.expected { + t.Errorf("OKReturnStatus(0x%x) [%s] = %v, want %v", + test.retVal, test.desc, result, test.expected) + } + } +} + +func TestCheckFileSyscallProcessorFailedReturnStatus(t *testing.T) { + processor := &checkFileSyscallProcessor{ + syscallProcessorCore: &syscallProcessorCore{}, + } + + tt := []struct { + retVal uint64 + expected bool + desc string + }{ + { + retVal: 0, + expected: false, + desc: "success (0) - not failed", + }, + { + retVal: 0xFFFFFFFE, // -2 (ENOENT) + expected: true, + desc: "ENOENT (-2) - failed", + }, + { + retVal: 0xFFFFFFFF, // -1 (EPERM) + expected: true, + desc: "EPERM (-1) - failed", + }, + } + + for _, test := range tt { + result := processor.FailedReturnStatus(test.retVal) + if result != test.expected { + t.Errorf("FailedReturnStatus(0x%x) [%s] = %v, want %v", + test.retVal, test.desc, result, test.expected) + } + } +} + diff --git a/pkg/util/fsutil/fsutil.go b/pkg/util/fsutil/fsutil.go index ed03e8c3f7..fcdbfb32af 100644 --- a/pkg/util/fsutil/fsutil.go +++ b/pkg/util/fsutil/fsutil.go @@ -518,6 +518,7 @@ func cloneDirPath(src, dst string) { func CopyRegularFile(clone bool, src, dst string, makeDir bool) error { log.Debugf("CopyRegularFile(%v,%v,%v,%v)", clone, src, dst, makeDir) //'clone' should be true only for the dst files that need to clone the dir properties from src + s, err := os.Open(src) if err != nil { return err @@ -574,20 +575,30 @@ func CopyRegularFile(clone bool, src, dst string, makeDir bool) error { return err } - if srcFileInfo.Size() > 0 { - written, err := io.Copy(d, s) - if err != nil { - d.Close() - return err - } + written, err := io.Copy(d, s) + if err != nil { + d.Close() + return err + } - if written != srcFileInfo.Size() { - log.Debugf("CopyRegularFile(%v,%v,%v) - copy data mismatch - %v/%v", - src, dst, makeDir, written, srcFileInfo.Size()) - d.Close() - return fmt.Errorf("%s -> %s: partial copy - %d/%d", - src, dst, written, srcFileInfo.Size()) - } + // Log if we copied a file that appeared 0-byte but had content + if srcFileInfo.Size() == 0 && written > 0 { + log.Debugf("CopyRegularFile(%v,%v) - file appeared as 0-byte but copied %d bytes", src, dst, written) + } + + // Log error for 0-byte copy of non-empty file (likely overlay FS issue) + if srcFileInfo.Size() > 0 && written == 0 { + log.Errorf("CopyRegularFile(%v,%v) - expected %d bytes but copied 0 (possible overlay FS issue)", + src, dst, srcFileInfo.Size()) + } + + // Verify the copy if we expected content + if srcFileInfo.Size() > 0 && written != srcFileInfo.Size() { + log.Debugf("CopyRegularFile(%v,%v,%v) - copy data mismatch - %v/%v", + src, dst, makeDir, written, srcFileInfo.Size()) + d.Close() + return fmt.Errorf("%s -> %s: partial copy - %d/%d", + src, dst, written, srcFileInfo.Size()) } //Need to close dst file before chmod works the right way diff --git a/pkg/util/fsutil/fsutil_test.go b/pkg/util/fsutil/fsutil_test.go new file mode 100644 index 0000000000..b8691959f5 --- /dev/null +++ b/pkg/util/fsutil/fsutil_test.go @@ -0,0 +1,216 @@ +package fsutil + +import ( + "os" + "path/filepath" + "testing" +) + +func TestCopyRegularFile(t *testing.T) { + tt := []struct { + name string + content string + clone bool + makeDir bool + expectError bool + }{ + { + name: "copy file with content", + content: "hello world", + clone: false, + makeDir: true, + expectError: false, + }, + { + name: "copy empty file", + content: "", + clone: false, + makeDir: true, + expectError: false, + }, + { + name: "copy without makeDir to existing dir", + content: "test content", + clone: false, + makeDir: false, + expectError: false, + }, + } + + for _, test := range tt { + // Create temp directory for test + tmpDir, err := os.MkdirTemp("", "fsutil_test") + if err != nil { + t.Fatalf("failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + // Create source file + srcPath := filepath.Join(tmpDir, "src", "testfile.txt") + if err := os.MkdirAll(filepath.Dir(srcPath), 0755); err != nil { + t.Fatalf("failed to create src dir: %v", err) + } + if err := os.WriteFile(srcPath, []byte(test.content), 0644); err != nil { + t.Fatalf("failed to create src file: %v", err) + } + + // Set up destination path + var dstPath string + if test.makeDir { + dstPath = filepath.Join(tmpDir, "dst", "testfile.txt") + } else { + // Create dst dir first for non-makeDir tests + dstDir := filepath.Join(tmpDir, "dst") + if err := os.MkdirAll(dstDir, 0755); err != nil { + t.Fatalf("failed to create dst dir: %v", err) + } + dstPath = filepath.Join(dstDir, "testfile.txt") + } + + // Copy file + err = CopyRegularFile(test.clone, srcPath, dstPath, test.makeDir) + + if test.expectError { + if err == nil { + t.Errorf("test %q: expected error but got none", test.name) + } + continue + } + + if err != nil { + t.Errorf("test %q: unexpected error: %v", test.name, err) + continue + } + + // Verify destination file exists and has correct content + dstContent, err := os.ReadFile(dstPath) + if err != nil { + t.Errorf("test %q: failed to read dst file: %v", test.name, err) + continue + } + + if string(dstContent) != test.content { + t.Errorf("test %q: content mismatch, got %q, want %q", test.name, string(dstContent), test.content) + } + } +} + +func TestCopyRegularFilePreservesSize(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "fsutil_test") + if err != nil { + t.Fatalf("failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + // Create source file with known content + content := "This is test content with known size" + srcPath := filepath.Join(tmpDir, "src.txt") + if err := os.WriteFile(srcPath, []byte(content), 0644); err != nil { + t.Fatalf("failed to create src file: %v", err) + } + + srcInfo, err := os.Stat(srcPath) + if err != nil { + t.Fatalf("failed to stat src file: %v", err) + } + + // Copy file + dstPath := filepath.Join(tmpDir, "dst.txt") + if err := CopyRegularFile(false, srcPath, dstPath, true); err != nil { + t.Fatalf("CopyRegularFile failed: %v", err) + } + + // Verify sizes match + dstInfo, err := os.Stat(dstPath) + if err != nil { + t.Fatalf("failed to stat dst file: %v", err) + } + + if dstInfo.Size() != srcInfo.Size() { + t.Errorf("size mismatch: src=%d, dst=%d", srcInfo.Size(), dstInfo.Size()) + } +} + +func TestCopyRegularFileMissingSource(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "fsutil_test") + if err != nil { + t.Fatalf("failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + srcPath := filepath.Join(tmpDir, "nonexistent.txt") + dstPath := filepath.Join(tmpDir, "dst.txt") + + err = CopyRegularFile(false, srcPath, dstPath, true) + if err == nil { + t.Error("expected error for missing source file, got none") + } +} + +func TestCopyFile(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "fsutil_test") + if err != nil { + t.Fatalf("failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + // Create source file + content := "test content for CopyFile" + srcPath := filepath.Join(tmpDir, "src.txt") + if err := os.WriteFile(srcPath, []byte(content), 0644); err != nil { + t.Fatalf("failed to create src file: %v", err) + } + + // Copy file + dstPath := filepath.Join(tmpDir, "dst.txt") + if err := CopyFile(false, srcPath, dstPath, true); err != nil { + t.Fatalf("CopyFile failed: %v", err) + } + + // Verify content + dstContent, err := os.ReadFile(dstPath) + if err != nil { + t.Fatalf("failed to read dst file: %v", err) + } + + if string(dstContent) != content { + t.Errorf("content mismatch: got %q, want %q", string(dstContent), content) + } +} + +func TestCopyFileWithSymlink(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "fsutil_test") + if err != nil { + t.Fatalf("failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + // Create source file + content := "symlink test content" + srcPath := filepath.Join(tmpDir, "src.txt") + if err := os.WriteFile(srcPath, []byte(content), 0644); err != nil { + t.Fatalf("failed to create src file: %v", err) + } + + // Create symlink to source + symlinkPath := filepath.Join(tmpDir, "src_link.txt") + if err := os.Symlink(srcPath, symlinkPath); err != nil { + t.Fatalf("failed to create symlink: %v", err) + } + + // Copy via symlink - CopyFile should handle this + dstPath := filepath.Join(tmpDir, "dst.txt") + if err := CopyFile(false, symlinkPath, dstPath, true); err != nil { + t.Fatalf("CopyFile via symlink failed: %v", err) + } + + // Verify content was copied (not the symlink itself) + dstContent, err := os.ReadFile(dstPath) + if err != nil { + t.Fatalf("failed to read dst file: %v", err) + } + + if string(dstContent) != content { + t.Errorf("content mismatch: got %q, want %q", string(dstContent), content) + } +}